Skip to content

Commit

Permalink
Remove kvdriver argument (mis-append from previous CL).
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 727049973
  • Loading branch information
timblakely authored and copybara-github committed Feb 14, 2025
1 parent afb5a09 commit fc2ae79
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 13 deletions.
10 changes: 2 additions & 8 deletions connectomics/common/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def __post_init__(self):
def dataclass_from_serialized(
target: Type[T],
serialized: Union[str, PathLike],
kvdriver: str = 'file',
infer_missing_fields: bool = False,
) -> T:
"""Load a dataclass from a serialized instance, file path, or dict.
Expand All @@ -228,7 +227,6 @@ def dataclass_from_serialized(
target: Dataclass to load
serialized: Serialized instance, file path, or dict to create dataclass
from.
kvdriver: Driver to use for loading.
infer_missing_fields: Whether to infer missing fields.
Returns:
Expand All @@ -250,15 +248,13 @@ def dataclass_from_serialized(
return load_dataclass_json(
target,
as_str,
kvdriver=kvdriver,
infer_missing_fields=infer_missing_fields,
)


def load_json(
json_or_path: str | PathLike,
json_path: str | None = None,
kvdriver: str = 'file',
) -> dict[str, Any]:
"""Load a JSON object from a string or file path via TensorStore."""
try:
Expand All @@ -270,7 +266,7 @@ def load_json(
path = json_or_path
spec = {
'driver': 'json',
'kvstore': {'driver': kvdriver, 'path': str(path)},
'kvstore': str(path),
}
if json_path is not None:
if not json_path.startswith('/'):
Expand All @@ -283,7 +279,6 @@ def load_dataclass_json(
dataclass_type: Type[T],
path: PathLike,
json_path: str | None = None,
kvdriver: str = 'file',
infer_missing_fields: bool = False,
) -> T:
"""Load a dataclass from a file path.
Expand All @@ -292,14 +287,13 @@ def load_dataclass_json(
dataclass_type: Dataclass to load
path: Path to load from.
json_path: Optional path to load from within the file.
kvdriver: Driver to use for loading.
infer_missing_fields: Whether to infer missing fields.
Returns:
New dataclass instance.
"""
return dataclass_type.from_dict(
load_json(path, json_path, kvdriver),
load_json(path, json_path),
infer_missing=infer_missing_fields,
)

Expand Down
6 changes: 3 additions & 3 deletions connectomics/common/file_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,20 @@ def test_load_dataclass_json(self):
a = TestDataClass(a=1, b='foo', c=1.0)
fname = os.path.join(FLAGS.test_tmpdir, 'dc_file')
file.save_dataclass_json(a, fname)
b = file.load_dataclass_json(TestDataClass, fname)
b = file.load_dataclass_json(TestDataClass, f'file://{fname}')
self.assertEqual(a, b)

a = AnotherTestDataClass(a=1, b='foo', c=1.0, inner=a)
file.save_dataclass_json(a, fname)

inner = file.load_dataclass_json(TestDataClass, fname, '/inner')
inner = file.load_dataclass_json(TestDataClass, f'file://{fname}', '/inner')
self.assertEqual(inner, a.inner)

def test_dataclass_from_serialized(self):
a = TestDataClass(a=1, b='foo', c=1.0)
fname = os.path.join(FLAGS.test_tmpdir, 'dc_file')
file.save_dataclass_json(a, fname)
b = file.dataclass_from_serialized(TestDataClass, fname)
b = file.dataclass_from_serialized(TestDataClass, f'file://{fname}')
self.assertEqual(a, b)
c = file.dataclass_from_serialized(TestDataClass, a.to_json())
self.assertEqual(a, c)
Expand Down
3 changes: 1 addition & 2 deletions connectomics/volume/subvolume_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ def default_config(
config_type: DefaultConfigType | None = None,
overrides: file.PathLike | dict[str, Any] | None = None,
fallback_to_em_2d: bool = True,
kvdriver: str = 'file',
) -> T:
"""Returns a default configuration for a given config type and class."""
if overrides and not isinstance(overrides, dict):
Expand All @@ -381,7 +380,7 @@ def default_config(
if not overrides:
overrides = None
if isinstance(overrides, file.PathLike):
overrides = file.load_json(overrides, kvdriver=kvdriver)
overrides = file.load_json(overrides)
if config_type is None and fallback_to_em_2d:
logging.warning('No default config type specified, falling back to EM_2D.')
config_type = DefaultConfigType.EM_2D
Expand Down

0 comments on commit fc2ae79

Please sign in to comment.