Skip to content

Commit

Permalink
Merge branch 'tom/fix/dtype-str-special-case' into user/tom/feature/c…
Browse files Browse the repository at this point in the history
…onsolidated-metadata
  • Loading branch information
TomAugspurger committed Oct 9, 2024
2 parents 19b9271 + 7e76e9e commit 418bc6b
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 10 deletions.
4 changes: 3 additions & 1 deletion src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ShapeLike,
ZarrFormat,
concurrent_map,
parse_dtype,
parse_shapelike,
product,
)
Expand Down Expand Up @@ -232,7 +233,8 @@ async def create(
if chunks is not None and chunk_shape is not None:
raise ValueError("Only one of chunk_shape or chunks can be provided.")

dtype = np.dtype(dtype)
dtype = parse_dtype(dtype, zarr_format)
# dtype = np.dtype(dtype)
if chunks:
_chunks = normalize_chunks(chunks, shape, dtype.itemsize)
else:
Expand Down
14 changes: 14 additions & 0 deletions src/zarr/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
overload,
)

import numpy as np

from zarr.core.strings import _STRING_DTYPE

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Iterator

Expand Down Expand Up @@ -163,3 +167,13 @@ def parse_order(data: Any) -> Literal["C", "F"]:
if data in ("C", "F"):
return cast(Literal["C", "F"], data)
raise ValueError(f"Expected one of ('C', 'F'), got {data} instead.")


def parse_dtype(dtype: Any, zarr_format: ZarrFormat) -> np.dtype[Any]:
if dtype is str or dtype == "str":
if zarr_format == 2:
# special case as object
return np.dtype("object")
else:
return _STRING_DTYPE
return np.dtype(dtype)
11 changes: 3 additions & 8 deletions src/zarr/core/metadata/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from zarr.core.array_spec import ArraySpec
from zarr.core.chunk_grids import RegularChunkGrid
from zarr.core.chunk_key_encodings import parse_separator
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_shapelike
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_dtype, parse_shapelike
from zarr.core.config import config, parse_indexing_order
from zarr.core.metadata.common import ArrayMetadata, parse_attributes

Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(
Metadata for a Zarr version 2 array.
"""
shape_parsed = parse_shapelike(shape)
data_type_parsed = parse_dtype(dtype)
data_type_parsed = parse_dtype(dtype, zarr_format=2)
chunks_parsed = parse_shapelike(chunks)
compressor_parsed = parse_compressor(compressor)
order_parsed = parse_indexing_order(order)
Expand Down Expand Up @@ -141,7 +141,7 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
_data = data.copy()
# check that the zarr_format attribute is correct
_ = parse_zarr_format(_data.pop("zarr_format"))
dtype = parse_dtype(_data["dtype"])
dtype = parse_dtype(_data["dtype"], zarr_format=2)

if dtype.kind in "SV":
fill_value_encoded = _data.get("fill_value")
Expand Down Expand Up @@ -201,11 +201,6 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self:
return replace(self, attributes=attributes)


def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
# todo: real validation
return np.dtype(data)


def parse_zarr_format(data: object) -> Literal[2]:
if data == 2:
return 2
Expand Down
2 changes: 1 addition & 1 deletion tests/v3/test_codecs/test_vlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from zarr.core.strings import _NUMPY_SUPPORTS_VLEN_STRING
from zarr.storage.common import StorePath

numpy_str_dtypes: list[type | None] = [None, str, np.dtypes.StrDType]
numpy_str_dtypes: list[type | str | None] = [None, str, "str", np.dtypes.StrDType]
expected_zarr_string_dtype: np.dtype[Any]
if _NUMPY_SUPPORTS_VLEN_STRING:
numpy_str_dtypes.append(np.dtypes.StringDType)
Expand Down
8 changes: 8 additions & 0 deletions tests/v3/test_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from collections.abc import Iterator
from typing import Any

import numpy as np
import pytest
Expand Down Expand Up @@ -84,3 +85,10 @@ async def test_v2_encode_decode(dtype):
data = zarr.open_array(store=store, path="foo")[:]
expected = np.full((3,), b"X", dtype=dtype)
np.testing.assert_equal(data, expected)


@pytest.mark.parametrize("dtype", [str, "str"])
async def test_create_dtype_str(dtype: Any) -> None:
arr = zarr.create(shape=10, dtype=dtype, zarr_format=2)
assert arr.dtype.kind == "O"
assert arr.metadata.to_dict()["dtype"] == "|O"

0 comments on commit 418bc6b

Please sign in to comment.