Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAugspurger committed Oct 9, 2024
1 parent 483681b commit 7e76e9e
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ async def create(
if chunks is not None and chunk_shape is not None:
raise ValueError("Only one of chunk_shape or chunks can be provided.")

dtype = parse_dtype(dtype)
dtype = parse_dtype(dtype, zarr_format)
# dtype = np.dtype(dtype)
if chunks:
_chunks = normalize_chunks(chunks, shape, dtype.itemsize)
Expand Down
11 changes: 8 additions & 3 deletions src/zarr/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import numpy as np

from zarr.core.strings import _STRING_DTYPE

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Iterator

Expand Down Expand Up @@ -166,8 +168,11 @@ def parse_order(data: Any) -> Literal["C", "F"]:
raise ValueError(f"Expected one of ('C', 'F'), got {data} instead.")


def parse_dtype(dtype: Any) -> np.dtype[Any]:
def parse_dtype(dtype: Any, zarr_format: ZarrFormat) -> np.dtype[Any]:
if dtype is str or dtype == "str":
# special case as object
return np.dtype("object")
if zarr_format == 2:
# special case as object
return np.dtype("object")
else:
return _STRING_DTYPE
return np.dtype(dtype)
4 changes: 2 additions & 2 deletions src/zarr/core/metadata/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
Metadata for a Zarr version 2 array.
"""
shape_parsed = parse_shapelike(shape)
data_type_parsed = parse_dtype(dtype)
data_type_parsed = parse_dtype(dtype, zarr_format=2)
chunks_parsed = parse_shapelike(chunks)
compressor_parsed = parse_compressor(compressor)
order_parsed = parse_indexing_order(order)
Expand Down Expand Up @@ -141,7 +141,7 @@ def from_dict(cls, data: dict[str, Any]) -> ArrayV2Metadata:
_data = data.copy()
# check that the zarr_format attribute is correct
_ = parse_zarr_format(_data.pop("zarr_format"))
dtype = parse_dtype(_data["dtype"])
dtype = parse_dtype(_data["dtype"], zarr_format=2)

if dtype.kind in "SV":
fill_value_encoded = _data.get("fill_value")
Expand Down
4 changes: 0 additions & 4 deletions src/zarr/core/metadata/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,6 @@ class DataType(Enum):
complex128 = "complex128"
string = "string"
bytes = "bytes"
object = "object"

@property
def byte_count(self) -> None | int:
Expand Down Expand Up @@ -550,7 +549,6 @@ def to_numpy_shortname(self) -> str:
DataType.float64: "f8",
DataType.complex64: "c8",
DataType.complex128: "c16",
DataType.object: "object",
}
return data_type_to_numpy[self]

Expand All @@ -574,8 +572,6 @@ def from_numpy(cls, dtype: np.dtype[Any]) -> DataType:
return DataType.string
elif dtype.kind == "S":
return DataType.bytes
elif dtype.kind == "O":
return DataType.object
dtype_to_data_type = {
"|b1": "bool",
"bool": "bool",
Expand Down
9 changes: 1 addition & 8 deletions tests/v3/test_array.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pickle
from itertools import accumulate
from typing import Any, Literal
from typing import Literal

import numpy as np
import pytest
Expand Down Expand Up @@ -406,10 +406,3 @@ def test_vlen_errors() -> None:
dtype="<U4",
codecs=[BytesCodec(), VLenBytesCodec()],
)


@pytest.mark.parametrize("zarr_format", [2, 3, None])
@pytest.mark.parametrize("dtype", [str, "str"])
def test_create_dtype_str(dtype: Any, zarr_format: ZarrFormat | None) -> None:
arr = zarr.create(shape=10, dtype=dtype, zarr_format=zarr_format)
assert arr.dtype.kind == "O"
2 changes: 1 addition & 1 deletion tests/v3/test_codecs/test_vlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from zarr.core.strings import _NUMPY_SUPPORTS_VLEN_STRING
from zarr.storage.common import StorePath

numpy_str_dtypes: list[type | None] = [None, str, np.dtypes.StrDType]
numpy_str_dtypes: list[type | str | None] = [None, str, "str", np.dtypes.StrDType]
expected_zarr_string_dtype: np.dtype[Any]
if _NUMPY_SUPPORTS_VLEN_STRING:
numpy_str_dtypes.append(np.dtypes.StringDType)
Expand Down
8 changes: 8 additions & 0 deletions tests/v3/test_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from collections.abc import Iterator
from typing import Any

import numpy as np
import pytest
Expand Down Expand Up @@ -84,3 +85,10 @@ async def test_v2_encode_decode(dtype):
data = zarr.open_array(store=store, path="foo")[:]
expected = np.full((3,), b"X", dtype=dtype)
np.testing.assert_equal(data, expected)


@pytest.mark.parametrize("dtype", [str, "str"])
async def test_create_dtype_str(dtype: Any) -> None:
arr = zarr.create(shape=10, dtype=dtype, zarr_format=2)
assert arr.dtype.kind == "O"
assert arr.metadata.to_dict()["dtype"] == "|O"

0 comments on commit 7e76e9e

Please sign in to comment.