diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc index c32a6ef6de93e..2a54d28c6fb64 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_extension.cc @@ -56,8 +56,9 @@ Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* ou std::shared_ptr GetCastToExtension(std::string name) { auto func = std::make_shared(std::move(name), Type::EXTENSION); for (Type::type in_ty : AllTypeIds()) { - DCHECK_OK( - func->AddKernel(in_ty, {InputType(in_ty)}, kOutputTargetType, CastToExtension)); + DCHECK_OK(func->AddKernel(in_ty, {InputType(in_ty)}, kOutputTargetType, + CastToExtension, NullHandling::COMPUTED_NO_PREALLOCATE, + MemAllocation::NO_PREALLOCATE)); } return func; } diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 9863d96058947..1c4d0175a2d97 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -194,6 +194,21 @@ def __arrow_ext_deserialize__(cls, storage_type, serialized): return cls(storage_type) +class MyFixedListType(pa.ExtensionType): + + def __init__(self, storage_type): + assert isinstance(storage_type, pa.FixedSizeListType) + super().__init__(storage_type, 'pyarrow.tests.MyFixedListType') + + def __arrow_ext_serialize__(self): + return b'' + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + assert serialized == b'' + return cls(storage_type) + + class AnnotatedType(pa.ExtensionType): """ Generic extension type that can store any storage type. @@ -738,6 +753,36 @@ def test_casting_dict_array_to_extension_type(): UUID('30313233-3435-3637-3839-616263646566')] +def test_cast_to_extension_with_nested_storage(): + # https://github.com/apache/arrow/issues/37669 + + # With fixed-size list + array = pa.array([[1, 2], [3, 4], [5, 6]], pa.list_(pa.float64(), 2)) + result = array.cast(MyFixedListType(pa.list_(pa.float64(), 2))) + expected = pa.ExtensionArray.from_storage(MyFixedListType(array.type), array) + assert result.equals(expected) + + ext_type = MyFixedListType(pa.list_(pa.float32(), 2)) + result = array.cast(ext_type) + expected = pa.ExtensionArray.from_storage( + ext_type, array.cast(ext_type.storage_type) + ) + assert result.equals(expected) + + # With variable-size list + array = pa.array([[1, 2], [3], [4, 5, 6]], pa.list_(pa.float64())) + result = array.cast(MyListType(pa.list_(pa.float64()))) + expected = pa.ExtensionArray.from_storage(MyListType(array.type), array) + assert result.equals(expected) + + ext_type = MyListType(pa.list_(pa.float32())) + result = array.cast(ext_type) + expected = pa.ExtensionArray.from_storage( + ext_type, array.cast(ext_type.storage_type) + ) + assert result.equals(expected) + + def test_concat(): arr1 = pa.array([1, 2, 3], IntegerType()) arr2 = pa.array([4, 5, 6], IntegerType()) @@ -1500,6 +1545,21 @@ def test_tensor_type_equality(): assert not tensor_type == tensor_type3 +def test_tensor_type_cast(): + tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 3]) + inner = pa.array(range(18), pa.int8()) + storage = pa.FixedSizeListArray.from_arrays(inner, 6) + + # cast storage -> extension type + result = storage.cast(tensor_type) + expected = pa.ExtensionArray.from_storage(tensor_type, storage) + assert result.equals(expected) + + # cast extension type -> storage type + storage_result = result.cast(storage.type) + assert storage_result.equals(storage) + + @pytest.mark.pandas def test_extension_to_pandas_storage_type(registered_period_type): period_type, _ = registered_period_type