Skip to content

Commit

Permalink
Merge pull request #555 from roboflow/cache-workflow-spec-offline
Browse files Browse the repository at this point in the history
Cache workflow specification for offline use
  • Loading branch information
grzegorz-roboflow authored Jul 29, 2024
2 parents 138b395 + 4dda8a3 commit 787c76c
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 2 deletions.
52 changes: 50 additions & 2 deletions inference/core/roboflow_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
VersionID,
WorkspaceID,
)
from inference.core.env import API_BASE_URL
from inference.core.env import API_BASE_URL, MODEL_CACHE_DIR
from inference.core.exceptions import (
MalformedRoboflowAPIResponseError,
MalformedWorkflowResponseError,
Expand All @@ -31,6 +31,7 @@
RoboflowAPIUnsuccessfulRequestError,
WorkspaceLoadError,
)
from inference.core.utils.file_system import sanitize_path_segment
from inference.core.utils.requests import api_key_safe_raise_for_status
from inference.core.utils.url_utils import wrap_url

Expand Down Expand Up @@ -357,6 +358,47 @@ def get_roboflow_labeling_jobs(
return _get_from_url(url=api_url)


def get_workflow_cache_file(workspace_id: WorkspaceID, workflow_id: str):
sanitized_workspace_id = sanitize_path_segment(workspace_id)
sanitized_workflow_id = sanitize_path_segment(workflow_id)
return os.path.join(
MODEL_CACHE_DIR,
"workflow",
sanitized_workspace_id,
f"{sanitized_workflow_id}.json",
)


def cache_workflow_response(
workspace_id: WorkspaceID, workflow_id: str, response: dict
):
workflow_cache_file = get_workflow_cache_file(workspace_id, workflow_id)
workflow_cache_dir = os.path.dirname(workflow_cache_file)
if not os.path.exists(workflow_cache_dir):
os.makedirs(workflow_cache_dir, exist_ok=True)
with open(workflow_cache_file, "w") as f:
json.dump(response, f)


def delete_cached_workflow_response_if_exists(
workspace_id: WorkspaceID, workflow_id: str
) -> None:
workflow_cache_file = get_workflow_cache_file(workspace_id, workflow_id)
if os.path.exists(workflow_cache_file):
os.remove(workflow_cache_file)


def load_cached_workflow_response(workspace_id: WorkspaceID, workflow_id: str) -> dict:
workflow_cache_file = get_workflow_cache_file(workspace_id, workflow_id)
if not os.path.exists(workflow_cache_file):
return None
try:
with open(workflow_cache_file, "r") as f:
return json.load(f)
except:
delete_cached_workflow_response_if_exists(workspace_id, workflow_id)


@wrap_roboflow_api_errors()
def get_workflow_specification(
api_key: str,
Expand All @@ -367,7 +409,13 @@ def get_workflow_specification(
url=f"{API_BASE_URL}/{workspace_id}/workflows/{workflow_id}",
params=[("api_key", api_key)],
)
response = _get_from_url(url=api_url)
try:
response = _get_from_url(url=api_url)
cache_workflow_response(workspace_id, workflow_id, response)
except (requests.exceptions.ConnectionError, ConnectionError) as error:
response = load_cached_workflow_response(workspace_id, workflow_id)
if response is None:
raise error
if "workflow" not in response or "config" not in response["workflow"]:
raise MalformedWorkflowResponseError(
f"Could not find workflow specification in API response"
Expand Down
6 changes: 6 additions & 0 deletions inference/core/utils/file_system.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os.path
import re
from typing import List, Optional, Union


Expand Down Expand Up @@ -63,3 +64,8 @@ def ensure_parent_dir_exists(path: str) -> None:
def ensure_write_is_allowed(path: str, allow_override: bool) -> None:
if os.path.exists(path) and not allow_override:
raise RuntimeError(f"File {path} exists and override is forbidden.")


def sanitize_path_segment(path_segment: str) -> str:
# Keep only letters, numbers, underscores and dashes
return re.sub(r"[^A-Za-z0-9_-]", "_", path_segment)
30 changes: 30 additions & 0 deletions tests/inference/unit_tests/core/test_roboflow_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@
get_roboflow_model_data,
get_roboflow_model_type,
get_roboflow_workspace,
delete_cached_workflow_response_if_exists,
get_workflow_specification,
raise_from_lambda,
register_image_at_roboflow,
wrap_roboflow_api_errors,
)
from inference.core.utils.url_utils import wrap_url
import json


class TestException(Exception):
Expand Down Expand Up @@ -1691,6 +1693,7 @@ def test_get_workflow_specification_when_connection_error_occurs(
get_mock: MagicMock,
) -> None:
# given
delete_cached_workflow_response_if_exists("my_workspace", "some_workflow")
get_mock.side_effect = ConnectionError()

# when
Expand All @@ -1702,6 +1705,33 @@ def test_get_workflow_specification_when_connection_error_occurs(
)


@mock.patch.object(roboflow_api.requests, "get")
def test_get_workflow_specification_when_connection_error_occurs_but_file_is_cached(
get_mock: MagicMock,
) -> None:
# given
delete_cached_workflow_response_if_exists("my_workspace", "some_workflow")

get_mock.return_value = MagicMock(
status_code=200,
json= MagicMock(return_value={"workflow": {"config": json.dumps({"specification": "some"}) }}),
)

_ = get_workflow_specification(
api_key="my_api_key",
workspace_id="my_workspace",
workflow_id="some_workflow",
)

get_mock.side_effect = ConnectionError()

_ = get_workflow_specification(
api_key="my_api_key",
workspace_id="my_workspace",
workflow_id="some_workflow",
)


def test_get_workflow_specification_when_wrong_api_key_used(
requests_mock: Mocker,
) -> None:
Expand Down

0 comments on commit 787c76c

Please sign in to comment.