Skip to content

Commit

Permalink
Maybe fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAugspurger committed Nov 5, 2024
1 parent 1cdfd6d commit 7cbc500
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 7 deletions.
11 changes: 9 additions & 2 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable

from zarr.core.buffer.core import default_buffer_prototype
from zarr.core.common import concurrent_map
from zarr.core.config import config

if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Iterable
Expand Down Expand Up @@ -453,9 +455,14 @@ async def getsize_prefix(self, prefix: str) -> int:
-----
``getsize_prefix`` is just provided as a potentially faster alternative to
listing all the keys under a prefix calling :meth:`Store.getsize` on each.
In general, ``prefix`` should be the path of an Array or Group in the Store.
Implementations may differ on the behavior when some other ``prefix``
is provided.
"""
keys = [x async for x in self.list_prefix(prefix)]
sizes = await gather(*[self.getsize(key) for key in keys])
keys = ((x,) async for x in self.list_prefix(prefix))
limit = config.get("async.concurrency")
sizes = await concurrent_map(keys, self.getsize, limit=limit)
return sum(sizes)


Expand Down
16 changes: 12 additions & 4 deletions src/zarr/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import functools
import operator
from collections.abc import Iterable, Mapping
from collections.abc import AsyncIterable, Iterable, Mapping
from enum import Enum
from itertools import starmap
from typing import (
Expand Down Expand Up @@ -50,10 +50,15 @@ def product(tup: ChunkCoords) -> int:


async def concurrent_map(
items: Iterable[T], func: Callable[..., Awaitable[V]], limit: int | None = None
items: Iterable[T] | AsyncIterable[T],
func: Callable[..., Awaitable[V]],
limit: int | None = None,
) -> list[V]:
if limit is None:
return await asyncio.gather(*list(starmap(func, items)))
if isinstance(items, AsyncIterable):
return await asyncio.gather(*list(starmap(func, [x async for x in items])))
else:
return await asyncio.gather(*list(starmap(func, items)))

else:
sem = asyncio.Semaphore(limit)
Expand All @@ -62,7 +67,10 @@ async def run(item: tuple[Any]) -> V:
async with sem:
return await func(*item)

return await asyncio.gather(*[asyncio.ensure_future(run(item)) for item in items])
if isinstance(items, AsyncIterable):
return await asyncio.gather(*[asyncio.ensure_future(run(item)) async for item in items])
else:
return await asyncio.gather(*[asyncio.ensure_future(run(item)) for item in items])


E = TypeVar("E", bound=Enum)
Expand Down
13 changes: 12 additions & 1 deletion src/zarr/testing/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,19 @@ async def test_getsize(self, store: S) -> None:
await self.set(store, key, data)

result = await store.getsize(key)
assert result == 10
assert isinstance(result, int)
assert result > 0

async def test_getsize_raises(self, store: S) -> None:
with pytest.raises(FileNotFoundError):
await store.getsize("not-a-real-key")

async def test_getsize_prefix(self, store: S) -> None:
prefix = "array/c/"
for i in range(10):
data = self.buffer_cls.from_bytes(b"0" * 10)
await self.set(store, f"{prefix}/{i}", data)

result = await store.getsize_prefix(prefix)
assert isinstance(result, int)
assert result > 0

0 comments on commit 7cbc500

Please sign in to comment.