diff --git a/python/pyarrow/src/arrow/python/numpy_to_arrow.cc b/python/pyarrow/src/arrow/python/numpy_to_arrow.cc index 460b1d0ce3fa6..afd781f7efe07 100644 --- a/python/pyarrow/src/arrow/python/numpy_to_arrow.cc +++ b/python/pyarrow/src/arrow/python/numpy_to_arrow.cc @@ -224,9 +224,11 @@ class NumPyConverter { // NumPy ascii string arrays Status Visit(const BinaryType& type); + Status Visit(const LargeBinaryType& type); // NumPy unicode arrays Status Visit(const StringType& type); + Status Visit(const LargeStringType& type); Status Visit(const StructType& type); @@ -284,6 +286,12 @@ class NumPyConverter { return PushArray(arr_data); } + template + Status VisitBinary(T* builder); + + template + Status VisitString(T* builder); + Status TypeNotImplemented(std::string type_name) { return Status::NotImplemented("NumPyConverter doesn't implement <", type_name, "> conversion. "); @@ -553,24 +561,23 @@ inline Status NumPyConverter::ConvertData(std::shared_ptr* d // Create 16MB chunks for binary data constexpr int32_t kBinaryChunksize = 1 << 24; -Status NumPyConverter::Visit(const BinaryType& type) { - ::arrow::internal::ChunkedBinaryBuilder builder(kBinaryChunksize, pool_); - +template +Status NumPyConverter::VisitBinary(T* builder) { auto data = reinterpret_cast(PyArray_DATA(arr_)); - auto AppendNotNull = [&builder, this](const uint8_t* data) { + auto AppendNotNull = [builder, this](const uint8_t* data) { // This is annoying. NumPy allows strings to have nul-terminators, so // we must check for them here const size_t item_size = strnlen(reinterpret_cast(data), static_cast(itemsize_)); - return builder.Append(data, static_cast(item_size)); + return builder->Append(data, static_cast(item_size)); }; if (mask_ != nullptr) { Ndarray1DIndexer mask_values(mask_); for (int64_t i = 0; i < length_; ++i) { if (mask_values[i]) { - RETURN_NOT_OK(builder.AppendNull()); + RETURN_NOT_OK(builder->AppendNull()); } else { RETURN_NOT_OK(AppendNotNull(data)); } @@ -583,6 +590,14 @@ Status NumPyConverter::Visit(const BinaryType& type) { } } + return Status::OK(); +} + +Status NumPyConverter::Visit(const BinaryType& type) { + ::arrow::internal::ChunkedBinaryBuilder builder(kBinaryChunksize, pool_); + + RETURN_NOT_OK(VisitBinary(&builder)); + ArrayVector result; RETURN_NOT_OK(builder.Finish(&result)); for (auto arr : result) { @@ -591,6 +606,16 @@ Status NumPyConverter::Visit(const BinaryType& type) { return Status::OK(); } +Status NumPyConverter::Visit(const LargeBinaryType& type) { + ::arrow::LargeBinaryBuilder builder(pool_); + + RETURN_NOT_OK(VisitBinary(&builder)); + + std::shared_ptr result; + RETURN_NOT_OK(builder.Finish(&result)); + return PushArray(result->data()); +} + Status NumPyConverter::Visit(const FixedSizeBinaryType& type) { auto byte_width = type.byte_width(); @@ -630,8 +655,8 @@ namespace { // NumPy unicode is UCS4/UTF32 always constexpr int kNumPyUnicodeSize = 4; -Status AppendUTF32(const char* data, int64_t itemsize, int byteorder, - ::arrow::internal::ChunkedStringBuilder* builder) { +template +Status AppendUTF32(const char* data, int64_t itemsize, int byteorder, T* builder) { // The binary \x00\x00\x00\x00 indicates a nul terminator in NumPy unicode, // so we need to detect that here to truncate if necessary. Yep. Py_ssize_t actual_length = 0; @@ -659,11 +684,8 @@ Status AppendUTF32(const char* data, int64_t itemsize, int byteorder, } // namespace -Status NumPyConverter::Visit(const StringType& type) { - util::InitializeUTF8(); - - ::arrow::internal::ChunkedStringBuilder builder(kBinaryChunksize, pool_); - +template +Status NumPyConverter::VisitString(T* builder) { auto data = reinterpret_cast(PyArray_DATA(arr_)); char numpy_byteorder = dtype_->byteorder; @@ -707,7 +729,7 @@ Status NumPyConverter::Visit(const StringType& type) { auto AppendNonNullValue = [&](const uint8_t* data) { if (is_binary_type) { if (ARROW_PREDICT_TRUE(util::ValidateUTF8(data, itemsize_))) { - return builder.Append(data, static_cast(itemsize_)); + return builder->Append(data, static_cast(itemsize_)); } else { return Status::Invalid("Encountered non-UTF8 binary value: ", HexEncode(data, itemsize_)); @@ -715,7 +737,7 @@ Status NumPyConverter::Visit(const StringType& type) { } else { // is_unicode_type case return AppendUTF32(reinterpret_cast(data), itemsize_, byteorder, - &builder); + builder); } }; @@ -723,7 +745,7 @@ Status NumPyConverter::Visit(const StringType& type) { Ndarray1DIndexer mask_values(mask_); for (int64_t i = 0; i < length_; ++i) { if (mask_values[i]) { - RETURN_NOT_OK(builder.AppendNull()); + RETURN_NOT_OK(builder->AppendNull()); } else { RETURN_NOT_OK(AppendNonNullValue(data)); } @@ -736,6 +758,16 @@ Status NumPyConverter::Visit(const StringType& type) { } } + return Status::OK(); +} + +Status NumPyConverter::Visit(const StringType& type) { + util::InitializeUTF8(); + + ::arrow::internal::ChunkedStringBuilder builder(kBinaryChunksize, pool_); + + RETURN_NOT_OK(VisitString(&builder)); + ArrayVector result; RETURN_NOT_OK(builder.Finish(&result)); for (auto arr : result) { @@ -744,6 +776,19 @@ Status NumPyConverter::Visit(const StringType& type) { return Status::OK(); } +Status NumPyConverter::Visit(const LargeStringType& type) { + util::InitializeUTF8(); + + ::arrow::LargeStringBuilder builder(pool_); + + RETURN_NOT_OK(VisitString(&builder)); + + std::shared_ptr result; + RETURN_NOT_OK(builder.Finish(&result)); + RETURN_NOT_OK(PushArray(result->data())); + return Status::OK(); +} + Status NumPyConverter::Visit(const StructType& type) { std::vector sub_converters; std::vector sub_arrays; diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index b89e0ace157af..78982628ea80f 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -2355,32 +2355,36 @@ def test_array_from_numpy_timedelta_incorrect_unit(): pa.array(data) -def test_array_from_numpy_ascii(): +@pytest.mark.parametrize('binary_type', [None, pa.binary(), pa.large_binary()]) +def test_array_from_numpy_ascii(binary_type): + # Default when no type is specified should be binary + expected_type = binary_type or pa.binary() + arr = np.array(['abcde', 'abc', ''], dtype='|S5') - arrow_arr = pa.array(arr) - assert arrow_arr.type == 'binary' - expected = pa.array(['abcde', 'abc', ''], type='binary') + arrow_arr = pa.array(arr, binary_type) + assert arrow_arr.type == expected_type + expected = pa.array(['abcde', 'abc', ''], type=expected_type) assert arrow_arr.equals(expected) mask = np.array([False, True, False]) - arrow_arr = pa.array(arr, mask=mask) - expected = pa.array(['abcde', None, ''], type='binary') + arrow_arr = pa.array(arr, binary_type, mask=mask) + expected = pa.array(['abcde', None, ''], type=expected_type) assert arrow_arr.equals(expected) # Strided variant arr = np.array(['abcde', 'abc', ''] * 5, dtype='|S5')[::2] mask = np.array([False, True, False] * 5)[::2] - arrow_arr = pa.array(arr, mask=mask) + arrow_arr = pa.array(arr, binary_type, mask=mask) expected = pa.array(['abcde', '', None, 'abcde', '', None, 'abcde', ''], - type='binary') + type=expected_type) assert arrow_arr.equals(expected) # 0 itemsize arr = np.array(['', '', ''], dtype='|S0') - arrow_arr = pa.array(arr) - expected = pa.array(['', '', ''], type='binary') + arrow_arr = pa.array(arr, binary_type) + expected = pa.array(['', '', ''], type=expected_type) assert arrow_arr.equals(expected) @@ -2499,35 +2503,39 @@ def test_interval_array_from_dateoffset(): assert list(actual_list[0]) == expected_from_pandas -def test_array_from_numpy_unicode(): +@pytest.mark.parametrize('string_type', [None, pa.utf8(), pa.large_utf8()]) +def test_array_from_numpy_unicode(string_type): + # Default when no type is specified should be utf8 + expected_type = string_type or pa.utf8() + dtypes = ['U5'] for dtype in dtypes: arr = np.array(['abcde', 'abc', ''], dtype=dtype) - arrow_arr = pa.array(arr) - assert arrow_arr.type == 'utf8' - expected = pa.array(['abcde', 'abc', ''], type='utf8') + arrow_arr = pa.array(arr, string_type) + assert arrow_arr.type == expected_type + expected = pa.array(['abcde', 'abc', ''], type=expected_type) assert arrow_arr.equals(expected) mask = np.array([False, True, False]) - arrow_arr = pa.array(arr, mask=mask) - expected = pa.array(['abcde', None, ''], type='utf8') + arrow_arr = pa.array(arr, string_type, mask=mask) + expected = pa.array(['abcde', None, ''], type=expected_type) assert arrow_arr.equals(expected) # Strided variant arr = np.array(['abcde', 'abc', ''] * 5, dtype=dtype)[::2] mask = np.array([False, True, False] * 5)[::2] - arrow_arr = pa.array(arr, mask=mask) + arrow_arr = pa.array(arr, string_type, mask=mask) expected = pa.array(['abcde', '', None, 'abcde', '', None, - 'abcde', ''], type='utf8') + 'abcde', ''], type=expected_type) assert arrow_arr.equals(expected) # 0 itemsize arr = np.array(['', '', ''], dtype='