diff --git a/inference/core/entities/requests/workflows.py b/inference/core/entities/requests/workflows.py index cb9db3eba3..0a94d612e6 100644 --- a/inference/core/entities/requests/workflows.py +++ b/inference/core/entities/requests/workflows.py @@ -22,6 +22,10 @@ class WorkflowInferenceRequest(BaseModel): class WorkflowSpecificationInferenceRequest(WorkflowInferenceRequest): specification: dict + is_preview: bool = Field( + default=False, + description="Reserved, used internally by Roboflow to distinguish between preview and non-preview runs", + ) class DescribeBlocksRequest(BaseModel): diff --git a/inference/core/interfaces/stream/inference_pipeline.py b/inference/core/interfaces/stream/inference_pipeline.py index db160aa87b..7e5dae4ca5 100644 --- a/inference/core/interfaces/stream/inference_pipeline.py +++ b/inference/core/interfaces/stream/inference_pipeline.py @@ -54,7 +54,6 @@ from inference.core.workflows.core_steps.common.entities import StepExecutionMode from inference.models.aliases import resolve_roboflow_model_alias from inference.models.utils import ROBOFLOW_MODEL_TYPES, get_model -from inference.usage_tracking.collector import usage_collector INFERENCE_PIPELINE_CONTEXT = "inference_pipeline" SOURCE_CONNECTION_ATTEMPT_FAILED_EVENT = "SOURCE_CONNECTION_ATTEMPT_FAILED" diff --git a/inference/core/interfaces/stream/model_handlers/roboflow_models.py b/inference/core/interfaces/stream/model_handlers/roboflow_models.py index 85a00654ab..3145b135ba 100644 --- a/inference/core/interfaces/stream/model_handlers/roboflow_models.py +++ b/inference/core/interfaces/stream/model_handlers/roboflow_models.py @@ -15,6 +15,8 @@ def default_process_frame( predictions = wrap_in_list( model.infer( [f.image for f in video_frame], + usage_fps=video_frame[0].fps, + usage_api_key=model.api_key, **postprocessing_args, ) ) diff --git a/inference/core/models/base.py b/inference/core/models/base.py index bef2d83369..5dbdec3da6 100644 --- a/inference/core/models/base.py +++ b/inference/core/models/base.py @@ -7,6 +7,7 @@ from inference.core.entities.requests.inference import InferenceRequest from inference.core.entities.responses.inference import InferenceResponse from inference.core.models.types import PreprocessReturnMetadata +from inference.usage_tracking.collector import usage_collector class BaseInference: @@ -15,6 +16,7 @@ class BaseInference: This class provides a basic interface for inference tasks. """ + @usage_collector def infer(self, image: Any, **kwargs) -> Any: """Runs inference on given data. - image: diff --git a/inference/core/models/roboflow.py b/inference/core/models/roboflow.py index 5a28b3821d..a6e9dfc8de 100644 --- a/inference/core/models/roboflow.py +++ b/inference/core/models/roboflow.py @@ -667,7 +667,7 @@ def validate_model(self) -> None: def run_test_inference(self) -> None: test_image = (np.random.rand(1024, 1024, 3) * 255).astype(np.uint8) logger.debug(f"Running test inference. Image size: {test_image.shape}") - result = self.infer(test_image) + result = self.infer(test_image, usage_inference_test_run=True) logger.debug(f"Test inference finished.") return result diff --git a/inference/core/version.py b/inference/core/version.py index 8b9c9d1e0a..11b4081a90 100644 --- a/inference/core/version.py +++ b/inference/core/version.py @@ -1,4 +1,4 @@ -__version__ = "0.16.2" +__version__ = "0.16.3" if __name__ == "__main__": diff --git a/inference/core/workflows/execution_engine/core.py b/inference/core/workflows/execution_engine/core.py index e9d9ad4010..7b7f78fbde 100644 --- a/inference/core/workflows/execution_engine/core.py +++ b/inference/core/workflows/execution_engine/core.py @@ -61,10 +61,12 @@ def run( self, runtime_parameters: Dict[str, Any], fps: float = 0, + _is_preview: bool = False, ) -> List[Dict[str, Any]]: return self._engine.run( runtime_parameters=runtime_parameters, fps=fps, + _is_preview=_is_preview, ) diff --git a/inference/core/workflows/execution_engine/entities/engine.py b/inference/core/workflows/execution_engine/entities/engine.py index 06a31880d2..021fba5b21 100644 --- a/inference/core/workflows/execution_engine/entities/engine.py +++ b/inference/core/workflows/execution_engine/entities/engine.py @@ -21,5 +21,6 @@ def run( self, runtime_parameters: Dict[str, Any], fps: float = 0, + _is_preview: bool = False, ) -> List[Dict[str, Any]]: pass diff --git a/inference/core/workflows/execution_engine/v1/core.py b/inference/core/workflows/execution_engine/v1/core.py index 72ed61ec96..20426d8621 100644 --- a/inference/core/workflows/execution_engine/v1/core.py +++ b/inference/core/workflows/execution_engine/v1/core.py @@ -61,6 +61,7 @@ def run( self, runtime_parameters: Dict[str, Any], fps: float = 0, + _is_preview: bool = False, ) -> List[Dict[str, Any]]: runtime_parameters = assemble_runtime_parameters( runtime_parameters=runtime_parameters, @@ -77,4 +78,5 @@ def run( max_concurrent_steps=self._max_concurrent_steps, usage_fps=fps, usage_workflow_id=self._workflow_id, + usage_workflow_preview=_is_preview, ) diff --git a/inference/usage_tracking/collector.py b/inference/usage_tracking/collector.py index 9536ced7e2..6eac36af1f 100644 --- a/inference/usage_tracking/collector.py +++ b/inference/usage_tracking/collector.py @@ -12,7 +12,6 @@ from threading import Event, Lock, Thread from uuid import uuid4 -import requests from typing_extensions import ( Any, Callable, @@ -20,9 +19,9 @@ Dict, List, Optional, - Set, + ParamSpec, Tuple, - Union, + TypeVar, ) from inference.core.env import API_KEY, LAMBDA, REDIS_HOST @@ -34,17 +33,22 @@ from inference.usage_tracking.utils import collect_func_params from .config import TelemetrySettings, get_telemetry_settings +from .payload_helpers import ( + APIKey, + APIKeyHash, + APIKeyUsage, + ResourceDetails, + ResourceID, + SystemDetails, + UsagePayload, + send_usage_payload, + zip_usage_payloads, +) +from .redis_queue import RedisQueue from .sqlite_queue import SQLiteQueue -ResourceID = str -Usage = Union[DefaultDict[str, Any], Dict[str, Any]] -ResourceUsage = Union[DefaultDict[ResourceID, Usage], Dict[ResourceID, Usage]] -APIKey = str -APIKeyHash = str -APIKeyUsage = Union[DefaultDict[APIKey, ResourceUsage], Dict[APIKey, ResourceUsage]] -ResourceDetails = Dict[str, Any] -SystemDetails = Dict[str, Any] -UsagePayload = Union[APIKeyUsage, ResourceDetails, SystemDetails] +T = TypeVar("T") +P = ParamSpec("P") class UsageCollector: @@ -76,22 +80,32 @@ def __init__(self): exec_session_id=self._exec_session_id ) - if LAMBDA or self._settings.opt_out: + self._hashed_api_keys: Dict[APIKey, APIKeyHash] = {} + self._api_keys_hashing_enabled = True + + if LAMBDA and REDIS_HOST: + logger.debug("Persistence through RedisQueue") + self._queue: "Queue[UsagePayload]" = RedisQueue() + self._api_keys_hashing_enabled = False + elif LAMBDA or self._settings.opt_out: + logger.debug("No persistence") self._queue: "Queue[UsagePayload]" = Queue( maxsize=self._settings.queue_size ) + self._api_keys_hashing_enabled = False else: try: self._queue = SQLiteQueue() + logger.debug("Persistence through SQLiteQueue") except Exception as exc: logger.debug("Unable to create instance of SQLiteQueue, %s", exc) + logger.debug("No persistence") self._queue: "Queue[UsagePayload]" = Queue( maxsize=self._settings.queue_size ) + self._api_keys_hashing_enabled = False self._queue_lock = Lock() - self._hashed_api_keys: Dict[APIKey, APIKeyHash] = {} - self._system_info_sent: bool = False self._resource_details_lock = Lock() self._resource_details: DefaultDict[APIKey, Dict[ResourceID, bool]] = ( @@ -128,23 +142,6 @@ def empty_usage_dict(exec_session_id: str) -> APIKeyUsage: ) ) - @staticmethod - def _merge_usage_dicts(d1: UsagePayload, d2: UsagePayload): - merged = {} - if d1 and d2 and d1.get("resource_id") != d2.get("resource_id"): - raise ValueError("Cannot merge usage for different resource IDs") - if "timestamp_start" in d1 and "timestamp_start" in d2: - merged["timestamp_start"] = min( - d1["timestamp_start"], d2["timestamp_start"] - ) - if "timestamp_stop" in d1 and "timestamp_stop" in d2: - merged["timestamp_stop"] = max(d1["timestamp_stop"], d2["timestamp_stop"]) - if "processed_frames" in d1 and "processed_frames" in d2: - merged["processed_frames"] = d1["processed_frames"] + d2["processed_frames"] - if "source_duration" in d1 and "source_duration" in d2: - merged["source_duration"] = d1["source_duration"] + d2["source_duration"] - return {**d1, **d2, **merged} - def _dump_usage_queue_no_lock(self) -> List[APIKeyUsage]: usage_payloads: List[APIKeyUsage] = [] while self._queue: @@ -163,100 +160,6 @@ def _dump_usage_queue_with_lock(self) -> List[APIKeyUsage]: usage_payloads = self._dump_usage_queue_no_lock() return usage_payloads - @staticmethod - def _get_api_key_usage_containing_resource( - api_key_hash: APIKey, usage_payloads: List[APIKeyUsage] - ) -> Optional[ResourceUsage]: - for usage_payload in usage_payloads: - for other_api_key_hash, resource_payloads in usage_payload.items(): - if api_key_hash and other_api_key_hash != api_key_hash: - continue - if other_api_key_hash == "": - continue - for resource_id, resource_usage in resource_payloads.items(): - if not resource_id: - continue - if not resource_usage or "resource_id" not in resource_usage: - continue - return resource_usage - return - - @staticmethod - def _zip_usage_payloads(usage_payloads: List[APIKeyUsage]) -> List[APIKeyUsage]: - merged_api_key_usage_payloads: APIKeyUsage = {} - system_info_payload = None - for usage_payload in usage_payloads: - for api_key_hash, resource_payloads in usage_payload.items(): - if api_key_hash == "": - if ( - resource_payloads - and len(resource_payloads) > 1 - or list(resource_payloads.keys()) != [""] - ): - logger.debug( - "Dropping usage payload %s due to missing API key", - resource_payloads, - ) - continue - api_key_usage_with_resource = ( - UsageCollector._get_api_key_usage_containing_resource( - api_key_hash=api_key_hash, - usage_payloads=usage_payloads, - ) - ) - if not api_key_usage_with_resource: - system_info_payload = resource_payloads - continue - api_key_hash = api_key_usage_with_resource["api_key_hash"] - resource_id = api_key_usage_with_resource["resource_id"] - category = api_key_usage_with_resource.get("category") - for v in resource_payloads.values(): - v["api_key_hash"] = api_key_hash - if "resource_id" not in v or not v["resource_id"]: - v["resource_id"] = resource_id - if "category" not in v or not v["category"]: - v["category"] = category - for ( - resource_usage_key, - resource_usage_payload, - ) in resource_payloads.items(): - if resource_usage_key == "": - api_key_usage_with_resource = ( - UsageCollector._get_api_key_usage_containing_resource( - api_key_hash=api_key_hash, - usage_payloads=usage_payloads, - ) - ) - if not api_key_usage_with_resource: - system_info_payload = {"": resource_usage_payload} - continue - resource_id = api_key_usage_with_resource["resource_id"] - category = api_key_usage_with_resource.get("category") - resource_usage_key = f"{category}:{resource_id}" - resource_usage_payload["api_key_hash"] = api_key_hash - resource_usage_payload["resource_id"] = resource_id - resource_usage_payload["category"] = category - merged_api_key_payload = merged_api_key_usage_payloads.setdefault( - api_key_hash, {} - ) - merged_resource_payload = merged_api_key_payload.setdefault( - resource_usage_key, {} - ) - merged_api_key_payload[resource_usage_key] = ( - UsageCollector._merge_usage_dicts( - merged_resource_payload, - resource_usage_payload, - ) - ) - - zipped_payloads = [merged_api_key_usage_payloads] - if system_info_payload: - system_info_api_key_hash = next(iter(system_info_payload.values()))[ - "api_key_hash" - ] - zipped_payloads.append({system_info_api_key_hash: system_info_payload}) - return zipped_payloads - @staticmethod def _hash(payload: str, length=5): payload_hash = hashlib.sha256(payload.encode()) @@ -269,7 +172,10 @@ def _calculate_api_key_hash(self, api_key: APIKey) -> APIKeyHash: if api_key: api_key_hash = self._hashed_api_keys.get(api_key) if not api_key_hash: - api_key_hash = UsageCollector._hash(api_key) + if self._api_keys_hashing_enabled: + api_key_hash = UsageCollector._hash(api_key) + else: + api_key_hash = api_key self._hashed_api_keys[api_key] = api_key_hash return api_key_hash @@ -287,7 +193,7 @@ def _enqueue_payload(self, payload: UsagePayload): else: usage_payloads = self._dump_usage_queue_no_lock() usage_payloads.append(payload) - merged_usage_payloads = self._zip_usage_payloads( + merged_usage_payloads = zip_usage_payloads( usage_payloads=usage_payloads, ) for usage_payload in merged_usage_payloads: @@ -429,6 +335,7 @@ def _update_usage_payload( api_key: APIKey = "", resource_details: Optional[Dict[str, Any]] = None, resource_id: str = "", + inference_test_run: bool = False, fps: float = 0, enterprise: bool = False, ): @@ -441,9 +348,11 @@ def _update_usage_payload( if not source_usage["timestamp_start"]: source_usage["timestamp_start"] = time.time_ns() source_usage["timestamp_stop"] = time.time_ns() - source_usage["processed_frames"] += frames + source_usage["processed_frames"] += frames if not inference_test_run else 0 source_usage["fps"] = round(fps, 2) - source_usage["source_duration"] += frames / fps if fps else 0 + source_usage["source_duration"] += ( + frames / fps if fps and not inference_test_run else 0 + ) source_usage["category"] = category source_usage["resource_id"] = resource_id source_usage["api_key_hash"] = api_key_hash @@ -459,6 +368,7 @@ def record_usage( api_key: APIKey = "", resource_details: Optional[Dict[str, Any]] = None, resource_id: str = "", + inference_test_run: bool = False, fps: float = 0, ) -> DefaultDict[str, Any]: if self._settings.opt_out and not enterprise: @@ -481,6 +391,7 @@ def record_usage( api_key=api_key, resource_details=resource_details, resource_id=resource_id, + inference_test_run=inference_test_run, fps=fps, enterprise=enterprise, ) @@ -494,6 +405,7 @@ async def async_record_usage( api_key: APIKey = "", resource_details: Optional[Dict[str, Any]] = None, resource_id: str = "", + inference_test_run: bool = False, fps: float = 0, ) -> DefaultDict[str, Any]: if self._async_lock: @@ -506,6 +418,7 @@ async def async_record_usage( api_key=api_key, resource_details=resource_details, resource_id=resource_id, + inference_test_run=inference_test_run, fps=fps, ) else: @@ -517,6 +430,7 @@ async def async_record_usage( api_key=api_key, resource_details=resource_details, resource_id=resource_id, + inference_test_run=inference_test_run, fps=fps, ) @@ -561,45 +475,18 @@ def _offload_to_api(self, payloads: List[APIKeyUsage]): hashes_to_api_keys = dict(a[::-1] for a in self._hashed_api_keys.items()) - api_keys_hashes_failed = set() for payload in payloads: - for api_key_hash, workflow_payloads in payload.items(): - if api_key_hash not in hashes_to_api_keys: - api_keys_hashes_failed.add(api_key_hash) - continue - api_key = hashes_to_api_keys[api_key_hash] - if any("processed_frames" not in w for w in workflow_payloads.values()): - api_keys_hashes_failed.add(api_key_hash) - continue - try: - for workflow_payload in workflow_payloads.values(): - if api_key_hash in workflow_payload: - del workflow_payload["api_key_hash"] - workflow_payload["api_key"] = api_key - logger.debug( - "Offloading usage to %s, payload: %s", - self._settings.api_usage_endpoint_url, - workflow_payloads, - ) - response = requests.post( - self._settings.api_usage_endpoint_url, - json=list(workflow_payloads.values()), - verify=ssl_verify, - headers={"Authorization": f"Bearer {api_key}"}, - timeout=1, - ) - except Exception as exc: - logger.debug("Failed to send usage - %s", exc) - api_keys_hashes_failed.add(api_key_hash) - continue - if response.status_code != 200: - logger.debug( - "Failed to send usage - got %s status code (%s)", - response.status_code, - response.raw, - ) - api_keys_hashes_failed.add(api_key_hash) - continue + api_keys_hashes_failed = send_usage_payload( + payloads=payloads, + api_usage_endpoint_url=self._settings.api_usage_endpoint_url, + hashes_to_api_keys=hashes_to_api_keys, + ssl_verify=ssl_verify, + ) + if api_keys_hashes_failed: + logger.debug( + "Failed to send usage following usage payloads: %s", + api_keys_hashes_failed, + ) for api_key_hash in list(payload.keys()): if api_key_hash not in api_keys_hashes_failed: del payload[api_key_hash] @@ -621,7 +508,7 @@ async def async_push_usage_payloads(self): @staticmethod def _resource_details_from_workflow_json( workflow_json: Dict[str, Any] - ) -> Tuple[ResourceID, ResourceDetails]: + ) -> ResourceDetails: if not isinstance(workflow_json, dict): raise ValueError("workflow_json must be dict") return { @@ -637,6 +524,8 @@ def _extract_usage_params_from_func_kwargs( usage_fps: float, usage_api_key: str, usage_workflow_id: str, + usage_workflow_preview: bool, + usage_inference_test_run: bool, func: Callable[[Any], Any], args: List[Any], kwargs: Dict[str, Any], @@ -645,12 +534,13 @@ def _extract_usage_params_from_func_kwargs( resource_details = {} resource_id = "" category = None + enterprise = False + # TODO: add requires_api_key, True if workflow definition comes from platform or model comes from workspace if "workflow" in func_kwargs: workflow: CompiledWorkflow = func_kwargs["workflow"] if hasattr(workflow, "workflow_definition"): - # TODO: handle enterprise blocks here + # TODO: extend ParsedWorkflowDefinition to expose `enterprise` workflow_definition = workflow.workflow_definition - enterprise = False if hasattr(workflow, "init_parameters"): init_parameters = workflow.init_parameters if "workflows_core.api_key" in init_parameters: @@ -661,15 +551,34 @@ def _extract_usage_params_from_func_kwargs( resource_details = UsageCollector._resource_details_from_workflow_json( workflow_json=workflow_json, ) + resource_details["is_preview"] = usage_workflow_preview resource_id = usage_workflow_id if not resource_id and resource_details: usage_workflow_id = UsageCollector._calculate_resource_hash( resource_details=resource_details ) category = "workflows" - elif "model_id" in func_kwargs: - # TODO: handle model - pass + elif "self" in func_kwargs: + _self = func_kwargs["self"] + if hasattr(_self, "dataset_id") and hasattr(_self, "version_id"): + model_id = f"{_self.dataset_id}/{_self.version_id}" + category = "model" + resource_id = model_id + elif isinstance(kwargs, dict) and "model_id" in kwargs: + model_id = kwargs["model_id"] + category = "model" + resource_id = model_id + else: + resource_id = "unknown" + category = "unknown" + if isinstance(kwargs, dict) and "source" in kwargs: + resource_details["source"] = kwargs["source"] + if hasattr(_self, "task_type"): + resource_details["task_type"] = _self.task_type + else: + resource_id = "unknown" + category = "unknown" + source = None runtime_parameters = func_kwargs.get("runtime_parameters") if ( @@ -690,24 +599,29 @@ def _extract_usage_params_from_func_kwargs( "category": category, "resource_details": resource_details, "resource_id": resource_id, + "inference_test_run": usage_inference_test_run, "fps": usage_fps, "enterprise": enterprise, } - def __call__(self, func: Callable[[Any], Any]): + def __call__(self, func: Callable[P, T]) -> Callable[P, T]: @wraps(func) def sync_wrapper( - *args, + *args: P.args, usage_fps: float = 0, usage_api_key: APIKey = "", usage_workflow_id: str = "", - **kwargs, - ): + usage_workflow_preview: bool = False, + usage_inference_test_run: bool = False, + **kwargs: P.kwargs, + ) -> T: self.record_usage( **self._extract_usage_params_from_func_kwargs( usage_fps=usage_fps, usage_api_key=usage_api_key, usage_workflow_id=usage_workflow_id, + usage_workflow_preview=usage_workflow_preview, + usage_inference_test_run=usage_inference_test_run, func=func, args=args, kwargs=kwargs, @@ -717,17 +631,21 @@ def sync_wrapper( @wraps(func) async def async_wrapper( - *args, + *args: P.args, usage_fps: float = 0, usage_api_key: APIKey = "", usage_workflow_id: str = "", - **kwargs, - ): + usage_workflow_preview: bool = False, + usage_inference_test_run: bool = False, + **kwargs: P.kwargs, + ) -> T: await self.async_record_usage( **self._extract_usage_params_from_func_kwargs( usage_fps=usage_fps, usage_api_key=usage_api_key, usage_workflow_id=usage_workflow_id, + usage_workflow_preview=usage_workflow_preview, + usage_inference_test_run=usage_inference_test_run, func=func, args=args, kwargs=kwargs, diff --git a/inference/usage_tracking/config.py b/inference/usage_tracking/config.py index 108209088b..80f97df280 100644 --- a/inference/usage_tracking/config.py +++ b/inference/usage_tracking/config.py @@ -13,7 +13,7 @@ class TelemetrySettings(BaseSettings): model_config = SettingsConfigDict(env_prefix="telemetry_") - api_usage_endpoint_url: str = "https://api.roboflow.one/usage/inference" + api_usage_endpoint_url: str = "https://api.roboflow.com/usage/inference" flush_interval: int = Field(default=10, ge=10, le=300) opt_out: Optional[bool] = False queue_size: int = Field(default=10, ge=10, le=10000) diff --git a/inference/usage_tracking/payload_helpers.py b/inference/usage_tracking/payload_helpers.py new file mode 100644 index 0000000000..996424ac32 --- /dev/null +++ b/inference/usage_tracking/payload_helpers.py @@ -0,0 +1,152 @@ +from typing import Any, DefaultDict, Dict, List, Optional, Set, Union + +import requests + +ResourceID = str +Usage = Union[DefaultDict[str, Any], Dict[str, Any]] +ResourceUsage = Union[DefaultDict[ResourceID, Usage], Dict[ResourceID, Usage]] +APIKey = str +APIKeyHash = str +APIKeyUsage = Union[DefaultDict[APIKey, ResourceUsage], Dict[APIKey, ResourceUsage]] +ResourceDetails = Dict[str, Any] +SystemDetails = Dict[str, Any] +UsagePayload = Union[APIKeyUsage, ResourceDetails, SystemDetails] + + +def merge_usage_dicts(d1: UsagePayload, d2: UsagePayload): + merged = {} + if d1 and d2 and d1.get("resource_id") != d2.get("resource_id"): + raise ValueError("Cannot merge usage for different resource IDs") + if "timestamp_start" in d1 and "timestamp_start" in d2: + merged["timestamp_start"] = min(d1["timestamp_start"], d2["timestamp_start"]) + if "timestamp_stop" in d1 and "timestamp_stop" in d2: + merged["timestamp_stop"] = max(d1["timestamp_stop"], d2["timestamp_stop"]) + if "processed_frames" in d1 and "processed_frames" in d2: + merged["processed_frames"] = d1["processed_frames"] + d2["processed_frames"] + if "source_duration" in d1 and "source_duration" in d2: + merged["source_duration"] = d1["source_duration"] + d2["source_duration"] + return {**d1, **d2, **merged} + + +def get_api_key_usage_containing_resource( + api_key_hash: APIKey, usage_payloads: List[APIKeyUsage] +) -> Optional[ResourceUsage]: + for usage_payload in usage_payloads: + for other_api_key_hash, resource_payloads in usage_payload.items(): + if api_key_hash and other_api_key_hash != api_key_hash: + continue + if other_api_key_hash == "": + continue + for resource_id, resource_usage in resource_payloads.items(): + if not resource_id: + continue + if not resource_usage or "resource_id" not in resource_usage: + continue + return resource_usage + return + + +def zip_usage_payloads(usage_payloads: List[APIKeyUsage]) -> List[APIKeyUsage]: + merged_api_key_usage_payloads: APIKeyUsage = {} + system_info_payload = None + for usage_payload in usage_payloads: + for api_key_hash, resource_payloads in usage_payload.items(): + if api_key_hash == "": + if ( + resource_payloads + and len(resource_payloads) > 1 + or list(resource_payloads.keys()) != [""] + ): + continue + api_key_usage_with_resource = get_api_key_usage_containing_resource( + api_key_hash=api_key_hash, + usage_payloads=usage_payloads, + ) + if not api_key_usage_with_resource: + system_info_payload = resource_payloads + continue + api_key_hash = api_key_usage_with_resource["api_key_hash"] + resource_id = api_key_usage_with_resource["resource_id"] + category = api_key_usage_with_resource.get("category") + for v in resource_payloads.values(): + v["api_key_hash"] = api_key_hash + if "resource_id" not in v or not v["resource_id"]: + v["resource_id"] = resource_id + if "category" not in v or not v["category"]: + v["category"] = category + for ( + resource_usage_key, + resource_usage_payload, + ) in resource_payloads.items(): + if resource_usage_key == "": + api_key_usage_with_resource = get_api_key_usage_containing_resource( + api_key_hash=api_key_hash, + usage_payloads=usage_payloads, + ) + if not api_key_usage_with_resource: + system_info_payload = {"": resource_usage_payload} + continue + resource_id = api_key_usage_with_resource["resource_id"] + category = api_key_usage_with_resource.get("category") + resource_usage_key = f"{category}:{resource_id}" + resource_usage_payload["api_key_hash"] = api_key_hash + resource_usage_payload["resource_id"] = resource_id + resource_usage_payload["category"] = category + merged_api_key_payload = merged_api_key_usage_payloads.setdefault( + api_key_hash, {} + ) + merged_resource_payload = merged_api_key_payload.setdefault( + resource_usage_key, {} + ) + merged_api_key_payload[resource_usage_key] = merge_usage_dicts( + merged_resource_payload, + resource_usage_payload, + ) + + zipped_payloads = [merged_api_key_usage_payloads] + if system_info_payload: + system_info_api_key_hash = next(iter(system_info_payload.values()))[ + "api_key_hash" + ] + zipped_payloads.append({system_info_api_key_hash: system_info_payload}) + return zipped_payloads + + +def send_usage_payload( + payload: UsagePayload, + api_usage_endpoint_url: str, + hashes_to_api_keys: Optional[Dict[APIKeyHash, APIKey]] = None, + ssl_verify: bool = False, +) -> Set[APIKeyHash]: + hashes_to_api_keys = hashes_to_api_keys or {} + api_keys_hashes_failed = set() + for api_key_hash, workflow_payloads in payload.items(): + if hashes_to_api_keys and api_key_hash not in hashes_to_api_keys: + api_keys_hashes_failed.add(api_key_hash) + continue + api_key = hashes_to_api_keys.get(api_key_hash) or api_key_hash + if not api_key: + api_keys_hashes_failed.add(api_key_hash) + continue + complete_workflow_payloads = [ + w for w in workflow_payloads.values() if "processed_frames" in w + ] + try: + for workflow_payload in complete_workflow_payloads: + if "api_key_hash" in workflow_payload: + del workflow_payload["api_key_hash"] + workflow_payload["api_key"] = api_key + response = requests.post( + api_usage_endpoint_url, + json=complete_workflow_payloads, + verify=ssl_verify, + headers={"Authorization": f"Bearer {api_key}"}, + timeout=1, + ) + except Exception: + api_keys_hashes_failed.add(api_key_hash) + continue + if response.status_code != 200: + api_keys_hashes_failed.add(api_key_hash) + continue + return api_keys_hashes_failed diff --git a/inference/usage_tracking/redis_queue.py b/inference/usage_tracking/redis_queue.py new file mode 100644 index 0000000000..206d58c7ba --- /dev/null +++ b/inference/usage_tracking/redis_queue.py @@ -0,0 +1,61 @@ +import json +import time +from threading import Lock +from uuid import uuid4 + +from typing_extensions import Any, Dict, List, Optional + +from inference.core.cache import cache +from inference.core.cache.redis import RedisCache +from inference.core.logger import logger + + +class RedisQueue: + """ + Store and forget, keys with specified hash tag are handled by external service + """ + + def __init__( + self, + hash_tag: str = "UsageCollector", + redis_cache: Optional[RedisCache] = None, + ): + # prefix must contain hash-tag to avoid CROSSLOT errors when using mget + # hash-tag is common part of the key wrapped within '{}' + # removing hash-tag will cause clients utilizing mget to fail + self._prefix: str = f"{{{hash_tag}}}:{uuid4().hex[:5]}:{time.time()}" + self._redis_cache: RedisCache = redis_cache or cache + self._increment: int = 0 + self._lock: Lock = Lock() + + def put(self, payload: Any): + if not isinstance(payload, str): + try: + payload = json.dumps(payload) + except Exception as exc: + logger.error("Failed to parse payload '%s' to JSON - %s", payload, exc) + return + with self._lock: + try: + self._increment += 1 + redis_key = f"{self._prefix}:{self._increment}" + self._redis_cache.client.set( + name=redis_key, + value=payload, + ) + self._redis_cache.client.zadd( + name="UsageCollector", + mapping={redis_key: time.time()}, + ) + except Exception as exc: + logger.error("Failed to store usage records '%s', %s", payload, exc) + + @staticmethod + def full() -> bool: + return False + + def empty(self) -> bool: + return True + + def get_nowait(self) -> List[Dict[str, Any]]: + return [] diff --git a/inference/usage_tracking/sqlite_queue.py b/inference/usage_tracking/sqlite_queue.py index 121bc69e71..f0f42657d6 100644 --- a/inference/usage_tracking/sqlite_queue.py +++ b/inference/usage_tracking/sqlite_queue.py @@ -19,6 +19,8 @@ def __init__( self._db_file_path: str = db_file_path if not connection: + if not os.path.exists(MODEL_CACHE_DIR): + os.makedirs(MODEL_CACHE_DIR) connection: sqlite3.Connection = sqlite3.connect(db_file_path, timeout=1) self._create_table(connection=connection) connection.close() @@ -70,7 +72,7 @@ def put(self, payload: Any, connection: Optional[sqlite3.Connection] = None): self._insert(payload=payload_str, connection=connection) connection.close() except Exception as exc: - logger.debug("Failed to store usage records, %s", exc) + logger.debug("Failed to store usage records '%s', %s", payload, exc) return [] else: self._insert(payload=payload_str, connection=connection) @@ -95,7 +97,7 @@ def _count_rows(self, connection: sqlite3.Connection) -> int: count = int(cursor.fetchone()[0]) connection.commit() except Exception as exc: - logger.debug("Failed to store usage payload, %s", exc) + logger.debug("Failed to obtain records count, %s", exc) connection.rollback() cursor.close() @@ -118,10 +120,12 @@ def empty(self, connection: Optional[sqlite3.Connection] = None) -> bool: return rows_count == 0 - def _flush_db(self, connection: sqlite3.Connection) -> List[Dict[str, Any]]: + def _flush_db( + self, connection: sqlite3.Connection, limit: int = 100 + ) -> List[Dict[str, Any]]: cursor = connection.cursor() - sql_select = f"SELECT {self._col_name} FROM {self._tbl_name}" - sql_delete = f"DELETE FROM {self._tbl_name}" + sql_select = f"SELECT id, {self._col_name} FROM {self._tbl_name} ORDER BY id ASC LIMIT {limit}" + sql_delete = f"DELETE FROM {self._tbl_name} WHERE id >= ? and id <= ?" try: cursor.execute("BEGIN EXCLUSIVE") @@ -133,22 +137,33 @@ def _flush_db(self, connection: sqlite3.Connection) -> List[Dict[str, Any]]: try: cursor.execute(sql_select) payloads = cursor.fetchall() - cursor.execute(sql_delete) - connection.commit() - cursor.close() except Exception as exc: - logger.debug("Failed to store usage payload, %s", exc) + logger.debug("Failed to obtain records, %s", exc) connection.rollback() return [] parsed_payloads = [] - for (payload,) in payloads: + top_id = -1 + bottom_id = -1 + for _id, payload in payloads: + top_id = max(top_id, _id) + if bottom_id == -1: + bottom_id = _id + bottom_id = min(bottom_id, _id) try: parsed_payload = json.loads(payload) parsed_payloads.append(parsed_payload) except Exception as exc: logger.debug("Failed to parse usage payload %s, %s", payload, exc) + try: + cursor.execute(sql_delete, [bottom_id, top_id]) + connection.commit() + cursor.close() + except Exception as exc: + logger.debug("Failed to obtain records, %s", exc) + connection.rollback() + return parsed_payloads def get_nowait( diff --git a/inference/usage_tracking/utils.py b/inference/usage_tracking/utils.py index edeacf37f1..6c78ea3b5c 100644 --- a/inference/usage_tracking/utils.py +++ b/inference/usage_tracking/utils.py @@ -20,7 +20,13 @@ def collect_func_params( for default_arg in defaults: params[default_arg] = signature.parameters[default_arg].default - if set(params) != set(signature.parameters): - logger.error("Params mismatch for %s.%s", func.__module__, func.__name__) + signature_params = set(signature.parameters) + if set(params) != signature_params: + if "kwargs" in signature_params: + params["kwargs"] = kwargs + if "args" in signature_params: + params["args"] = args + if not set(params).issuperset(signature_params): + logger.error("Params mismatch for %s.%s", func.__module__, func.__name__) return params diff --git a/setup.py b/setup.py index f2193dffd4..f14488f312 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,6 @@ def read_requirements(path): ), extras_require={ "sam": read_requirements("requirements/requirements.sam.txt"), - "sam2": read_requirements("requirements/requirements.sam2.txt"), }, classifiers=[ "Programming Language :: Python :: 3", diff --git a/tests/inference/unit_tests/core/interfaces/stream/test_interface_pipeline.py b/tests/inference/unit_tests/core/interfaces/stream/test_interface_pipeline.py index 983170b58d..cff719b09d 100644 --- a/tests/inference/unit_tests/core/interfaces/stream/test_interface_pipeline.py +++ b/tests/inference/unit_tests/core/interfaces/stream/test_interface_pipeline.py @@ -116,6 +116,8 @@ def __next__(self) -> VideoFrame: class ModelStub: + def __init__(self): + self.api_key = None def infer(self, image: Any, **kwargs) -> List[ObjectDetectionInferenceResponse]: return [ ObjectDetectionInferenceResponse( diff --git a/tests/inference/unit_tests/usage_tracking/test_collector.py b/tests/inference/unit_tests/usage_tracking/test_collector.py index c19c143d5d..f0a68a37ce 100644 --- a/tests/inference/unit_tests/usage_tracking/test_collector.py +++ b/tests/inference/unit_tests/usage_tracking/test_collector.py @@ -5,6 +5,7 @@ from inference.core.env import LAMBDA from inference.usage_tracking.collector import UsageCollector +from inference.usage_tracking.payload_helpers import get_api_key_usage_containing_resource, merge_usage_dicts, zip_usage_payloads def test_create_empty_usage_dict(): @@ -45,7 +46,7 @@ def test_merge_usage_dicts_raises_on_mismatched_resource_id(): usage_payload_2 = {"resource_id": "other"} with pytest.raises(ValueError): - UsageCollector._merge_usage_dicts(d1=usage_payload_1, d2=usage_payload_2) + merge_usage_dicts(d1=usage_payload_1, d2=usage_payload_2) def test_merge_usage_dicts_merge_with_empty(): @@ -61,11 +62,11 @@ def test_merge_usage_dicts_merge_with_empty(): usage_payload_2 = {"resource_id": "some", "api_key_hash": "some"} assert ( - UsageCollector._merge_usage_dicts(d1=usage_payload_1, d2=usage_payload_2) + merge_usage_dicts(d1=usage_payload_1, d2=usage_payload_2) == usage_payload_1 ) assert ( - UsageCollector._merge_usage_dicts(d1=usage_payload_2, d2=usage_payload_1) + merge_usage_dicts(d1=usage_payload_2, d2=usage_payload_1) == usage_payload_1 ) @@ -89,7 +90,7 @@ def test_merge_usage_dicts(): "source_duration": 1, } - assert UsageCollector._merge_usage_dicts( + assert merge_usage_dicts( d1=usage_payload_1, d2=usage_payload_2 ) == { "resource_id": "some", @@ -119,7 +120,7 @@ def test_get_api_key_usage_containing_resource_with_no_payload_containing_api_ke ] # when - api_key_usage_with_resource = UsageCollector._get_api_key_usage_containing_resource( + api_key_usage_with_resource = get_api_key_usage_containing_resource( api_key_hash="fake", usage_payloads=usage_payloads ) @@ -167,7 +168,7 @@ def test_get_api_key_usage_containing_resource_with_no_payload_containing_resour ] # when - api_key_usage_with_resource = UsageCollector._get_api_key_usage_containing_resource( + api_key_usage_with_resource = get_api_key_usage_containing_resource( api_key_hash="fake_api2_hash", usage_payloads=usage_payloads ) @@ -205,7 +206,7 @@ def test_get_api_key_usage_containing_resource(): ] # when - api_key_usage_with_resource = UsageCollector._get_api_key_usage_containing_resource( + api_key_usage_with_resource = get_api_key_usage_containing_resource( api_key_hash="fake_api2_hash", usage_payloads=usage_payloads ) @@ -303,7 +304,7 @@ def test_zip_usage_payloads(): ] # when - zipped_usage_payloads = UsageCollector._zip_usage_payloads( + zipped_usage_payloads = zip_usage_payloads( usage_payloads=dumped_usage_payloads ) @@ -395,7 +396,7 @@ def test_zip_usage_payloads_with_system_info_missing_resource_id_and_no_resource ] # when - zipped_usage_payloads = UsageCollector._zip_usage_payloads( + zipped_usage_payloads = zip_usage_payloads( usage_payloads=dumped_usage_payloads ) @@ -458,7 +459,7 @@ def test_zip_usage_payloads_with_system_info_missing_resource_id(): ] # when - zipped_usage_payloads = UsageCollector._zip_usage_payloads( + zipped_usage_payloads = zip_usage_payloads( usage_payloads=dumped_usage_payloads ) @@ -513,7 +514,7 @@ def test_zip_usage_payloads_with_system_info_missing_resource_id_and_api_key(): ] # when - zipped_usage_payloads = UsageCollector._zip_usage_payloads( + zipped_usage_payloads = zip_usage_payloads( usage_payloads=dumped_usage_payloads )