Skip to content

Commit

Permalink
Merge pull request #632 from roboflow/keep-distinct-exec-sessions-for…
Browse files Browse the repository at this point in the history
…-inf-pipeline-usage

Keep distinct exec sessions for inf pipeline usage
  • Loading branch information
grzegorz-roboflow authored Sep 5, 2024
2 parents 04642f7 + 8e1ff4d commit ad91156
Show file tree
Hide file tree
Showing 5 changed files with 334 additions and 35 deletions.
28 changes: 7 additions & 21 deletions inference/usage_tracking/collector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import atexit
import hashlib
import json
import mimetypes
import socket
Expand All @@ -10,19 +9,10 @@
from functools import wraps
from queue import Queue
from threading import Event, Lock, Thread
from typing import Tuple
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Tuple, TypeVar
from uuid import uuid4

from typing_extensions import (
Any,
Callable,
DefaultDict,
Dict,
List,
Optional,
ParamSpec,
TypeVar,
)
from typing_extensions import ParamSpec

from inference.core.env import API_KEY, LAMBDA, REDIS_HOST
from inference.core.logger import logger
Expand All @@ -43,6 +33,7 @@
SystemDetails,
UsagePayload,
send_usage_payload,
sha256_hash,
zip_usage_payloads,
)
from .redis_queue import RedisQueue
Expand Down Expand Up @@ -167,11 +158,6 @@ def _dump_usage_queue_with_lock(self) -> List[APIKeyUsage]:
usage_payloads = self._dump_usage_queue_no_lock()
return usage_payloads

@staticmethod
def _hash(payload: str, length=5):
payload_hash = hashlib.sha256(payload.encode())
return payload_hash.hexdigest()[:length]

def _calculate_api_key_hash(self, api_key: APIKey) -> APIKeyHash:
api_key_hash = ""
if not api_key:
Expand All @@ -180,15 +166,15 @@ def _calculate_api_key_hash(self, api_key: APIKey) -> APIKeyHash:
api_key_hash = self._hashed_api_keys.get(api_key)
if not api_key_hash:
if self._api_keys_hashing_enabled:
api_key_hash = UsageCollector._hash(api_key)
api_key_hash = sha256_hash(api_key)
else:
api_key_hash = api_key
self._hashed_api_keys[api_key] = api_key_hash
return api_key_hash

@staticmethod
def _calculate_resource_hash(resource_details: Dict[str, Any]) -> str:
return UsageCollector._hash(json.dumps(resource_details, sort_keys=True))
return sha256_hash(json.dumps(resource_details, sort_keys=True))

