From a7a6653d8bface2305048debf33d7b7fe2470429 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Tue, 27 Aug 2024 16:41:20 +0200 Subject: [PATCH 1/3] Improve CLIP block --- inference/core/workflows/core_steps/loader.py | 4 + .../models/foundation/clip_comparison/v2.py | 292 ++++++++++++++++++ .../execution/test_workflow_with_clip.py | 139 +++++++++ 3 files changed, 435 insertions(+) create mode 100644 inference/core/workflows/core_steps/models/foundation/clip_comparison/v2.py diff --git a/inference/core/workflows/core_steps/loader.py b/inference/core/workflows/core_steps/loader.py index ad9d3c7bb9..8d63b5070e 100644 --- a/inference/core/workflows/core_steps/loader.py +++ b/inference/core/workflows/core_steps/loader.py @@ -55,6 +55,9 @@ from inference.core.workflows.core_steps.models.foundation.clip_comparison.v1 import ( ClipComparisonBlockV1, ) +from inference.core.workflows.core_steps.models.foundation.clip_comparison.v2 import ( + ClipComparisonBlockV2, +) from inference.core.workflows.core_steps.models.foundation.cog_vlm.v1 import ( CogVLMBlockV1, ) @@ -283,6 +286,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]: ConvertGrayscaleBlockV1, ImageThresholdBlockV1, ImageContoursDetectionBlockV1, + ClipComparisonBlockV2, ] diff --git a/inference/core/workflows/core_steps/models/foundation/clip_comparison/v2.py b/inference/core/workflows/core_steps/models/foundation/clip_comparison/v2.py new file mode 100644 index 0000000000..ce7f69c009 --- /dev/null +++ b/inference/core/workflows/core_steps/models/foundation/clip_comparison/v2.py @@ -0,0 +1,292 @@ +from functools import partial +from typing import List, Literal, Optional, Type, Union + +import numpy as np +from pydantic import ConfigDict, Field + +from inference.core.entities.requests.clip import ClipCompareRequest +from inference.core.env import ( + HOSTED_CORE_MODEL_URL, + LOCAL_INFERENCE_API_URL, + WORKFLOWS_REMOTE_API_TARGET, + WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, +) +from inference.core.managers.base import ModelManager +from inference.core.workflows.core_steps.common.entities import StepExecutionMode +from inference.core.workflows.core_steps.common.utils import ( + load_core_model, + run_in_parallel, +) +from inference.core.workflows.execution_engine.constants import ( + PARENT_ID_KEY, + ROOT_PARENT_ID_KEY, +) +from inference.core.workflows.execution_engine.entities.base import ( + Batch, + OutputDefinition, + WorkflowImageData, +) +from inference.core.workflows.execution_engine.entities.types import ( + BATCH_OF_CLASSIFICATION_PREDICTION_KIND, + BATCH_OF_PARENT_ID_KIND, + FLOAT_ZERO_TO_ONE_KIND, + LIST_OF_VALUES_KIND, + STRING_KIND, + ImageInputField, + StepOutputImageSelector, + WorkflowImageSelector, + WorkflowParameterSelector, +) +from inference.core.workflows.prototypes.block import ( + BlockResult, + WorkflowBlock, + WorkflowBlockManifest, +) +from inference_sdk import InferenceHTTPClient + +LONG_DESCRIPTION = """ +Use the OpenAI CLIP zero-shot classification model to classify images. + +This block accepts an image and a list of text prompts. The block then returns the +similarity of each text label to the provided image. + +This block is useful for classifying images without having to train a fine-tuned +classification model. For example, you could use CLIP to classify the type of vehicle +in an image, or if an image contains NSFW material. +""" + +EXPECTED_OUTPUT_KEYS = {"similarity", "parent_id", "root_parent_id", "prediction_type"} + +ALL_CLIP_VARIANTS = { + "RN101", + "RN50", + "RN50x16", + "RN50x4", + "RN50x64", + "ViT-B-16", + "ViT-B-32", + "ViT-L-14-336px", + "ViT-L-14", +} + + +class BlockManifest(WorkflowBlockManifest): + model_config = ConfigDict( + json_schema_extra={ + "name": "Clip Comparison", + "version": "v2", + "short_description": "Compare CLIP image and text embeddings.", + "long_description": LONG_DESCRIPTION, + "license": "Apache-2.0", + "block_type": "model", + } + ) + type: Literal["roboflow_core/clip_comparison@v2"] + name: str = Field(description="Unique name of step in workflows") + images: Union[WorkflowImageSelector, StepOutputImageSelector] = ImageInputField + classes: Union[WorkflowParameterSelector(kind=[LIST_OF_VALUES_KIND]), List[str]] = ( + Field( + description="List of classes to calculate similarity against each input image", + examples=[["a", "b", "c"], "$inputs.texts"], + ) + ) + version: Union[ + Literal[ + "RN101", + "RN50", + "RN50x16", + "RN50x4", + "RN50x64", + "ViT-B-16", + "ViT-B-32", + "ViT-L-14-336px", + "ViT-L-14", + ], + WorkflowParameterSelector(kind=[STRING_KIND]), + ] = Field( + default="ViT-B-16", + description="Variant of YoloWorld model", + examples=["ViT-B-16", "$inputs.variant"], + ) + + @classmethod + def accepts_batch_input(cls) -> bool: + return True + + @classmethod + def describe_outputs(cls) -> List[OutputDefinition]: + return [ + OutputDefinition(name="similarities", kind=[LIST_OF_VALUES_KIND]), + OutputDefinition(name="max_similarity", kind=[FLOAT_ZERO_TO_ONE_KIND]), + OutputDefinition(name="most_similar_class", kind=[STRING_KIND]), + OutputDefinition(name="min_similarity", kind=[FLOAT_ZERO_TO_ONE_KIND]), + OutputDefinition(name="least_similar_class", kind=[STRING_KIND]), + OutputDefinition( + name="classification_predictions", + kind=[BATCH_OF_CLASSIFICATION_PREDICTION_KIND], + ), + OutputDefinition(name="inference_id", kind=[STRING_KIND]), + OutputDefinition(name="parent_id", kind=[BATCH_OF_PARENT_ID_KIND]), + OutputDefinition(name="root_parent_id", kind=[BATCH_OF_PARENT_ID_KIND]), + ] + + @classmethod + def get_execution_engine_compatibility(cls) -> Optional[str]: + return ">=1.0.0,<2.0.0" + + +class ClipComparisonBlockV2(WorkflowBlock): + + def __init__( + self, + model_manager: ModelManager, + api_key: Optional[str], + step_execution_mode: StepExecutionMode, + ): + self._model_manager = model_manager + self._api_key = api_key + self._step_execution_mode = step_execution_mode + + @classmethod + def get_init_parameters(cls) -> List[str]: + return ["model_manager", "api_key", "step_execution_mode"] + + @classmethod + def get_manifest(cls) -> Type[WorkflowBlockManifest]: + return BlockManifest + + def run( + self, + images: Batch[WorkflowImageData], + classes: List[str], + version: str, + ) -> BlockResult: + if version not in ALL_CLIP_VARIANTS: + raise ValueError(f"Supported CLIP versions do not involve {version}") + if not classes: + raise ValueError("Provided empty class list for CLIP Comparison step") + if self._step_execution_mode is StepExecutionMode.LOCAL: + return self.run_locally(images=images, classes=classes, version=version) + elif self._step_execution_mode is StepExecutionMode.REMOTE: + return self.run_remotely(images=images, classes=classes, version=version) + else: + raise ValueError( + f"Unknown step execution mode: {self._step_execution_mode}" + ) + + def run_locally( + self, + images: Batch[WorkflowImageData], + classes: List[str], + version: str, + ) -> BlockResult: + predictions = [] + for single_image in images: + inference_request = ClipCompareRequest( + clip_version_id=version, + subject=single_image.to_inference_format(numpy_preferred=True), + subject_type="image", + prompt=classes, + prompt_type="text", + api_key=self._api_key, + ) + clip_model_id = load_core_model( + model_manager=self._model_manager, + inference_request=inference_request, + core_model="clip", + ) + prediction = self._model_manager.infer_from_request_sync( + clip_model_id, inference_request + ) + predictions.append(prediction.model_dump()) + return self._post_process_result( + images=images, + predictions=predictions, + classes=classes, + ) + + def run_remotely( + self, + images: Batch[WorkflowImageData], + classes: List[str], + version: str, + ) -> BlockResult: + api_url = ( + LOCAL_INFERENCE_API_URL + if WORKFLOWS_REMOTE_API_TARGET != "hosted" + else HOSTED_CORE_MODEL_URL + ) + client = InferenceHTTPClient( + api_url=api_url, + api_key=self._api_key, + ) + if WORKFLOWS_REMOTE_API_TARGET == "hosted": + client.select_api_v0() + tasks = [ + partial( + client.clip_compare, + subject=single_image.numpy_image, + prompt=classes, + clip_version=version, + ) + for single_image in images + ] + predictions = run_in_parallel( + tasks=tasks, + max_workers=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, + ) + return self._post_process_result( + images=images, + predictions=predictions, + classes=classes, + ) + + def _post_process_result( + self, + images: Batch[WorkflowImageData], + predictions: List[dict], + classes: List[str], + ) -> List[dict]: + results = [] + for prediction, image in zip(predictions, images): + similarities = prediction["similarity"] + inference_id = prediction.get("inference_id") + max_similarity = np.max(similarities) + max_similarity_id = np.argmax(similarities) + min_similarity = np.min(similarities) + min_similarity_id = np.argmin(similarities) + most_similar_class_name = classes[max_similarity_id] + least_similar_class_name = classes[min_similarity_id] + prediction[PARENT_ID_KEY] = image.parent_metadata.parent_id + prediction[ROOT_PARENT_ID_KEY] = ( + image.workflow_root_ancestor_metadata.parent_id + ) + classification_predictions = { + "predictions": [ + { + "class": class_name, + "class_id": class_id, + "confidence": confidence, + } + for class_id, (class_name, confidence) in enumerate( + zip(classes, similarities) + ) + ], + "top": most_similar_class_name, + "confidence": max_similarity, + "parent_id": image.parent_metadata.parent_id, + "inference_id": inference_id, + } + result = { + PARENT_ID_KEY: image.parent_metadata.parent_id, + ROOT_PARENT_ID_KEY: image.workflow_root_ancestor_metadata.parent_id, + "inference_id": inference_id, + "similarities": similarities, + "max_similarity": max_similarity, + "most_similar_class": most_similar_class_name, + "min_similarity": min_similarity, + "least_similar_class": least_similar_class_name, + "classification_predictions": classification_predictions, + } + results.append(result) + return results diff --git a/tests/workflows/integration_tests/execution/test_workflow_with_clip.py b/tests/workflows/integration_tests/execution/test_workflow_with_clip.py index 0b8a392b47..5a32fbfd4e 100644 --- a/tests/workflows/integration_tests/execution/test_workflow_with_clip.py +++ b/tests/workflows/integration_tests/execution/test_workflow_with_clip.py @@ -76,3 +76,142 @@ def test_clip_workflow_when_minimal_valid_input_provided( assert ( result[1]["similarity"][0] < result[1]["similarity"][1] ), "Expected to predict `crowd` class for second image" + + +WORKFLOW_WITH_CLIP_COMPARISON_V2 = { + "version": "1.0", + "inputs": [ + {"type": "WorkflowImage", "name": "image"}, + {"type": "WorkflowParameter", "name": "reference"}, + {"type": "WorkflowParameter", "name": "version", "default_value": "ViT-B-16"}, + ], + "steps": [ + { + "type": "roboflow_core/clip_comparison@v2", + "name": "comparison", + "images": "$inputs.image", + "classes": "$inputs.reference", + "version": "$inputs.version", + }, + { + "type": "PropertyDefinition", + "name": "property_extraction", + "data": "$steps.comparison.classification_predictions", + "operations": [ + {"type": "ClassificationPropertyExtract", "property_name": "top_class"} + ], + }, + ], + "outputs": [ + { + "type": "JsonField", + "name": "clip_output", + "selector": "$steps.comparison.*", + }, + { + "type": "JsonField", + "name": "class_name", + "selector": "$steps.property_extraction.output", + }, + ], +} + + +def test_workflow_with_clip_comparison_v2_and_property_definition_with_valid_input( + model_manager: ModelManager, + license_plate_image: np.ndarray, + crowd_image: np.ndarray, +) -> None: + # given + workflow_init_parameters = { + "workflows_core.model_manager": model_manager, + "workflows_core.api_key": None, + "workflows_core.step_execution_mode": StepExecutionMode.LOCAL, + } + execution_engine = ExecutionEngine.init( + workflow_definition=WORKFLOW_WITH_CLIP_COMPARISON_V2, + init_parameters=workflow_init_parameters, + max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS, + ) + + # when + result = execution_engine.run( + runtime_parameters={ + "image": [license_plate_image, crowd_image], + "reference": ["car", "crowd"], + } + ) + + # then + assert isinstance(result, list), "Expected list to be delivered" + assert len(result) == 2, "Expected 2 elements in the output for two input images" + assert set(result[0].keys()) == { + "clip_output", + "class_name", + }, "Expected all declared outputs to be delivered" + assert set(result[1].keys()) == { + "clip_output", + "class_name", + }, "Expected all declared outputs to be delivered" + assert np.allclose( + result[0]["clip_output"]["similarities"], + [0.23334351181983948, 0.17259158194065094], + atol=1e-4, + ), "Expected predicted similarities to match values verified at test creation" + assert ( + abs( + result[0]["clip_output"]["similarities"][0] + - result[0]["clip_output"]["max_similarity"] + ) + < 1e-5 + ), "Expected max similarity to be correct" + assert ( + abs( + result[0]["clip_output"]["similarities"][1] + - result[0]["clip_output"]["min_similarity"] + ) + < 1e-5 + ), "Expected max similarity to be correct" + assert ( + result[0]["clip_output"]["most_similar_class"] == "car" + ), "Expected most similar class to be extracted properly" + assert ( + result[0]["clip_output"]["least_similar_class"] == "crowd" + ), "Expected least similar class to be extracted properly" + assert ( + result[0]["clip_output"]["classification_predictions"]["top"] == "car" + ), "Expected classifier output to be shaped correctly" + assert ( + result[0]["class_name"] == "car" + ), "Expected property definition step to cooperate nicely with clip output" + assert np.allclose( + result[1]["clip_output"]["similarities"], + [0.18426208198070526, 0.207647442817688], + atol=1e-4, + ), "Expected predicted similarities to match values verified at test creation" + assert ( + abs( + result[1]["clip_output"]["similarities"][1] + - result[1]["clip_output"]["max_similarity"] + ) + < 1e-5 + ), "Expected max similarity to be correct" + assert ( + abs( + result[1]["clip_output"]["similarities"][0] + - result[1]["clip_output"]["min_similarity"] + ) + < 1e-5 + ), "Expected max similarity to be correct" + assert ( + result[1]["clip_output"]["most_similar_class"] == "crowd" + ), "Expected most similar class to be extracted properly" + assert ( + result[1]["clip_output"]["least_similar_class"] == "car" + ), "Expected least similar class to be extracted properly" + assert ( + result[1]["clip_output"]["classification_predictions"]["top"] == "crowd" + ), "Expected classifier output to be shaped correctly" + assert ( + result[1]["class_name"] == "crowd" + ), "Expected property definition step to cooperate nicely with clip output" From 83ba3530432e65592b2c98e12a432e3a30bd4b0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Tue, 27 Aug 2024 16:45:41 +0200 Subject: [PATCH 2/3] Improve tests --- .../models/foundation/clip_comparison/v2.py | 17 +----- .../execution/test_workflow_with_clip.py | 56 +++++++++++++++++++ 2 files changed, 57 insertions(+), 16 deletions(-) diff --git a/inference/core/workflows/core_steps/models/foundation/clip_comparison/v2.py b/inference/core/workflows/core_steps/models/foundation/clip_comparison/v2.py index ce7f69c009..dad0d6b471 100644 --- a/inference/core/workflows/core_steps/models/foundation/clip_comparison/v2.py +++ b/inference/core/workflows/core_steps/models/foundation/clip_comparison/v2.py @@ -57,18 +57,6 @@ EXPECTED_OUTPUT_KEYS = {"similarity", "parent_id", "root_parent_id", "prediction_type"} -ALL_CLIP_VARIANTS = { - "RN101", - "RN50", - "RN50x16", - "RN50x4", - "RN50x64", - "ViT-B-16", - "ViT-B-32", - "ViT-L-14-336px", - "ViT-L-14", -} - class BlockManifest(WorkflowBlockManifest): model_config = ConfigDict( @@ -88,6 +76,7 @@ class BlockManifest(WorkflowBlockManifest): Field( description="List of classes to calculate similarity against each input image", examples=[["a", "b", "c"], "$inputs.texts"], + min_items=1, ) ) version: Union[ @@ -161,10 +150,6 @@ def run( classes: List[str], version: str, ) -> BlockResult: - if version not in ALL_CLIP_VARIANTS: - raise ValueError(f"Supported CLIP versions do not involve {version}") - if not classes: - raise ValueError("Provided empty class list for CLIP Comparison step") if self._step_execution_mode is StepExecutionMode.LOCAL: return self.run_locally(images=images, classes=classes, version=version) elif self._step_execution_mode is StepExecutionMode.REMOTE: diff --git a/tests/workflows/integration_tests/execution/test_workflow_with_clip.py b/tests/workflows/integration_tests/execution/test_workflow_with_clip.py index 5a32fbfd4e..e07116b71c 100644 --- a/tests/workflows/integration_tests/execution/test_workflow_with_clip.py +++ b/tests/workflows/integration_tests/execution/test_workflow_with_clip.py @@ -4,6 +4,7 @@ from inference.core.env import WORKFLOWS_MAX_CONCURRENT_STEPS from inference.core.managers.base import ModelManager from inference.core.workflows.core_steps.common.entities import StepExecutionMode +from inference.core.workflows.errors import RuntimeInputError from inference.core.workflows.execution_engine.core import ExecutionEngine CLIP_WORKFLOW = { @@ -215,3 +216,58 @@ def test_workflow_with_clip_comparison_v2_and_property_definition_with_valid_inp assert ( result[1]["class_name"] == "crowd" ), "Expected property definition step to cooperate nicely with clip output" + + +def test_workflow_with_clip_comparison_v2_and_property_definition_with_empty_class_list( + model_manager: ModelManager, + license_plate_image: np.ndarray, + crowd_image: np.ndarray, +) -> None: + # given + workflow_init_parameters = { + "workflows_core.model_manager": model_manager, + "workflows_core.api_key": None, + "workflows_core.step_execution_mode": StepExecutionMode.LOCAL, + } + execution_engine = ExecutionEngine.init( + workflow_definition=WORKFLOW_WITH_CLIP_COMPARISON_V2, + init_parameters=workflow_init_parameters, + max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS, + ) + + # when + with pytest.raises(RuntimeInputError): + _ = execution_engine.run( + runtime_parameters={ + "image": [license_plate_image, crowd_image], + "reference": [], + } + ) + + +def test_workflow_with_clip_comparison_v2_and_property_definition_with_invalid_model_version( + model_manager: ModelManager, + license_plate_image: np.ndarray, + crowd_image: np.ndarray, +) -> None: + # given + workflow_init_parameters = { + "workflows_core.model_manager": model_manager, + "workflows_core.api_key": None, + "workflows_core.step_execution_mode": StepExecutionMode.LOCAL, + } + execution_engine = ExecutionEngine.init( + workflow_definition=WORKFLOW_WITH_CLIP_COMPARISON_V2, + init_parameters=workflow_init_parameters, + max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS, + ) + + # when + with pytest.raises(RuntimeInputError): + _ = execution_engine.run( + runtime_parameters={ + "image": [license_plate_image, crowd_image], + "reference": ["car", "crowd"], + "version": "invalid", + } + ) From ef07ceb61291f354ca1de06b441ceea0be763700 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Wed, 28 Aug 2024 19:49:23 +0200 Subject: [PATCH 3/3] Fix issues spotted while testing --- .../core_steps/models/foundation/clip_comparison/v2.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/inference/core/workflows/core_steps/models/foundation/clip_comparison/v2.py b/inference/core/workflows/core_steps/models/foundation/clip_comparison/v2.py index dad0d6b471..0ba9772d38 100644 --- a/inference/core/workflows/core_steps/models/foundation/clip_comparison/v2.py +++ b/inference/core/workflows/core_steps/models/foundation/clip_comparison/v2.py @@ -55,8 +55,6 @@ in an image, or if an image contains NSFW material. """ -EXPECTED_OUTPUT_KEYS = {"similarity", "parent_id", "root_parent_id", "prediction_type"} - class BlockManifest(WorkflowBlockManifest): model_config = ConfigDict( @@ -94,7 +92,7 @@ class BlockManifest(WorkflowBlockManifest): WorkflowParameterSelector(kind=[STRING_KIND]), ] = Field( default="ViT-B-16", - description="Variant of YoloWorld model", + description="Variant of CLIP model", examples=["ViT-B-16", "$inputs.variant"], ) @@ -114,7 +112,6 @@ def describe_outputs(cls) -> List[OutputDefinition]: name="classification_predictions", kind=[BATCH_OF_CLASSIFICATION_PREDICTION_KIND], ), - OutputDefinition(name="inference_id", kind=[STRING_KIND]), OutputDefinition(name="parent_id", kind=[BATCH_OF_PARENT_ID_KIND]), OutputDefinition(name="root_parent_id", kind=[BATCH_OF_PARENT_ID_KIND]), ] @@ -235,7 +232,6 @@ def _post_process_result( results = [] for prediction, image in zip(predictions, images): similarities = prediction["similarity"] - inference_id = prediction.get("inference_id") max_similarity = np.max(similarities) max_similarity_id = np.argmax(similarities) min_similarity = np.min(similarities) @@ -260,12 +256,10 @@ def _post_process_result( "top": most_similar_class_name, "confidence": max_similarity, "parent_id": image.parent_metadata.parent_id, - "inference_id": inference_id, } result = { PARENT_ID_KEY: image.parent_metadata.parent_id, ROOT_PARENT_ID_KEY: image.workflow_root_ancestor_metadata.parent_id, - "inference_id": inference_id, "similarities": similarities, "max_similarity": max_similarity, "most_similar_class": most_similar_class_name,