From c69334f4e411bc1d18e0a4a3f8115c3397921fa3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Tue, 7 Jan 2025 13:26:15 +0100 Subject: [PATCH] Make linter happy and fix assertions mentioned in https://github.com/roboflow/inference/issues/904 --- .../models/foundation/llama_vision/v1.py | 32 +++++++++---------- .../v2/test_workflow_for_classification.py | 4 +-- .../sampling/test_identify_outliers.py | 4 ++- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/inference/core/workflows/core_steps/models/foundation/llama_vision/v1.py b/inference/core/workflows/core_steps/models/foundation/llama_vision/v1.py index 7c36298cd..a8e289a1a 100644 --- a/inference/core/workflows/core_steps/models/foundation/llama_vision/v1.py +++ b/inference/core/workflows/core_steps/models/foundation/llama_vision/v1.py @@ -84,6 +84,7 @@ "structured-answering", } + class BlockManifest(WorkflowBlockManifest): model_config = ConfigDict( json_schema_extra={ @@ -93,7 +94,7 @@ class BlockManifest(WorkflowBlockManifest): "long_description": LONG_DESCRIPTION, "license": "Apache-2.0", "block_type": "model", - "search_keywords": ["LMM", "VLM", "Llama","Vision","Meta"], + "search_keywords": ["LMM", "VLM", "Llama", "Vision", "Meta"], "is_vlm_block": True, "task_type_property": "task_type", }, @@ -172,9 +173,7 @@ class BlockManifest(WorkflowBlockManifest): examples=["xxx-xxx", "$inputs.llama_api_key"], private=True, ) - model_version: Union[ - Selector(kind=[STRING_KIND]), Literal["vision-11B"] - ] = Field( + model_version: Union[Selector(kind=[STRING_KIND]), Literal["vision-11B"]] = Field( default="Llama-Vision-11B", description="Model to be used", examples=["Llama-Vision-11B", "$inputs.llama_model"], @@ -274,7 +273,7 @@ def run( model_version: str, max_tokens: int, temperature: float, - top_p : Optional[float], + top_p: Optional[float], max_concurrent_requests: Optional[int], ) -> BlockResult: inference_images = [i.to_inference_format() for i in images] @@ -288,7 +287,7 @@ def run( llama_model_version=model_version, max_tokens=max_tokens, temperature=temperature, - top_p = top_p, + top_p=top_p, max_concurrent_requests=max_concurrent_requests, ) return [ @@ -306,7 +305,7 @@ def run_llama_vision_32_llm_prompting( llama_model_version: str, max_tokens: int, temperature: float, - top_p : Optional[float], + top_p: Optional[float], max_concurrent_requests: Optional[int], ) -> List[str]: if task_type not in PROMPT_BUILDERS: @@ -330,7 +329,7 @@ def run_llama_vision_32_llm_prompting( llama_model_version=llama_model_version, max_tokens=max_tokens, temperature=temperature, - top_p = top_p, + top_p=top_p, max_concurrent_requests=max_concurrent_requests, ) @@ -341,7 +340,7 @@ def execute_llama_vision_32_requests( llama_model_version: str, max_tokens: int, temperature: float, - top_p : Optional[float], + top_p: Optional[float], max_concurrent_requests: Optional[int], ) -> List[str]: llama_model_version = MODEL_NAME_MAPPING.get(llama_model_version) @@ -349,8 +348,7 @@ def execute_llama_vision_32_requests( raise ValueError( f"Invalid model name: '{llama_model_version}'. Please use one of {list(MODEL_NAME_MAPPING.keys())}." ) - client = OpenAI(base_url="https://openrouter.ai/api/v1", - api_key=llama_api_key) + client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=llama_api_key) tasks = [ partial( execute_llama_vision_32_request, @@ -359,7 +357,7 @@ def execute_llama_vision_32_requests( llama_model_version=llama_model_version, max_tokens=max_tokens, temperature=temperature, - top_p = top_p, + top_p=top_p, ) for prompt in llama_prompts ] @@ -372,13 +370,14 @@ def execute_llama_vision_32_requests( max_workers=max_workers, ) + def execute_llama_vision_32_request( client: OpenAI, prompt: List[dict], llama_model_version: str, max_tokens: int, temperature: float, - top_p : Optional[float], + top_p: Optional[float], ) -> str: if temperature is None: temperature = 1 @@ -387,7 +386,7 @@ def execute_llama_vision_32_request( messages=prompt, max_tokens=max_tokens, temperature=temperature, - top_p = top_p, + top_p=top_p, ) return response.choices[0].message.content @@ -413,6 +412,7 @@ def prepare_unconstrained_prompt( } ] + def prepare_classification_prompt( base64_image: str, classes: List[str], gpt_image_detail: str, **kwargs ) -> List[dict]: @@ -501,6 +501,7 @@ def prepare_vqa_prompt( }, ] + def prepare_ocr_prompt( base64_image: str, gpt_image_detail: str, **kwargs ) -> List[dict]: @@ -592,6 +593,3 @@ def prepare_structured_answering_prompt( "multi-label-classification": prepare_multi_label_classification_prompt, "structured-answering": prepare_structured_answering_prompt, } - - - diff --git a/tests/inference/hosted_platform_tests/workflows_examples/roboflow_models/v2/test_workflow_for_classification.py b/tests/inference/hosted_platform_tests/workflows_examples/roboflow_models/v2/test_workflow_for_classification.py index 86aefe010..8418033c6 100644 --- a/tests/inference/hosted_platform_tests/workflows_examples/roboflow_models/v2/test_workflow_for_classification.py +++ b/tests/inference/hosted_platform_tests/workflows_examples/roboflow_models/v2/test_workflow_for_classification.py @@ -124,8 +124,8 @@ def test_multi_class_classification_workflow( {"cat", "dog"}, ], PlatformEnvironment.ROBOFLOW_PLATFORM: [ - {"dog"}, - set(), + {"cat", "dog"}, + {"cat", "dog"} ], } diff --git a/tests/workflows/unit_tests/core_steps/sampling/test_identify_outliers.py b/tests/workflows/unit_tests/core_steps/sampling/test_identify_outliers.py index 2e5e77f97..b7b1e45ee 100644 --- a/tests/workflows/unit_tests/core_steps/sampling/test_identify_outliers.py +++ b/tests/workflows/unit_tests/core_steps/sampling/test_identify_outliers.py @@ -15,7 +15,9 @@ def get_perturbed_value(initial_value: np.ndarray, perturbation: float) -> np.nd ) -@pytest.mark.skip(reason="Solve flakiness of the block: https://github.com/roboflow/inference/issues/901") +@pytest.mark.skip( + reason="Solve flakiness of the block: https://github.com/roboflow/inference/issues/901" +) def test_identify_outliers() -> None: # given identify_changes_block = IdentifyOutliersBlockV1()