def _enqueue_payload(self, payload: UsagePayload):
logger.debug("Enqueuing usage payload %s", payload)
Expand Down Expand Up @@ -235,7 +221,7 @@ def system_info(
ip_address: Optional[str] = None,
) -> SystemDetails:
if ip_address:
ip_address_hash_hex = UsageCollector._hash(ip_address)
ip_address_hash_hex = sha256_hash(ip_address)
else:
try:
ip_address: str = socket.gethostbyname(socket.gethostname())
Expand All @@ -251,7 +237,7 @@ def system_info(
if s:
s.close()

ip_address_hash_hex = UsageCollector._hash(ip_address)
ip_address_hash_hex = sha256_hash(ip_address)

return {
"ip_address_hash": ip_address_hash_hex,
Expand Down
62 changes: 53 additions & 9 deletions inference/usage_tracking/payload_helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
from typing import Any, DefaultDict, Dict, List, Optional, Set, Union

import requests
Expand Down Expand Up @@ -48,8 +49,10 @@ def get_api_key_usage_containing_resource(


def zip_usage_payloads(usage_payloads: List[APIKeyUsage]) -> List[APIKeyUsage]:
merged_api_key_usage_payloads: APIKeyUsage = {}
system_info_payload = None
usage_by_exec_session_id: Dict[
APIKeyHash, Dict[ResourceID, Dict[str, List[ResourceUsage]]]
] = {}
for usage_payload in usage_payloads:
for api_key_hash, resource_payloads in usage_payload.items():
if api_key_hash == "":
Expand All @@ -75,6 +78,9 @@ def zip_usage_payloads(usage_payloads: List[APIKeyUsage]) -> List[APIKeyUsage]:
v["resource_id"] = resource_id
if "category" not in v or not v["category"]:
v["category"] = category
api_key_usage_by_exec_session_id = usage_by_exec_session_id.setdefault(
api_key_hash, {}
)
for (
resource_usage_key,
resource_usage_payload,
Expand All @@ -93,18 +99,51 @@ def zip_usage_payloads(usage_payloads: List[APIKeyUsage]) -> List[APIKeyUsage]:
resource_usage_payload["api_key_hash"] = api_key_hash
resource_usage_payload["resource_id"] = resource_id
resource_usage_payload["category"] = category
merged_api_key_payload = merged_api_key_usage_payloads.setdefault(
api_key_hash, {}

resource_usage_exec_session_id = (
api_key_usage_by_exec_session_id.setdefault(resource_usage_key, {})
)
merged_resource_payload = merged_api_key_payload.setdefault(
resource_usage_key, {}
if not resource_usage_payload.get("fps"):
resource_usage_exec_session_id.setdefault("", []).append(
resource_usage_payload
)
continue
exec_session_id = resource_usage_payload.get("exec_session_id", "")
resource_usage_exec_session_id.setdefault(exec_session_id, []).append(
resource_usage_payload
)

merged_exec_session_id_usage_payloads: Dict[str, APIKeyUsage] = {}
for (
api_key_hash,
api_key_usage_by_exec_session_id,
) in usage_by_exec_session_id.items():
for (
resource_usage_key,
resource_usage_exec_session_id,
) in api_key_usage_by_exec_session_id.items():
for (
exec_session_id,
usage_payloads,
) in resource_usage_exec_session_id.items():
merged_api_key_usage_payloads = (
merged_exec_session_id_usage_payloads.setdefault(
exec_session_id, {}
)
)
merged_api_key_payload[resource_usage_key] = merge_usage_dicts(
merged_resource_payload,
resource_usage_payload,
merged_api_key_payload = merged_api_key_usage_payloads.setdefault(
api_key_hash, {}
)
for resource_usage_payload in usage_payloads:
merged_resource_payload = merged_api_key_payload.setdefault(
resource_usage_key, {}
)
merged_api_key_payload[resource_usage_key] = merge_usage_dicts(
merged_resource_payload,
resource_usage_payload,
)

zipped_payloads = [merged_api_key_usage_payloads]
zipped_payloads = list(merged_exec_session_id_usage_payloads.values())
if system_info_payload:
system_info_api_key_hash = next(iter(system_info_payload.values()))[
"api_key_hash"
Expand Down Expand Up @@ -151,3 +190,8 @@ def send_usage_payload(
api_keys_hashes_failed.add(api_key_hash)
continue
return api_keys_hashes_failed


def sha256_hash(payload: str, length=5):
payload_hash = hashlib.sha256(payload.encode())
return payload_hash.hexdigest()[:length]
3 changes: 1 addition & 2 deletions inference/usage_tracking/redis_queue.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import json
import time
from threading import Lock
from typing import Any, Dict, List, Optional
from uuid import uuid4

from typing_extensions import Any, Dict, List, Optional

from inference.core.cache import cache
from inference.core.cache.redis import RedisCache
from inference.core.logger import logger
Expand Down
3 changes: 1 addition & 2 deletions inference/usage_tracking/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import inspect

from typing_extensions import Any, Callable, Dict, Iterable
from typing import Any, Callable, Dict, Iterable

from inference.core.logger import logger

Expand Down
Loading

0 comments on commit ad91156

Please sign in to comment.