Skip to content

Commit

Permalink
Pull out load_json into its own function for use elsewhere.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 693841745
  • Loading branch information
timblakely authored and copybara-github committed Nov 6, 2024
1 parent e40ceca commit 3ddaf6d
Showing 1 changed file with 28 additions and 12 deletions.
40 changes: 28 additions & 12 deletions connectomics/common/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import json
import pathlib
import typing
from typing import Any, Callable, Optional, Type, TypeVar, Union
from typing import Any, Callable, Type, TypeVar, Union

from absl import logging
import dataclasses_json
Expand All @@ -34,7 +34,7 @@
def save_dataclass_json(
dataclass_instance: T,
path: PathLike,
json_path: Optional[str] = None,
json_path: str | None = None,
kvdriver: str = 'file',
):
"""Save a dataclass to a file.
Expand Down Expand Up @@ -96,10 +96,34 @@ def dataclass_from_serialized(
)


def load_json(
json_or_path: str | PathLike,
json_path: str | None = None,
kvdriver: str = 'file',
) -> dict[str, Any]:
"""Load a JSON object from a string or file path via TensorStore."""
try:
return json.loads(json_or_path)
except json.JSONDecodeError:
logging.warning(
'Could not decode %s as JSON, trying to load as a path', json_or_path
)
path = json_or_path
spec = {
'driver': 'json',
'kvstore': {'driver': kvdriver, 'path': str(path)},
}
if json_path is not None:
if not json_path.startswith('/'):
json_path = f'/{json_path}'
spec['json_pointer'] = json_path
return ts.open(spec).result().read().result().item()


def load_dataclass_json(
dataclass_type: Type[T],
path: PathLike,
json_path: Optional[str] = None,
json_path: str | None = None,
kvdriver: str = 'file',
infer_missing_fields: bool = False,
) -> T:
Expand All @@ -115,16 +139,8 @@ def load_dataclass_json(
Returns:
New dataclass instance.
"""
spec = {
'driver': 'json',
'kvstore': {'driver': kvdriver, 'path': str(path)},
}
if json_path is not None:
if not json_path.startswith('/'):
json_path = f'/{json_path}'
spec['json_pointer'] = json_path
return dataclass_type.from_dict(
ts.open(spec).result().read().result().item(),
load_json(path, json_path, kvdriver),
infer_missing=infer_missing_fields,
)

Expand Down

0 comments on commit 3ddaf6d

Please sign in to comment.