diff --git a/roboflow/cli/handlers/infer.py b/roboflow/cli/handlers/infer.py index 4e3f4492..9f3ddca1 100644 --- a/roboflow/cli/handlers/infer.py +++ b/roboflow/cli/handlers/infer.py @@ -65,6 +65,7 @@ def _infer(args): # noqa: ANN001 from roboflow.models.keypoint_detection import KeypointDetectionModel from roboflow.models.object_detection import ObjectDetectionModel from roboflow.models.semantic_segmentation import SemanticSegmentationModel + from roboflow.models.vlm import VLMModel model_class_map = { "object-detection": ObjectDetectionModel, @@ -72,6 +73,7 @@ def _infer(args): # noqa: ANN001 "instance-segmentation": InstanceSegmentationModel, "semantic-segmentation": SemanticSegmentationModel, "keypoint-detection": KeypointDetectionModel, + "text-image-pairs": VLMModel, } model_cls = model_class_map.get(project_type) @@ -97,15 +99,25 @@ def _infer(args): # noqa: ANN001 kwargs["overlap"] = int(args.overlap * 100) try: - group = model.predict(args.file, **kwargs) + result = model.predict(args.file, **kwargs) except Exception as exc: output_error(args, f"Inference failed: {exc}") return + # VLM models return raw dict response; pass through as-is. + if isinstance(result, dict): + if getattr(args, "json", False): + output(args, result) + else: + import json as _json + + output(args, None, text=_json.dumps(result, indent=2)) + return + # Serialize predictions for JSON output if getattr(args, "json", False): predictions = [] - for pred in group: + for pred in result: if hasattr(pred, "json"): predictions.append(pred.json()) elif hasattr(pred, "__dict__"): @@ -114,4 +126,4 @@ def _infer(args): # noqa: ANN001 predictions.append(str(pred)) output(args, predictions) else: - output(args, None, text=str(group)) + output(args, None, text=str(result)) diff --git a/roboflow/config.py b/roboflow/config.py index bc3eaf03..1b3d5a1c 100644 --- a/roboflow/config.py +++ b/roboflow/config.py @@ -72,6 +72,7 @@ def get_conditional_configuration_variable(key, default): TYPE_INSTANCE_SEGMENTATION = "instance-segmentation" TYPE_SEMANTIC_SEGMENTATION = "semantic-segmentation" TYPE_KEYPOINT_DETECTION = "keypoint-detection" +TYPE_TEXT_IMAGE_PAIRS = "text-image-pairs" TASK_DET = "det" TASK_SEG = "seg" diff --git a/roboflow/core/version.py b/roboflow/core/version.py index 229f2ad9..92817bff 100644 --- a/roboflow/core/version.py +++ b/roboflow/core/version.py @@ -22,6 +22,7 @@ TYPE_KEYPOINT_DETECTION, TYPE_OBJECT_DETECTION, TYPE_SEMANTIC_SEGMENTATION, + TYPE_TEXT_IMAGE_PAIRS, UNIVERSE_URL, ) from roboflow.core.dataset import Dataset @@ -30,6 +31,7 @@ from roboflow.models.keypoint_detection import KeypointDetectionModel from roboflow.models.object_detection import ObjectDetectionModel from roboflow.models.semantic_segmentation import SemanticSegmentationModel +from roboflow.models.vlm import VLMModel from roboflow.util.annotations import amend_data_yaml from roboflow.util.general import extract_zip, write_line from roboflow.util.model_processor import process, validate_model_type_for_project @@ -133,6 +135,16 @@ def __init__( self.model = SemanticSegmentationModel(self.__api_key, self.id) elif self.type == TYPE_KEYPOINT_DETECTION: self.model = KeypointDetectionModel(self.__api_key, self.id, version=version_without_workspace) + elif self.type == TYPE_TEXT_IMAGE_PAIRS: + self.model = VLMModel( + self.__api_key, + self.id, + self.name, + version_without_workspace, + local=local, + colors=self.colors, + preprocessing=self.preprocessing, + ) else: self.model = None diff --git a/roboflow/models/vlm.py b/roboflow/models/vlm.py new file mode 100644 index 00000000..b24c3ceb --- /dev/null +++ b/roboflow/models/vlm.py @@ -0,0 +1,95 @@ +"""Vision-language (text-image-pairs) hosted inference. + +Wraps the serverless endpoint for VLM-style projects (e.g. PaliGemma). +Unlike detection/classification models, the response shape is free-form: +captions, VQA answers, OCR text, or tokenized detections depending on the +underlying model. `predict` returns the raw serverless JSON unchanged so +callers can interpret the payload for their specific model. +""" + +from __future__ import annotations + +import base64 +import io +import os +import urllib.parse +from typing import Any, Optional + +import requests +from PIL import Image + +from roboflow.models.inference import InferenceModel +from roboflow.util.image_utils import check_image_url + + +class VLMModel(InferenceModel): + """Run inference on a hosted text-image-pairs (VLM) model.""" + + def __init__( + self, + api_key: str, + id: str, + name: Optional[str] = None, + version: Optional[str] = None, + local: Optional[str] = None, + colors: Optional[dict] = None, + preprocessing: Optional[dict] = None, + ) -> None: + super().__init__(api_key, id, version=version) + self.__api_key = api_key + self.id = id + self.name = name + self.version = version + self.base_url = local if local else "https://serverless.roboflow.com/" + self.colors = {} if colors is None else colors + self.preprocessing = {} if preprocessing is None else preprocessing + + def _endpoint(self) -> str: + parts = self.id.rsplit("/") + without_workspace = parts[1] + version = self.version + if not version and len(parts) > 2: + version = parts[2] + base = self.base_url if self.base_url.endswith("/") else self.base_url + "/" + return f"{base}{without_workspace}/{version}" + + def predict(self, image_path: str, **kwargs: Any) -> dict: # type: ignore[override] + """Run inference and return the raw serverless response. + + Args: + image_path: local path or http(s) URL to an image. + **kwargs: extra query params forwarded to the endpoint. + + Returns: + The raw JSON response as a dict. Shape depends on the underlying + VLM (e.g. `{"response": {">": "..."}}` for PaliGemma). + """ + is_url = urllib.parse.urlparse(image_path).scheme in ("http", "https") + + params: dict[str, Any] = {"api_key": self.__api_key} + params.update(kwargs) + + if is_url: + if not check_image_url(image_path): + raise Exception(f"Image URL is not reachable: {image_path}") + params["image"] = image_path + url = f"{self._endpoint()}?{urllib.parse.urlencode(params)}" + resp = requests.get(url) + else: + if not os.path.exists(image_path): + raise Exception(f"Image does not exist at {image_path}!") + image = Image.open(image_path).convert("RGB") + buffered = io.BytesIO() + image.save(buffered, quality=90, format="JPEG") + img_b64 = base64.b64encode(buffered.getvalue()).decode("ascii") + url = f"{self._endpoint()}?{urllib.parse.urlencode(params)}" + resp = requests.post( + url, + data=img_b64, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + if resp.status_code != 200: + raise Exception(resp.text) + + return resp.json() diff --git a/tests/cli/test_infer_handler.py b/tests/cli/test_infer_handler.py index d3023c2b..52a53f43 100644 --- a/tests/cli/test_infer_handler.py +++ b/tests/cli/test_infer_handler.py @@ -149,5 +149,63 @@ def test_infer_confidence_converted_to_percentage(self, mock_model_cls: MagicMoc mock_model.predict.assert_called_once_with("test.jpg", confidence=70, overlap=30) +class TestInferVLM(unittest.TestCase): + """VLM (text-image-pairs) path returns raw dict passthrough.""" + + def _make_args(self, **kwargs: object) -> types.SimpleNamespace: + defaults = { + "json": False, + "api_key": "test-key", + "workspace": "test-ws", + "model": "test-project/1", + "file": "https://example.com/img.jpg", + "confidence": 0.5, + "overlap": 0.5, + "type": "text-image-pairs", + } + defaults.update(kwargs) + return types.SimpleNamespace(**defaults) + + @patch("roboflow.models.vlm.VLMModel") + def test_infer_vlm_json_passthrough(self, mock_model_cls: MagicMock) -> None: + from roboflow.cli.handlers.infer import _infer + + raw = {"inference_id": "abc", "response": {">": "caption text"}} + mock_model = MagicMock() + mock_model.predict.return_value = raw + mock_model_cls.return_value = mock_model + + args = self._make_args(json=True) + buf = io.StringIO() + old_stdout = sys.stdout + sys.stdout = buf + try: + _infer(args) + finally: + sys.stdout = old_stdout + + result = json.loads(buf.getvalue()) + self.assertEqual(result, raw) + + @patch("roboflow.models.vlm.VLMModel") + def test_infer_vlm_skips_confidence_overlap(self, mock_model_cls: MagicMock) -> None: + from roboflow.cli.handlers.infer import _infer + + mock_model = MagicMock() + mock_model.predict.return_value = {"ok": True} + mock_model_cls.return_value = mock_model + + args = self._make_args(confidence=0.7, overlap=0.3) + buf = io.StringIO() + old_stdout = sys.stdout + sys.stdout = buf + try: + _infer(args) + finally: + sys.stdout = old_stdout + + mock_model.predict.assert_called_once_with("https://example.com/img.jpg") + + if __name__ == "__main__": unittest.main() diff --git a/tests/models/test_vlm.py b/tests/models/test_vlm.py new file mode 100644 index 00000000..56c48dc2 --- /dev/null +++ b/tests/models/test_vlm.py @@ -0,0 +1,82 @@ +"""Unit tests for roboflow.models.vlm.VLMModel.""" + +from __future__ import annotations + +import unittest +from unittest.mock import MagicMock, patch + +from roboflow.models.vlm import VLMModel + + +class TestVLMModel(unittest.TestCase): + def _make(self) -> VLMModel: + return VLMModel(api_key="k", id="ws/proj/3", name="proj", version="3") + + @patch("roboflow.models.vlm.check_image_url", return_value=True) + @patch("roboflow.models.vlm.requests.get") + def test_predict_url_returns_raw_dict(self, mock_get: MagicMock, _chk: MagicMock) -> None: + mock_get.return_value = MagicMock( + status_code=200, + json=lambda: {"response": {">": "box"}}, + ) + model = self._make() + result = model.predict("https://example.com/img.jpg") + + self.assertEqual(result, {"response": {">": "box"}}) + called_url = mock_get.call_args[0][0] + self.assertIn("https://serverless.roboflow.com/proj/3", called_url) + self.assertIn("api_key=k", called_url) + self.assertIn("image=", called_url) + + @patch("roboflow.models.vlm.check_image_url", return_value=True) + @patch("roboflow.models.vlm.requests.get") + def test_predict_forwards_extra_kwargs_as_query(self, mock_get: MagicMock, _chk: MagicMock) -> None: + mock_get.return_value = MagicMock(status_code=200, json=lambda: {"ok": True}) + self._make().predict("https://example.com/img.jpg", prompt="caption") + + called_url = mock_get.call_args[0][0] + self.assertIn("prompt=caption", called_url) + + @patch("roboflow.models.vlm.check_image_url", return_value=True) + @patch("roboflow.models.vlm.requests.get") + def test_predict_non_200_raises(self, mock_get: MagicMock, _chk: MagicMock) -> None: + mock_get.return_value = MagicMock(status_code=401, text="unauthorized") + with self.assertRaises(Exception) as ctx: + self._make().predict("https://example.com/img.jpg") + self.assertIn("unauthorized", str(ctx.exception)) + + @patch("roboflow.models.vlm.os.path.exists", return_value=True) + @patch("roboflow.models.vlm.Image.open") + @patch("roboflow.models.vlm.requests.post") + def test_predict_local_path_posts_base64( + self, mock_post: MagicMock, mock_open: MagicMock, _exists: MagicMock + ) -> None: + mock_img = MagicMock() + mock_img.convert.return_value = mock_img + + def _save(buf: object, **_kw: object) -> None: + buf.write(b"fakejpeg") # type: ignore[attr-defined] + + mock_img.save.side_effect = _save + mock_open.return_value = mock_img + mock_post.return_value = MagicMock(status_code=200, json=lambda: {"ok": True}) + + result = self._make().predict("/tmp/x.jpg") + self.assertEqual(result, {"ok": True}) + _, kwargs = mock_post.call_args + self.assertEqual(kwargs["headers"], {"Content-Type": "application/x-www-form-urlencoded"}) + self.assertIsInstance(kwargs["data"], str) + + def test_predict_missing_local_file_raises(self) -> None: + with self.assertRaises(Exception) as ctx: + self._make().predict("/definitely/not/a/real/path.jpg") + self.assertIn("does not exist", str(ctx.exception)) + + def test_endpoint_uses_id_parts_when_version_unset(self) -> None: + model = VLMModel(api_key="k", id="ws/proj/7") + model.version = None + self.assertEqual(model._endpoint(), "https://serverless.roboflow.com/proj/7") + + +if __name__ == "__main__": + unittest.main()