Skip to content

Commit

Permalink
Make linter happy and fix assertions mentioned in #904
Browse files Browse the repository at this point in the history
  • Loading branch information
PawelPeczek-Roboflow committed Jan 7, 2025
1 parent a2e19ab commit c69334f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
"structured-answering",
}


class BlockManifest(WorkflowBlockManifest):
model_config = ConfigDict(
json_schema_extra={
Expand All @@ -93,7 +94,7 @@ class BlockManifest(WorkflowBlockManifest):
"long_description": LONG_DESCRIPTION,
"license": "Apache-2.0",
"block_type": "model",
"search_keywords": ["LMM", "VLM", "Llama","Vision","Meta"],
"search_keywords": ["LMM", "VLM", "Llama", "Vision", "Meta"],
"is_vlm_block": True,
"task_type_property": "task_type",
},
Expand Down Expand Up @@ -172,9 +173,7 @@ class BlockManifest(WorkflowBlockManifest):
examples=["xxx-xxx", "$inputs.llama_api_key"],
private=True,
)
model_version: Union[
Selector(kind=[STRING_KIND]), Literal["vision-11B"]
] = Field(
model_version: Union[Selector(kind=[STRING_KIND]), Literal["vision-11B"]] = Field(
default="Llama-Vision-11B",
description="Model to be used",
examples=["Llama-Vision-11B", "$inputs.llama_model"],
Expand Down Expand Up @@ -274,7 +273,7 @@ def run(
model_version: str,
max_tokens: int,
temperature: float,
top_p : Optional[float],
top_p: Optional[float],
max_concurrent_requests: Optional[int],
) -> BlockResult:
inference_images = [i.to_inference_format() for i in images]
Expand All @@ -288,7 +287,7 @@ def run(
llama_model_version=model_version,
max_tokens=max_tokens,
temperature=temperature,
top_p = top_p,
top_p=top_p,
max_concurrent_requests=max_concurrent_requests,
)
return [
Expand All @@ -306,7 +305,7 @@ def run_llama_vision_32_llm_prompting(
llama_model_version: str,
max_tokens: int,
temperature: float,
top_p : Optional[float],
top_p: Optional[float],
max_concurrent_requests: Optional[int],
) -> List[str]:
if task_type not in PROMPT_BUILDERS:
Expand All @@ -330,7 +329,7 @@ def run_llama_vision_32_llm_prompting(
llama_model_version=llama_model_version,
max_tokens=max_tokens,
temperature=temperature,
top_p = top_p,
top_p=top_p,
max_concurrent_requests=max_concurrent_requests,
)

Expand All @@ -341,16 +340,15 @@ def execute_llama_vision_32_requests(
llama_model_version: str,
max_tokens: int,
temperature: float,
top_p : Optional[float],
top_p: Optional[float],
max_concurrent_requests: Optional[int],
) -> List[str]:
llama_model_version = MODEL_NAME_MAPPING.get(llama_model_version)
if not llama_model_version:
raise ValueError(
f"Invalid model name: '{llama_model_version}'. Please use one of {list(MODEL_NAME_MAPPING.keys())}."
)
client = OpenAI(base_url="https://openrouter.ai/api/v1",
api_key=llama_api_key)
client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=llama_api_key)
tasks = [
partial(
execute_llama_vision_32_request,
Expand All @@ -359,7 +357,7 @@ def execute_llama_vision_32_requests(
llama_model_version=llama_model_version,
max_tokens=max_tokens,
temperature=temperature,
top_p = top_p,
top_p=top_p,
)
for prompt in llama_prompts
]
Expand All @@ -372,13 +370,14 @@ def execute_llama_vision_32_requests(
max_workers=max_workers,
)


def execute_llama_vision_32_request(
client: OpenAI,
prompt: List[dict],
llama_model_version: str,
max_tokens: int,
temperature: float,
top_p : Optional[float],
top_p: Optional[float],
) -> str:
if temperature is None:
temperature = 1
Expand All @@ -387,7 +386,7 @@ def execute_llama_vision_32_request(
messages=prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p = top_p,
top_p=top_p,
)
return response.choices[0].message.content

Expand All @@ -413,6 +412,7 @@ def prepare_unconstrained_prompt(
}
]


def prepare_classification_prompt(
base64_image: str, classes: List[str], gpt_image_detail: str, **kwargs
) -> List[dict]:
Expand Down Expand Up @@ -501,6 +501,7 @@ def prepare_vqa_prompt(
},
]


def prepare_ocr_prompt(
base64_image: str, gpt_image_detail: str, **kwargs
) -> List[dict]:
Expand Down Expand Up @@ -592,6 +593,3 @@ def prepare_structured_answering_prompt(
"multi-label-classification": prepare_multi_label_classification_prompt,
"structured-answering": prepare_structured_answering_prompt,
}



Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def test_multi_class_classification_workflow(
{"cat", "dog"},
],
PlatformEnvironment.ROBOFLOW_PLATFORM: [
{"dog"},
set(),
{"cat", "dog"},
{"cat", "dog"}
],
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def get_perturbed_value(initial_value: np.ndarray, perturbation: float) -> np.nd
)


@pytest.mark.skip(reason="Solve flakiness of the block: https://github.com/roboflow/inference/issues/901")
@pytest.mark.skip(
reason="Solve flakiness of the block: https://github.com/roboflow/inference/issues/901"
)
def test_identify_outliers() -> None:
# given
identify_changes_block = IdentifyOutliersBlockV1()
Expand Down

0 comments on commit c69334f

Please sign in to comment.