Skip to content

Commit

Permalink
add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
brokkoli71 committed Nov 6, 2024
1 parent 955f06c commit a7f8bbf
Showing 1 changed file with 158 additions and 11 deletions.
169 changes: 158 additions & 11 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,23 @@


def parse_zarr_format(data: Any) -> ZarrFormat:
"""Parse the zarr_format field from metadata."""
if data in (2, 3):
return cast(Literal[2, 3], data)
msg = f"Invalid zarr_format. Expected one of 2 or 3. Got {data}."
raise ValueError(msg)


def parse_node_type(data: Any) -> NodeType:
"""Parse the node_type field from metadata."""
if data in ("array", "group"):
return cast(Literal["array", "group"], data)
raise MetadataValidationError("node_type", "array or group", data)


# todo: convert None to empty dict
def parse_attributes(data: Any) -> dict[str, Any]:
"""Parse the attributes field from metadata."""
if data is None:
return {}
elif isinstance(data, dict) and all(isinstance(k, str) for k in data):
Expand All @@ -88,9 +91,7 @@ def _parse_async_node(node: AsyncGroup) -> Group: ...
def _parse_async_node(
node: AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup,
) -> Array | Group:
"""
Wrap an AsyncArray in an Array, or an AsyncGroup in a Group.
"""
"""Wrap an AsyncArray in an Array, or an AsyncGroup in a Group."""
if isinstance(node, AsyncArray):
return Array(node)
elif isinstance(node, AsyncGroup):
Expand Down Expand Up @@ -297,6 +298,10 @@ def flatten(

@dataclass(frozen=True)
class GroupMetadata(Metadata):
"""
Metadata for a Group.
"""

attributes: dict[str, Any] = field(default_factory=dict)
zarr_format: ZarrFormat = 3
consolidated_metadata: ConsolidatedMetadata | None = None
Expand Down Expand Up @@ -391,6 +396,10 @@ def to_dict(self) -> dict[str, Any]:

@dataclass(frozen=True)
class AsyncGroup:
"""
Asynchronous Group object.
"""

metadata: GroupMetadata
store_path: StorePath

Expand Down Expand Up @@ -620,6 +629,18 @@ async def getitem(
self,
key: str,
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup:
"""
Get a subarray or subgroup from the group.
Parameters
----------
key : str
Array or group name
Returns
-------
AsyncArray or AsyncGroup
"""
store_path = self.store_path / key
logger.debug("key=%s, store_path=%s", key, store_path)

Expand Down Expand Up @@ -725,6 +746,13 @@ def _getitem_consolidated(
return AsyncArray(metadata=metadata, store_path=store_path)

async def delitem(self, key: str) -> None:
"""Delete a group member.
Parameters
----------
key : str
Array or group name
"""
store_path = self.store_path / key
if self.metadata.zarr_format == 3:
await (store_path / ZARR_JSON).delete()
Expand Down Expand Up @@ -834,6 +862,21 @@ async def create_group(
exists_ok: bool = False,
attributes: dict[str, Any] | None = None,
) -> AsyncGroup:
"""Create a sub-group.
Parameters
----------
name : str
Group name.
exists_ok : bool, optional
If True, do not raise an error if the group already exists.
attributes : dict, optional
Group attributes.
Returns
-------
g : AsyncGroup
"""
attributes = attributes or {}
return await type(self).from_store(
self.store_path / name,
Expand Down Expand Up @@ -875,7 +918,17 @@ async def require_group(self, name: str, overwrite: bool = False) -> AsyncGroup:
return grp

async def require_groups(self, *names: str) -> tuple[AsyncGroup, ...]:
"""Convenience method to require multiple groups in a single call."""
"""Convenience method to require multiple groups in a single call.
Parameters
----------
*names : str
Group names.
Returns
-------
Tuple[AsyncGroup, ...]
"""
if not names:
return ()
return tuple(await asyncio.gather(*(self.require_group(name) for name in names)))
Expand Down Expand Up @@ -1083,6 +1136,17 @@ async def require_array(
return ds

async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup:
"""Update group attributes.
Parameters
----------
new_attributes : dict
New attributes to set on the group.
Returns
-------
self : AsyncGroup
"""
# metadata.attributes is "frozen" so we simply clear and update the dict
self.metadata.attributes.clear()
self.metadata.attributes.update(new_attributes)
Expand Down Expand Up @@ -1241,10 +1305,22 @@ def _members_consolidated(
yield from obj._members_consolidated(max_depth, current_depth + 1, prefix=key)

async def keys(self) -> AsyncGenerator[str, None]:
"""Iterate over member names."""
async for key, _ in self.members():
yield key

async def contains(self, member: str) -> bool:
"""Check if a member exists in the group.
Parameters
----------
member : str
Member name.
Returns
-------
bool
"""
# TODO: this can be made more efficient.
try:
await self.getitem(member)
Expand All @@ -1254,15 +1330,18 @@ async def contains(self, member: str) -> bool:
return True

async def groups(self) -> AsyncGenerator[tuple[str, AsyncGroup], None]:
"""Iterate over subgroups."""
async for name, value in self.members():
if isinstance(value, AsyncGroup):
yield name, value

async def group_keys(self) -> AsyncGenerator[str, None]:
"""Iterate over group names."""
async for key, _ in self.groups():
yield key

async def group_values(self) -> AsyncGenerator[AsyncGroup, None]:
"""Iterate over group values."""
async for _, group in self.groups():
yield group

Expand All @@ -1271,21 +1350,25 @@ async def arrays(
) -> AsyncGenerator[
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]], None
]:
"""Iterate over arrays."""
async for key, value in self.members():
if isinstance(value, AsyncArray):
yield key, value

async def array_keys(self) -> AsyncGenerator[str, None]:
"""Iterate over array names."""
async for key, _ in self.arrays():
yield key

async def array_values(
self,
) -> AsyncGenerator[AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata], None]:
"""Iterate over array values."""
async for _, array in self.arrays():
yield array

async def tree(self, expand: bool = False, level: int | None = None) -> Any:
"""Return a nested representation of the group hierarchy."""
raise NotImplementedError

async def empty(
Expand Down Expand Up @@ -1467,7 +1550,12 @@ async def full_like(
return await async_api.full_like(a=data, store=self.store_path, path=name, **kwargs)

async def move(self, source: str, dest: str) -> None:
"""Not implemented"""
"""Move a sub-group or sub-array from one path to another.
Notes
-----
Not implemented
"""
raise NotImplementedError


Expand Down Expand Up @@ -1609,7 +1697,22 @@ def get(self, path: str, default: DefaultT | None = None) -> Array | Group | Def
return default

def __delitem__(self, key: str) -> None:
"""Delete a group member."""
"""Delete a group member.
Parameters
----------
key : str
Group member name.
Examples
--------
>>> import zarr
>>> group = Group.from_store(zarr.storage.MemoryStore(mode="w"))
>>> group.create_array(name="subarray", shape=(10,), chunk_shape=(10,))
>>> del group["subarray"]
>>> "subarray" in group
False
"""
self._sync(self._async_group.delitem(key))

def __iter__(self) -> Iterator[str]:
Expand Down Expand Up @@ -1639,6 +1742,22 @@ def __setitem__(self, key: str, value: Any) -> None:
"""Fastpath for creating a new array.
New arrays will be created using default settings for the array type.
If you need to create an array with custom settings, use the `create_array` method.
Parameters
----------
key : str
Array name.
value : Any
Array data.
Examples
--------
>>> import zarr
>>> group = zarr.group()
>>> group["foo"] = zarr.zeros((10,))
>>> group["foo"]
<Array memory://132270269438272/foo shape=(10,) dtype=float64>
"""
self._sync(self._async_group.setitem(key, value))

Expand All @@ -1647,6 +1766,7 @@ def __repr__(self) -> str:

async def update_attributes_async(self, new_attributes: dict[str, Any]) -> Group:
"""Update the attributes of this group.
Example
-------
>>> import zarr
Expand Down Expand Up @@ -1697,6 +1817,7 @@ def attrs(self) -> Attributes:

@property
def info(self) -> None:
"""Group information."""
raise NotImplementedError

@property
Expand Down Expand Up @@ -1757,6 +1878,7 @@ def members(self, max_depth: int | None = 0) -> tuple[tuple[str, Array | Group],

def keys(self) -> Generator[str, None]:
"""Return an iterator over group member names.
Examples
--------
>>> import zarr
Expand Down Expand Up @@ -1795,6 +1917,7 @@ def __contains__(self, member: str) -> bool:

def groups(self) -> Generator[tuple[str, Group], None]:
"""Return the sub-groups of this group as a generator of (name, group) pairs.
Example
-------
>>> import zarr
Expand All @@ -1809,6 +1932,7 @@ def groups(self) -> Generator[tuple[str, Group], None]:

def group_keys(self) -> Generator[str, None]:
"""Return an iterator over group member names.
Examples
--------
>>> import zarr
Expand All @@ -1823,6 +1947,7 @@ def group_keys(self) -> Generator[str, None]:

def group_values(self) -> Generator[Group, None]:
"""Return an iterator over group members.
Examples
--------
>>> import zarr
Expand All @@ -1836,8 +1961,8 @@ def group_values(self) -> Generator[Group, None]:
yield group

def arrays(self) -> Generator[tuple[str, Array], None]:
"""
Return the sub-arrays of this group as a generator of (name, array) pairs
"""Return the sub-arrays of this group as a generator of (name, array) pairs
Examples
--------
>>> import zarr
Expand All @@ -1852,6 +1977,7 @@ def arrays(self) -> Generator[tuple[str, Array], None]:

def array_keys(self) -> Generator[str, None]:
"""Return an iterator over group member names.
Examples
--------
>>> import zarr
Expand All @@ -1867,6 +1993,7 @@ def array_keys(self) -> Generator[str, None]:

def array_values(self) -> Generator[Array, None]:
"""Return an iterator over group members.
Examples
--------
>>> import zarr
Expand All @@ -1880,7 +2007,12 @@ def array_values(self) -> Generator[Array, None]:
yield array

def tree(self, expand: bool = False, level: int | None = None) -> Any:
"""Not implemented"""
"""Return a nested representation of the group hierarchy.
Notes
-----
Not implemented
"""
return self._sync(self._async_group.tree(expand=expand, level=level))

def create_group(self, name: str, **kwargs: Any) -> Group:
Expand Down Expand Up @@ -1920,7 +2052,17 @@ def require_group(self, name: str, **kwargs: Any) -> Group:
return Group(self._sync(self._async_group.require_group(name, **kwargs)))

def require_groups(self, *names: str) -> tuple[Group, ...]:
"""Convenience method to require multiple groups in a single call."""
"""Convenience method to require multiple groups in a single call.
Parameters
----------
*names : str
Group names.
Returns
-------
groups : tuple of Groups
"""
return tuple(map(Group, self._sync(self._async_group.require_groups(*names))))

def create(self, *args: Any, **kwargs: Any) -> Array:
Expand Down Expand Up @@ -2259,7 +2401,12 @@ def full_like(self, *, name: str, data: async_api.ArrayLike, **kwargs: Any) -> A
return Array(self._sync(self._async_group.full_like(name=name, data=data, **kwargs)))

def move(self, source: str, dest: str) -> None:
"""Not implemented"""
"""Move a sub-group or sub-array from one path to another.
Notes
-----
Not implemented
"""
return self._sync(self._async_group.move(source, dest))

@deprecated("Use Group.create_array instead.")
Expand Down

0 comments on commit a7f8bbf

Please sign in to comment.