diff --git a/cpp/src/arrow/buffer.h b/cpp/src/arrow/buffer.h index 9270c4dea3fb6..8daa8bafaaf39 100644 --- a/cpp/src/arrow/buffer.h +++ b/cpp/src/arrow/buffer.h @@ -444,10 +444,27 @@ class ARROW_EXPORT ResizableBuffer : public MutableBuffer { return Reserve(sizeof(T) * new_nb_elements); } + public: + uint8_t* offsetBuffer; + int64_t offsetCapacity; + uint8_t* validityBuffer; + uint8_t* outerValidityBuffer; + protected: - ResizableBuffer(uint8_t* data, int64_t size) : MutableBuffer(data, size) {} + ResizableBuffer(uint8_t* data, int64_t size) : MutableBuffer(data, size) { + offsetBuffer = nullptr; + offsetCapacity = 0; + validityBuffer = nullptr; + outerValidityBuffer = nullptr; + + } ResizableBuffer(uint8_t* data, int64_t size, std::shared_ptr mm) - : MutableBuffer(data, size, std::move(mm)) {} + : MutableBuffer(data, size, std::move(mm)) { + offsetBuffer = nullptr; + offsetCapacity = 0; + validityBuffer = nullptr; + outerValidityBuffer = nullptr; + } }; /// \defgroup buffer-allocation-functions Functions for allocating buffers diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt index 6a92224e9113d..dc0c427f48d23 100644 --- a/cpp/src/gandiva/CMakeLists.txt +++ b/cpp/src/gandiva/CMakeLists.txt @@ -61,9 +61,11 @@ set(SRC_FILES expression_registry.cc exported_funcs_registry.cc filter.cc + array_ops.cc function_ir_builder.cc function_registry.cc function_registry_arithmetic.cc + function_registry_array.cc function_registry_datetime.cc function_registry_hash.cc function_registry_math_ops.cc @@ -249,7 +251,8 @@ add_gandiva_test(internals-test random_generator_holder_test.cc hash_utils_test.cc gdv_function_stubs_test.cc - interval_holder_test.cc) + interval_holder_test.cc + array_ops_test.cc) add_subdirectory(precompiled) add_subdirectory(tests) diff --git a/cpp/src/gandiva/annotator.cc b/cpp/src/gandiva/annotator.cc index b341fdde3a3f4..abd5ba6b1a4bf 100644 --- a/cpp/src/gandiva/annotator.cc +++ b/cpp/src/gandiva/annotator.cc @@ -46,15 +46,27 @@ FieldDescriptorPtr Annotator::MakeDesc(FieldPtr field, bool is_output) { int data_idx = buffer_count_++; int validity_idx = buffer_count_++; int offsets_idx = FieldDescriptor::kInvalidIdx; + int child_offsets_idx = FieldDescriptor::kInvalidIdx; if (arrow::is_binary_like(field->type()->id())) { offsets_idx = buffer_count_++; } + + if (field->type()->id() == arrow::Type::LIST) { + offsets_idx = buffer_count_++; + if (arrow::is_binary_like(field->type()->field(0)->type()->id())) { + child_offsets_idx = buffer_count_++; + } + } int data_buffer_ptr_idx = FieldDescriptor::kInvalidIdx; if (is_output) { data_buffer_ptr_idx = buffer_count_++; } + int child_valid_buffer_ptr_idx = FieldDescriptor::kInvalidIdx; + if (field->type()->id() == arrow::Type::LIST) { + child_valid_buffer_ptr_idx = buffer_count_++; + } return std::make_shared(field, data_idx, validity_idx, offsets_idx, - data_buffer_ptr_idx); + data_buffer_ptr_idx, child_offsets_idx, child_valid_buffer_ptr_idx); } int Annotator::AddHolderPointer(void* holder) { @@ -80,17 +92,76 @@ void Annotator::PrepareBuffersForField(const FieldDescriptor& desc, if (desc.HasOffsetsIdx()) { uint8_t* offsets_buf = const_cast(array_data.buffers[buffer_idx]->data()); eval_batch->SetBuffer(desc.offsets_idx(), offsets_buf, array_data.offset); - ++buffer_idx; + + if (desc.HasChildOffsetsIdx()) { + if (is_output) { + // if list field is output field, we should put buffer pointer into eval batch + // for resizing + uint8_t* child_offsets_buf = reinterpret_cast( + array_data.child_data.at(0)->buffers[buffer_idx].get()); + eval_batch->SetBuffer(desc.child_data_offsets_idx(), child_offsets_buf, + array_data.child_data.at(0)->offset); + + uint8_t* child_valid_buf = reinterpret_cast( + array_data.child_data.at(0)->buffers[0].get()); + eval_batch->SetBuffer(desc.child_data_validity_idx(), child_valid_buf, + array_data.child_data.at(0)->offset); + + } else { + // if list field is input field, just put buffer data into eval batch + uint8_t* child_offsets_buf = const_cast( + array_data.child_data.at(0)->buffers[buffer_idx]->data()); + eval_batch->SetBuffer(desc.child_data_offsets_idx(), child_offsets_buf, + array_data.child_data.at(0)->offset); + + uint8_t* child_valid_buf = const_cast( + array_data.child_data.at(0)->buffers[0]->data()); + eval_batch->SetBuffer(desc.child_data_offsets_idx(), child_valid_buf, + array_data.child_data.at(0)->offset); + } + } + if (array_data.type->id() != arrow::Type::LIST || + arrow::is_binary_like(array_data.type->field(0)->type()->id())) { + // primitive type list data buffer index is 1 + // binary like type list data buffer index is 2 + ++buffer_idx; + } + } + + int const childDataIndex = 0; + if (array_data.type->id() != arrow::Type::LIST) { + uint8_t* data_buf = const_cast(array_data.buffers[buffer_idx]->data()); + eval_batch->SetBuffer(desc.data_idx(), data_buf, array_data.offset); + } else { + uint8_t* data_buf = + const_cast(array_data.child_data.at(childDataIndex)->buffers[buffer_idx]->data()); + eval_batch->SetBuffer(desc.data_idx(), data_buf, array_data.child_data.at(0)->offset); + + int const childDataBufferIndex = 0; + if (array_data.child_data.at(childDataIndex)->buffers[childDataBufferIndex] ) { + uint8_t* child_valid_buf = const_cast( + array_data.child_data.at(childDataIndex)->buffers[childDataBufferIndex]->data()); + eval_batch->SetBuffer(desc.child_data_validity_idx(), child_valid_buf, 0); + } + } - uint8_t* data_buf = const_cast(array_data.buffers[buffer_idx]->data()); - eval_batch->SetBuffer(desc.data_idx(), data_buf, array_data.offset); if (is_output) { // pass in the Buffer object for output data buffers. Can be used for resizing. - uint8_t* data_buf_ptr = - reinterpret_cast(array_data.buffers[buffer_idx].get()); - eval_batch->SetBuffer(desc.data_buffer_ptr_idx(), data_buf_ptr, array_data.offset); + + if (array_data.type->id() != arrow::Type::LIST) { + uint8_t* data_buf_ptr = + reinterpret_cast(array_data.buffers[buffer_idx].get()); + eval_batch->SetBuffer(desc.data_buffer_ptr_idx(), data_buf_ptr, array_data.offset); + } else { + // list data buffer is in child data buffer + uint8_t* data_buf_ptr = reinterpret_cast( + array_data.child_data.at(0)->buffers[buffer_idx].get()); + eval_batch->SetBuffer(desc.data_buffer_ptr_idx(), data_buf_ptr, + array_data.child_data.at(0)->offset); + } } + } EvalBatchPtr Annotator::PrepareEvalBatch(const arrow::RecordBatch& record_batch, @@ -106,7 +177,6 @@ EvalBatchPtr Annotator::PrepareEvalBatch(const arrow::RecordBatch& record_batch, // skip columns not involved in the expression. continue; } - PrepareBuffersForField(*(found->second), *(record_batch.column_data(i)), eval_batch.get(), false /*is_output*/); } diff --git a/cpp/src/gandiva/array_ops.cc b/cpp/src/gandiva/array_ops.cc new file mode 100644 index 0000000000000..7170534342085 --- /dev/null +++ b/cpp/src/gandiva/array_ops.cc @@ -0,0 +1,357 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/array_ops.h" + +#include +#include +#include + +#include "arrow/util/value_parsing.h" + +#include "gandiva/gdv_function_stubs.h" +#include "gandiva/engine.h" +#include "gandiva/exported_funcs.h" + +/// Stub functions that can be accessed from LLVM or the pre-compiled library. + +template +Type* array_remove_template(int64_t context_ptr, const Type* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + Type remove_data, bool remove_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row, int32_t* out_len, int32_t** valid_ptr) +{ + std::vector newInts; + + const int32_t* entry_validityAdjusted = entry_validity - (loop_var ); + int64_t validityBitIndex = 0; + //The validity index already has the current row length added to it, so decrement. + validityBitIndex = validity_index_var - entry_len; + std::vector outValid; + for (int i = 0; i < entry_len; i++) { + Type entry_item = *(entry_buf + i); + if (remove_data_valid && entry_item == remove_data) { + //Do not add the item to remove. + } else if (!arrow::bit_util::GetBit(reinterpret_cast(entry_validityAdjusted), validityBitIndex + i)) { + outValid.push_back(false); + newInts.push_back(0); + } else { + outValid.push_back(true); + newInts.push_back(entry_item); + } + } + + *out_len = (int)newInts.size(); + + //Since this function can remove values we don't know the length ahead of time. + //A fast way to compute Math.ceil(input / 8.0). + int validByteSize = (unsigned int)((*out_len) + 7) >> 3; + + uint8_t* validRet = gdv_fn_context_arena_malloc(context_ptr, validByteSize); + for (size_t i = 0; i < outValid.size(); i++) { + arrow::bit_util::SetBitTo(validRet, i, outValid[i]); + } + + int32_t outBufferLength = (int)*out_len * sizeof(Type); + //length is number of items, but buffers must account for byte size. + uint8_t* ret = gdv_fn_context_arena_malloc(context_ptr, outBufferLength); + memcpy(ret, newInts.data(), outBufferLength); + *valid_row = true; + + //Return null if the input array is null or the data to remove is null. + if (!combined_row_validity || !remove_data_valid) { + *out_len = 0; + *valid_row = false; //this one is what works for the top level validity. + } + + *valid_ptr = reinterpret_cast(validRet); + return reinterpret_cast(ret); +} + +template +bool array_contains_template(const Type* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + Type contains_data, bool contains_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row) { + if (!combined_row_validity || !contains_data_valid) { + *valid_row = false; + return false; + } + *valid_row = true; + + const int32_t* entry_validityAdjusted = entry_validity - (loop_var ); + int64_t validityBitIndex = validity_index_var - entry_len; + + bool found_null_in_data = false; + for (int i = 0; i < entry_len; i++) { + if (!arrow::bit_util::GetBit(reinterpret_cast(entry_validityAdjusted), validityBitIndex + i)) { + found_null_in_data = true; + continue; + } + Type entry_item = *(entry_buf + i); + if (contains_data_valid && entry_item == contains_data) { + return true; + } + } + //If there is null in the input and the item is not found the result is null. + if (found_null_in_data) { + *valid_row = false; + } + return false; +} + +extern "C" { + +bool array_int32_contains_int32(int64_t context_ptr, const int32_t* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + int32_t contains_data, bool contains_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row) { + return array_contains_template(entry_buf, entry_len, entry_validity, + combined_row_validity, contains_data, contains_data_valid, + loop_var, validity_index_var, valid_row); +} + +bool array_int64_contains_int64(int64_t context_ptr, const int64_t* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + int64_t contains_data, bool contains_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row) { + return array_contains_template(entry_buf, entry_len, entry_validity, + combined_row_validity, contains_data, contains_data_valid, + loop_var, validity_index_var, valid_row); +} + +bool array_float32_contains_float32(int64_t context_ptr, const float* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + float contains_data, bool contains_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row) { + return array_contains_template(entry_buf, entry_len, entry_validity, + combined_row_validity, contains_data, contains_data_valid, + loop_var, validity_index_var, valid_row); +} + +bool array_float64_contains_float64(int64_t context_ptr, const double* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + double contains_data, bool contains_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row) { + return array_contains_template(entry_buf, entry_len, entry_validity, + combined_row_validity, contains_data, contains_data_valid, + loop_var, validity_index_var, valid_row); +} + + + +int32_t* array_int32_remove(int64_t context_ptr, const int32_t* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + int32_t remove_data, bool remove_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row, int32_t* out_len, int32_t** valid_ptr) { + return array_remove_template(context_ptr, entry_buf, + entry_len, entry_validity, combined_row_validity, + remove_data, remove_data_valid, + loop_var, validity_index_var, + valid_row, out_len, valid_ptr); +} + +int64_t* array_int64_remove(int64_t context_ptr, const int64_t* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + int64_t remove_data, bool remove_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row, int32_t* out_len, int32_t** valid_ptr){ + return array_remove_template(context_ptr, entry_buf, + entry_len, entry_validity, combined_row_validity, + remove_data, remove_data_valid, + loop_var, validity_index_var, + valid_row, out_len, valid_ptr); +} + +float* array_float32_remove(int64_t context_ptr, const float* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + float remove_data, bool remove_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row, int32_t* out_len, int32_t** valid_ptr){ + return array_remove_template(context_ptr, entry_buf, + entry_len, entry_validity, combined_row_validity, + remove_data, remove_data_valid, + loop_var, validity_index_var, + valid_row, out_len, valid_ptr); +} + + +double* array_float64_remove(int64_t context_ptr, const double* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + double remove_data, bool remove_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row, int32_t* out_len, int32_t** valid_ptr){ + return array_remove_template(context_ptr, entry_buf, + entry_len, entry_validity, combined_row_validity, + remove_data, remove_data_valid, + loop_var, validity_index_var, + valid_row, out_len, valid_ptr); +} +} + +namespace gandiva { +void ExportedArrayFunctions::AddMappings(Engine* engine) const { + std::vector args; + auto types = engine->types(); + + args = {types->i64_type(), // int64_t execution_context + types->i64_ptr_type(), // int8_t* data ptr + types->i32_type(), // int32_t data length + types->i32_ptr_type(), // input validity buffer + types->i1_type(), // bool input row validity + types->i32_type(), // int32_t value to check for + types->i1_type(), // bool validity --Needed? + types->i64_type(), //in loop var --Needed? + types->i64_type(), //in validity_index_var index into the valdity vector for the current row. + types->i1_ptr_type() //output validity for the row + }; + + engine->AddGlobalMappingForFunc("array_int32_contains_int32", + types->i1_type() /*return_type*/, args, + reinterpret_cast(array_int32_contains_int32)); + + args = {types->i64_type(), // int64_t execution_context + types->i64_ptr_type(), // int8_t* data ptr + types->i32_type(), // int32_t data length + types->i32_ptr_type(), // input validity buffer + types->i1_type(), // bool input row validity + types->i64_type(), // int32_t value to check for + types->i1_type(), // bool validity --Needed? + types->i64_type(), //in loop var --Needed? + types->i64_type(), //in validity_index_var index into the valdity vector for the current row. + types->i1_ptr_type() //output validity for the row + }; + + engine->AddGlobalMappingForFunc("array_int64_contains_int64", + types->i1_type() /*return_type*/, args, + reinterpret_cast(array_int64_contains_int64)); + + args = {types->i64_type(), // int64_t execution_context + types->float_ptr_type(), // int8_t* data ptr + types->i32_type(), // int32_t data length + types->i32_ptr_type(), // input validity buffer + types->i1_type(), // bool input row validity + types->float_type(), // int32_t value to check for + types->i1_type(), // bool validity --Needed? + types->i64_type(), //in loop var --Needed? + types->i64_type(), //in validity_index_var index into the valdity vector for the current row. + types->i1_ptr_type() //output validity for the row + }; + + engine->AddGlobalMappingForFunc("array_float32_contains_float32", + types->i1_type() /*return_type*/, args, + reinterpret_cast(array_float32_contains_float32)); + + args = {types->i64_type(), // int64_t execution_context + types->double_ptr_type(), // int8_t* data ptr + types->i32_type(), // int32_t data length + types->i32_ptr_type(), // input validity buffer + types->i1_type(), // bool input row validity + types->double_type(), // int32_t value to check for + types->i1_type(), // bool validity --Needed? + types->i64_type(), //in loop var --Needed? + types->i64_type(), //in validity_index_var index into the valdity vector for the current row. + types->i1_ptr_type() //output validity for the row + }; + + engine->AddGlobalMappingForFunc("array_float64_contains_float64", + types->i1_type() /*return_type*/, args, + reinterpret_cast(array_float64_contains_float64)); + //Array remove. + args = {types->i64_type(), // int64_t execution_context + types->i32_ptr_type(), // int8_t* input data ptr + types->i32_type(), // int32_t input length + types->i32_ptr_type(), // input validity buffer + types->i1_type(), // bool input row validity + types->i32_type(), //value to remove from input + types->i1_type(), // bool validity --Needed? + types->i64_type(), //in loop var --Needed? + types->i64_type(), //in validity_index_var index into the valdity vector for the current row. + types->i1_ptr_type(), //output validity for the row + types->i32_ptr_type(), // output array length + types->i32_ptr_type() //output pointer to new validity buffer + + }; + engine->AddGlobalMappingForFunc("array_int32_remove", + types->i32_ptr_type(), args, + reinterpret_cast(array_int32_remove)); + + args = {types->i64_type(), // int64_t execution_context + types->i64_ptr_type(), // int8_t* input data ptr + types->i32_type(), // int32_t input length + types->i32_ptr_type(), // input validity buffer + types->i1_type(), // bool input row validity + types->i64_type(), //value to remove from input + types->i1_type(), // bool validity --Needed? + types->i64_type(), //in loop var --Needed? + types->i64_type(), //in validity_index_var index into the valdity vector for the current row. + types->i1_ptr_type(), //output validity for the row + types->i32_ptr_type(), // output array length + types->i32_ptr_type() //output pointer to new validity buffer + + }; + + engine->AddGlobalMappingForFunc("array_int64_remove", + types->i64_ptr_type(), args, + reinterpret_cast(array_int64_remove)); + + args = {types->i64_type(), // int64_t execution_context + types->float_ptr_type(), // float* input data ptr + types->i32_type(), // int32_t input length + types->i32_ptr_type(), // input validity buffer + types->i1_type(), // bool input row validity + types->float_type(), //value to remove from input + types->i1_type(), // bool validity --Needed? + types->i64_type(), //in loop var --Needed? + types->i64_type(), //in validity_index_var index into the valdity vector for the current row. + types->i1_ptr_type(), //output validity for the row + types->i32_ptr_type(), // output array length + types->i32_ptr_type() //output pointer to new validity buffer + + }; + + engine->AddGlobalMappingForFunc("array_float32_remove", + types->float_ptr_type(), args, + reinterpret_cast(array_float32_remove)); + + args = {types->i64_type(), // int64_t execution_context + types->double_ptr_type(), // int8_t* input data ptr + types->i32_type(), // int32_t input length + types->i32_ptr_type(), // input validity buffer + types->i1_type(), // bool input row validity + types->double_type(), //value to remove from input + types->i1_type(), // bool validity --Needed? + types->i64_type(), //in loop var --Needed? + types->i64_type(), //in validity_index_var index into the valdity vector for the current row. + types->i1_ptr_type(), //output validity for the row + types->i32_ptr_type(), // output array length + types->i32_ptr_type() //output pointer to new validity buffer + + }; + + engine->AddGlobalMappingForFunc("array_float64_remove", + types->double_ptr_type(), args, + reinterpret_cast(array_float64_remove)); +} +} // namespace gandiva diff --git a/cpp/src/gandiva/array_ops.h b/cpp/src/gandiva/array_ops.h new file mode 100644 index 0000000000000..c0de72a39472b --- /dev/null +++ b/cpp/src/gandiva/array_ops.h @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "gandiva/visibility.h" + +namespace llvm { +class VectorType; +} + +/// Array functions that can be accessed from LLVM. +extern "C" { + +GANDIVA_EXPORT +bool array_int32_contains_int32(int64_t context_ptr, const int32_t* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + int32_t contains_data, bool entry_validWhat, + int64_t loop_var, int64_t validity_index_var, + bool* valid_buf); +GANDIVA_EXPORT +bool array_int64_contains_int64(int64_t context_ptr, const int64_t* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + int64_t contains_data, bool entry_validWhat, + int64_t loop_var, int64_t validity_index_var, + bool* valid_buf); + +GANDIVA_EXPORT +bool array_float32_contains_float32(int64_t context_ptr, const float* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + float contains_data, bool entry_validWhat, + int64_t loop_var, int64_t validity_index_var, + bool* valid_buf); + +GANDIVA_EXPORT +bool array_float64_contains_float64(int64_t context_ptr, const double* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + double contains_data, bool entry_validWhat, + int64_t loop_var, int64_t validity_index_var, + bool* valid_buf); + +GANDIVA_EXPORT +int32_t* array_int32_remove(int64_t context_ptr, const int32_t* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + int32_t remove_data, bool entry_validWhat, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row, int32_t* out_len, int32_t** valid_ptr); + +GANDIVA_EXPORT +int64_t* array_int64_remove(int64_t context_ptr, const int64_t* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + int64_t remove_data, bool entry_validWhat, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row, int32_t* out_len, int32_t** valid_ptr); + +GANDIVA_EXPORT +float* array_float32_remove(int64_t context_ptr, const float* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + float remove_data, bool entry_validWhat, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row, int32_t* out_len, int32_t** valid_ptr); + +GANDIVA_EXPORT +double* array_float64_remove(int64_t context_ptr, const double* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + double remove_data, bool entry_validWhat, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row, int32_t* out_len, int32_t** valid_ptr); + +} diff --git a/cpp/src/gandiva/array_ops_test.cc b/cpp/src/gandiva/array_ops_test.cc new file mode 100644 index 0000000000000..bf01c1fe0a091 --- /dev/null +++ b/cpp/src/gandiva/array_ops_test.cc @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "gandiva/execution_context.h" +#include "gandiva/precompiled/types.h" + +namespace gandiva { + +TEST(TestArrayOps, TestInt32ContainsInt32) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast(&ctx); + int32_t data[] = {1, 2, 3, 4}; + int32_t entry_offsets_len = 3; + int32_t contains_data = 2; + int32_t entry_validity = 15; + bool valid = false; + + EXPECT_EQ( + array_int32_contains_int32(ctx_ptr, data, entry_offsets_len, &entry_validity, + true, contains_data, true, 0, 3, &valid), + true); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/dex.h b/cpp/src/gandiva/dex.h index 2998c2131769a..95053ddabfb75 100644 --- a/cpp/src/gandiva/dex.h +++ b/cpp/src/gandiva/dex.h @@ -80,6 +80,23 @@ class GANDIVA_EXPORT VectorReadFixedLenValueDex : public VectorReadBaseDex { void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } }; +/// value component of a fixed-len list ValueVector +class GANDIVA_EXPORT VectorReadFixedLenValueListDex : public VectorReadBaseDex { + public: + explicit VectorReadFixedLenValueListDex(FieldDescriptorPtr field_desc) + : VectorReadBaseDex(field_desc) {} + + int DataIdx() const { return field_desc_->data_idx(); } + + int OffsetsIdx() const { return field_desc_->offsets_idx(); } + + int ValidityIdx() const { return field_desc_->validity_idx(); } + + int ChildValidityIdx() const { return field_desc_->child_data_validity_idx(); } + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } +}; + /// value component of a variable-len ValueVector class GANDIVA_EXPORT VectorReadVarLenValueDex : public VectorReadBaseDex { public: @@ -93,6 +110,21 @@ class GANDIVA_EXPORT VectorReadVarLenValueDex : public VectorReadBaseDex { void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } }; +/// value component of a variable-len list ValueVector +class GANDIVA_EXPORT VectorReadVarLenValueListDex : public VectorReadBaseDex { + public: + explicit VectorReadVarLenValueListDex(FieldDescriptorPtr field_desc) + : VectorReadBaseDex(field_desc) {} + + int DataIdx() const { return field_desc_->data_idx(); } + + int OffsetsIdx() const { return field_desc_->offsets_idx(); } + + int ChildOffsetsIdx() const { return field_desc_->child_data_offsets_idx(); } + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } +}; + /// validity based on a local bitmap. class GANDIVA_EXPORT LocalBitMapValidityDex : public Dex { public: diff --git a/cpp/src/gandiva/dex_visitor.h b/cpp/src/gandiva/dex_visitor.h index 5d160bb22ca68..4a03b9c21fc8a 100644 --- a/cpp/src/gandiva/dex_visitor.h +++ b/cpp/src/gandiva/dex_visitor.h @@ -28,7 +28,9 @@ namespace gandiva { class VectorReadValidityDex; class VectorReadFixedLenValueDex; +class VectorReadFixedLenValueListDex; class VectorReadVarLenValueDex; +class VectorReadVarLenValueListDex; class LocalBitMapValidityDex; class LiteralDex; class TrueDex; @@ -49,7 +51,9 @@ class GANDIVA_EXPORT DexVisitor { virtual void Visit(const VectorReadValidityDex& dex) = 0; virtual void Visit(const VectorReadFixedLenValueDex& dex) = 0; + virtual void Visit(const VectorReadFixedLenValueListDex& dex) = 0; virtual void Visit(const VectorReadVarLenValueDex& dex) = 0; + virtual void Visit(const VectorReadVarLenValueListDex& dex) = 0; virtual void Visit(const LocalBitMapValidityDex& dex) = 0; virtual void Visit(const TrueDex& dex) = 0; virtual void Visit(const FalseDex& dex) = 0; @@ -75,7 +79,9 @@ class GANDIVA_EXPORT DexVisitor { class GANDIVA_EXPORT DexDefaultVisitor : public DexVisitor { VISIT_DCHECK(VectorReadValidityDex) VISIT_DCHECK(VectorReadFixedLenValueDex) + VISIT_DCHECK(VectorReadFixedLenValueListDex) VISIT_DCHECK(VectorReadVarLenValueDex) + VISIT_DCHECK(VectorReadVarLenValueListDex) VISIT_DCHECK(LocalBitMapValidityDex) VISIT_DCHECK(TrueDex) VISIT_DCHECK(FalseDex) diff --git a/cpp/src/gandiva/exported_funcs.h b/cpp/src/gandiva/exported_funcs.h index 5a14c52162156..55145b301e78c 100644 --- a/cpp/src/gandiva/exported_funcs.h +++ b/cpp/src/gandiva/exported_funcs.h @@ -32,6 +32,12 @@ class ExportedFuncsBase { virtual void AddMappings(Engine* engine) const = 0; }; +// Class for exporting Array functions +class ExportedArrayFunctions : public ExportedFuncsBase { + void AddMappings(Engine* engine) const override; +}; +REGISTER_EXPORTED_FUNCS(ExportedArrayFunctions); + // Class for exporting Stub functions class ExportedStubFunctions : public ExportedFuncsBase { void AddMappings(Engine* engine) const override; diff --git a/cpp/src/gandiva/expr_decomposer.cc b/cpp/src/gandiva/expr_decomposer.cc index 957d9d046bd57..719d4006e65ae 100644 --- a/cpp/src/gandiva/expr_decomposer.cc +++ b/cpp/src/gandiva/expr_decomposer.cc @@ -39,8 +39,17 @@ Status ExprDecomposer::Visit(const FieldNode& node) { DexPtr validity_dex = std::make_shared(desc); DexPtr value_dex; - if (desc->HasOffsetsIdx()) { - value_dex = std::make_shared(desc); + if (desc->HasChildOffsetsIdx()) { + // handle list type + value_dex = std::make_shared(desc); + } else if (desc->HasOffsetsIdx()) { + if (desc->field()->type()->id() == arrow::Type::LIST) { + // handle list type + auto p = std::make_shared(desc); + value_dex = p; + } else { + value_dex = std::make_shared(desc); + } } else { value_dex = std::make_shared(desc); } diff --git a/cpp/src/gandiva/expr_validator.cc b/cpp/src/gandiva/expr_validator.cc index 35a13494523d0..265f2c119cd0e 100644 --- a/cpp/src/gandiva/expr_validator.cc +++ b/cpp/src/gandiva/expr_validator.cc @@ -67,7 +67,7 @@ Status ExprValidator::Validate(const ExpressionPtr& expr) { } Status ExprValidator::Visit(const FieldNode& node) { - auto llvm_type = types_->IRType(node.return_type()->id()); + auto llvm_type = types_->DataVecType(node.return_type()); ARROW_RETURN_IF(llvm_type == nullptr, Status::ExpressionValidationError("Field ", node.field()->name(), " has unsupported data type ", @@ -136,7 +136,7 @@ Status ExprValidator::Visit(const IfNode& node) { } Status ExprValidator::Visit(const LiteralNode& node) { - auto llvm_type = types_->IRType(node.return_type()->id()); + auto llvm_type = types_->DataVecType(node.return_type()); ARROW_RETURN_IF(llvm_type == nullptr, Status::ExpressionValidationError("Value ", ToString(node.holder()), " has unsupported data type ", diff --git a/cpp/src/gandiva/expression_registry.cc b/cpp/src/gandiva/expression_registry.cc index 9bff97f5ad269..12ac0d0b154e8 100644 --- a/cpp/src/gandiva/expression_registry.cc +++ b/cpp/src/gandiva/expression_registry.cc @@ -166,6 +166,13 @@ static void AddArrowTypesToVector(arrow::Type::type type, DataTypeVector& vector case arrow::Type::type::INTERVAL_DAY_TIME: vector.push_back(arrow::day_time_interval()); break; + case arrow::Type::type::LIST: + vector.push_back(arrow::list(arrow::int32())); + vector.push_back(arrow::list(arrow::int64())); + vector.push_back(arrow::list(arrow::float32())); + vector.push_back(arrow::list(arrow::float64())); + vector.push_back(arrow::list(arrow::utf8())); + break; default: // Unsupported types. test ensures that // when one of these are added build breaks. diff --git a/cpp/src/gandiva/field_descriptor.h b/cpp/src/gandiva/field_descriptor.h index 0fe6fe37f4dd3..dfcf6872d501d 100644 --- a/cpp/src/gandiva/field_descriptor.h +++ b/cpp/src/gandiva/field_descriptor.h @@ -30,12 +30,16 @@ class FieldDescriptor { static const int kInvalidIdx = -1; FieldDescriptor(FieldPtr field, int data_idx, int validity_idx = kInvalidIdx, - int offsets_idx = kInvalidIdx, int data_buffer_ptr_idx = kInvalidIdx) + int offsets_idx = kInvalidIdx, int data_buffer_ptr_idx = kInvalidIdx, + int child_offsets_idx = kInvalidIdx, int child_validity_idx = kInvalidIdx) : field_(field), data_idx_(data_idx), validity_idx_(validity_idx), offsets_idx_(offsets_idx), - data_buffer_ptr_idx_(data_buffer_ptr_idx) {} + data_buffer_ptr_idx_(data_buffer_ptr_idx), + child_offsets_idx_(child_offsets_idx), + child_validity_idx_(child_validity_idx) { + } /// Index of validity array in the array-of-buffers int validity_idx() const { return validity_idx_; } @@ -49,6 +53,12 @@ class FieldDescriptor { /// Index of data buffer pointer in the array-of-buffers int data_buffer_ptr_idx() const { return data_buffer_ptr_idx_; } + /// Index of list type child data offsets + int child_data_offsets_idx() const { return child_offsets_idx_; } + int child_data_validity_idx() const { return child_validity_idx_; } + void set_child_data_validity_idx(int val) { + child_validity_idx_ = val; + } FieldPtr field() const { return field_; } const std::string& Name() const { return field_->name(); } @@ -58,12 +68,18 @@ class FieldDescriptor { bool HasDataBufferPtrIdx() const { return data_buffer_ptr_idx_ != kInvalidIdx; } + bool HasChildOffsetsIdx() const { return child_offsets_idx_ != kInvalidIdx; } + + bool HasChildValidityIdx() const { return child_validity_idx_ != kInvalidIdx; } + private: FieldPtr field_; int data_idx_; int validity_idx_; int offsets_idx_; int data_buffer_ptr_idx_; + int child_offsets_idx_; + int child_validity_idx_; }; } // namespace gandiva diff --git a/cpp/src/gandiva/function_registry.cc b/cpp/src/gandiva/function_registry.cc index 67b7b404b325c..9180e8c33ca33 100644 --- a/cpp/src/gandiva/function_registry.cc +++ b/cpp/src/gandiva/function_registry.cc @@ -16,17 +16,19 @@ // under the License. #include "gandiva/function_registry.h" + +#include +#include +#include + #include "gandiva/function_registry_arithmetic.h" +#include "gandiva/function_registry_array.h" #include "gandiva/function_registry_datetime.h" #include "gandiva/function_registry_hash.h" #include "gandiva/function_registry_math_ops.h" #include "gandiva/function_registry_string.h" #include "gandiva/function_registry_timestamp_arithmetic.h" -#include -#include -#include - namespace gandiva { FunctionRegistry::iterator FunctionRegistry::begin() const { @@ -64,6 +66,10 @@ SignatureMap FunctionRegistry::InitPCMap() { auto v6 = GetDateTimeArithmeticFunctionRegistry(); pc_registry_.insert(std::end(pc_registry_), v6.begin(), v6.end()); + + auto v7 = GetArrayFunctionRegistry(); + pc_registry_.insert(std::end(pc_registry_), v7.begin(), v7.end()); + for (auto& elem : pc_registry_) { for (auto& func_signature : elem.signatures()) { map.insert(std::make_pair(&(func_signature), &elem)); diff --git a/cpp/src/gandiva/function_registry_array.cc b/cpp/src/gandiva/function_registry_array.cc new file mode 100644 index 0000000000000..893ba6e3d2b04 --- /dev/null +++ b/cpp/src/gandiva/function_registry_array.cc @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/function_registry_array.h" + +#include "gandiva/function_registry_common.h" + +namespace gandiva { +std::vector GetArrayFunctionRegistry() { + static std::vector array_fn_registry_ = { + NativeFunction("array_contains", {}, DataTypeVector{list(int32()), int32()}, + boolean(), kResultNullInternal, "array_int32_contains_int32", + NativeFunction::kNeedsContext), + NativeFunction("array_contains", {}, DataTypeVector{list(int64()), int64()}, + boolean(), kResultNullInternal, "array_int64_contains_int64", + NativeFunction::kNeedsContext), + NativeFunction("array_contains", {}, DataTypeVector{list(float32()), float32()}, + boolean(), kResultNullInternal, "array_float32_contains_float32", + NativeFunction::kNeedsContext), + NativeFunction("array_contains", {}, DataTypeVector{list(float64()), float64()}, + boolean(), kResultNullInternal, "array_float64_contains_float64", + NativeFunction::kNeedsContext), + + NativeFunction("array_remove", {}, DataTypeVector{list(int32()), int32()}, + list(int32()), kResultNullInternal, "array_int32_remove", + NativeFunction::kNeedsContext), + NativeFunction("array_remove", {}, DataTypeVector{list(int64()), int64()}, + list(int64()), kResultNullInternal, "array_int64_remove", + NativeFunction::kNeedsContext), + NativeFunction("array_remove", {}, DataTypeVector{list(float32()), float32()}, + list(float32()), kResultNullInternal, "array_float32_remove", + NativeFunction::kNeedsContext), + NativeFunction("array_remove", {}, DataTypeVector{list(float64()), float64()}, + list(float64()), kResultNullInternal, "array_float64_remove", + NativeFunction::kNeedsContext), + }; + return array_fn_registry_; +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/function_registry_array.h b/cpp/src/gandiva/function_registry_array.h new file mode 100644 index 0000000000000..9b8e4553702a8 --- /dev/null +++ b/cpp/src/gandiva/function_registry_array.h @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "gandiva/native_function.h" + +namespace gandiva { + +std::vector GetArrayFunctionRegistry(); + +} // namespace gandiva diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index 5146f7fa1990a..2ca9529fa846b 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -36,8 +36,6 @@ #include "gandiva/random_generator_holder.h" #include "gandiva/to_date_holder.h" -/// Stub functions that can be accessed from LLVM or the pre-compiled library. - extern "C" { static char mask_array[256] = { @@ -161,6 +159,96 @@ int32_t gdv_fn_populate_varlen_vector(int64_t context_ptr, int8_t* data_ptr, return 0; } +/// Stub functions that can be accessed from LLVM or the pre-compiled library. +#define POPULATE_NUMERIC_LIST_TYPE_VECTOR(TYPE, SCALE) \ + int32_t gdv_fn_populate_list_##TYPE##_vector(int64_t context_ptr, int8_t* data_ptr, \ + int32_t* offsets, int64_t slot, \ + TYPE* entry_buf, int32_t entry_len, int32_t** valid_ptr) { \ + auto buffer = reinterpret_cast(data_ptr); \ + int32_t offset = static_cast(buffer->size()); \ + auto status = buffer->Resize(offset + entry_len * SCALE, false /*shrink*/); \ + if (!status.ok()) { \ + gandiva::ExecutionContext* context = \ + reinterpret_cast(context_ptr); \ + context->set_error_msg(status.message().c_str()); \ + return -1; \ + } \ + memcpy(buffer->mutable_data() + offset, (char*)entry_buf, entry_len * SCALE); \ + int validbitIndex = offset / SCALE; \ + for (int i = 0; i < entry_len; i++) { \ + arrow::bit_util::SetBitTo(buffer->validityBuffer, validbitIndex + i, arrow::bit_util::GetBit(reinterpret_cast(valid_ptr), i)); \ + } \ + offsets = reinterpret_cast(buffer->offsetBuffer); \ + offsets[slot] = offset / SCALE; \ + offsets[slot + 1] = offset / SCALE + entry_len; \ + return 0; \ + }\ + +POPULATE_NUMERIC_LIST_TYPE_VECTOR(int32_t, 4) +POPULATE_NUMERIC_LIST_TYPE_VECTOR(int64_t, 8) +POPULATE_NUMERIC_LIST_TYPE_VECTOR(float, 4) +POPULATE_NUMERIC_LIST_TYPE_VECTOR(double, 8) + +int32_t gdv_fn_populate_list_varlen_vector(int64_t context_ptr, int8_t* data_ptr, + int32_t* offsets, int32_t* child_offsets, + int64_t slot, const char* entry_buf, + int32_t* entry_child_offsets, + int32_t entry_offsets_len) { + // we should calculate varlen list type varlen offset + // copy from entry child offsets + // it should be noted that, + // buffer size unit is byte(8 bit), + // offset element unit is int32(32 bit) + auto child_offsets_buffer = reinterpret_cast(child_offsets); + int32_t child_offsets_buffer_offset = + static_cast(child_offsets_buffer->size()); + + // data buffer elelment is char(8 bit) + auto data_buffer = reinterpret_cast(data_ptr); + int32_t data_buffer_offset = static_cast(data_buffer->size()); + + // sets the size in the child offsets buffer + // offsets element is int32, we should resize buffer by extra offsets_len * 4 + auto status = child_offsets_buffer->Resize( + child_offsets_buffer_offset + entry_offsets_len * 4, false /*shrink*/); + if (!status.ok()) { + gandiva::ExecutionContext* context = + reinterpret_cast(context_ptr); + + context->set_error_msg(status.message().c_str()); + return -1; + } + + // append the new child offsets entry to child offsets buffer + // offsets buffer last offset number indicating data length + // we should take this extra offset into consider + // so the initialize child_offsets_buffer length is 1(int32) + memcpy(child_offsets_buffer->mutable_data() + child_offsets_buffer_offset - 4, + (char*)entry_child_offsets, (entry_offsets_len + 1) * 4); + + // compute data length + int32_t data_length = + *(entry_child_offsets + entry_offsets_len) - *(entry_child_offsets); + + // sets the size in the child offsets buffer. + status = data_buffer->Resize(data_buffer_offset + data_length, false /*shrink*/); + if (!status.ok()) { + gandiva::ExecutionContext* context = + reinterpret_cast(context_ptr); + + context->set_error_msg(status.message().c_str()); + return -1; + } + + // append the new child offsets entry to child offsets buffer + memcpy(data_buffer->mutable_data() + data_buffer_offset, entry_buf, data_length); + + // update offsets buffer. + offsets[slot] = child_offsets_buffer_offset / 4 - 1; + offsets[slot + 1] = child_offsets_buffer_offset / 4 - 1 + entry_offsets_len; + return 0; +} + #define CRC_FUNCTION(TYPE) \ GANDIVA_EXPORT \ int64_t gdv_fn_crc_32_##TYPE(int64_t ctx, const char* input, int32_t input_len) { \ @@ -838,6 +926,8 @@ const char* gdv_mask_show_last_n_utf8_int32(int64_t context, const char* data, int32_t n_to_mask = num_of_chars - n_to_show; return gdv_mask_first_n_utf8_int32(context, data, data_len, n_to_mask, out_len); } + +#undef POPULATE_NUMERIC_LIST_TYPE_VECTOR } namespace gandiva { @@ -1174,6 +1264,34 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { types->i32_type() /*return_type*/, args, reinterpret_cast(gdv_fn_cast_intervalyear_utf8)); + engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_utf8", + types->i1_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_in_expr_lookup_utf8)); + +#define ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION(LLVM_TYPE, DATA_TYPE) \ + args = {types->i64_type(), types->i8_ptr_type(), types->i32_ptr_type(), \ + types->i64_type(), types->LLVM_TYPE##_ptr_type(), types->i32_type(), types->i32_ptr_type()}; \ + engine->AddGlobalMappingForFunc( \ + "gdv_fn_populate_list_" #DATA_TYPE "_vector", types->i32_type() /*return_type*/, \ + args, reinterpret_cast(gdv_fn_populate_list_##DATA_TYPE##_vector)); + + ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION(i32, int32_t) + ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION(i64, int64_t) + ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION(float, float) + ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION(double, double) + + // gdv_fn_populate_varlen_vector + args = {types->i64_type(), // int64_t execution_context + types->i8_ptr_type(), // int8_t* data ptr + types->i32_ptr_type(), // int32_t* offsets ptr + types->i64_type(), // int64_t slot + types->i8_ptr_type(), // const char* entry_buf + types->i32_type()}; // int32_t entry__len + + engine->AddGlobalMappingForFunc("gdv_fn_populate_varlen_vector", + types->i32_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_populate_varlen_vector)); + // gdv_fn_cast_intervalyear_utf8_int32 args = { types->i64_type(), // context @@ -1190,6 +1308,26 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { "gdv_fn_cast_intervalyear_utf8_int32", types->i32_type() /*return_type*/, args, reinterpret_cast(gdv_fn_cast_intervalyear_utf8_int32)); + // gdv_fn_populate_list_varlen_vector + args = {types->i64_type(), // int64_t execution_context + types->i8_ptr_type(), // int8_t* data ptr + types->i32_ptr_type(), // int32_t* offsets ptr + types->i32_ptr_type(), // int32_t* child offsets ptr + types->i64_type(), // int64_t slot + types->i8_ptr_type(), // const char* entry_buf + types->i32_ptr_type(), // int32_t* entry child offsets ptr + types->i32_type(), // int32_t entry child offsets length + types->i32_ptr_type() // int32_t* entry child valid ptr + }; + + engine->AddGlobalMappingForFunc( + "gdv_fn_populate_list_varlen_vector", types->i32_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_populate_list_varlen_vector)); + + // gdv_fn_random + args = {types->i64_type()}; + engine->AddGlobalMappingForFunc("gdv_fn_random", types->double_type(), args, + reinterpret_cast(gdv_fn_random)); // to_utc_timezone_timestamp args = { types->i64_type(), // context @@ -1289,4 +1427,6 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { engine->AddGlobalMappingForFunc("mask_utf8", types->i8_ptr_type() /*return_type*/, args, reinterpret_cast(mask_utf8)); } + +#undef ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION } // namespace gandiva diff --git a/cpp/src/gandiva/llvm_generator.cc b/cpp/src/gandiva/llvm_generator.cc index 1615eece1f2c7..5e676d70251fa 100644 --- a/cpp/src/gandiva/llvm_generator.cc +++ b/cpp/src/gandiva/llvm_generator.cc @@ -30,7 +30,6 @@ #include "gandiva/lvalue.h" namespace gandiva { - #define ADD_TRACE(...) \ if (enable_ir_traces_) { \ AddTrace(__VA_ARGS__); \ @@ -94,7 +93,7 @@ Status LLVMGenerator::Build(const ExpressionVector& exprs, SelectionVector::Mode // Compile and inject into the process' memory the generated function. ARROW_RETURN_NOT_OK(engine_->FinalizeModule()); - + // setup the jit functions for each expression. for (auto& compiled_expr : compiled_exprs_) { auto fn_name = compiled_expr->GetFunctionName(mode); @@ -210,6 +209,14 @@ llvm::Value* LLVMGenerator::GetOffsetsReference(llvm::Value* arg_addrs, int idx, return ir_builder()->CreateIntToPtr(load, types()->i32_ptr_type(), name + "_oarray"); } +/// Get reference to child offsets array at specified index in the args list. +llvm::Value* LLVMGenerator::GetChildOffsetsReference(llvm::Value* arg_addrs, int idx, + FieldPtr field) { + const std::string& name = field->name(); + llvm::Value* load = LoadVectorAtIndex(arg_addrs, types()->i64_type(), idx, name); + return ir_builder()->CreateIntToPtr(load, types()->i32_ptr_type(), name + "_coarray"); +} + /// Get reference to local bitmap array at specified index in the args list. llvm::Value* LLVMGenerator::GetLocalBitMapReference(llvm::Value* arg_bitmaps, int idx) { llvm::Value* load = LoadVectorAtIndex(arg_bitmaps, types()->i64_type(), idx, ""); @@ -350,6 +357,10 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count, slice_offsets.push_back(offset); } + llvm::AllocaInst* validity_index_var = + new llvm::AllocaInst(types()->i64_type(), 0, "validity_index_var", loop_entry); + builder->CreateStore(types()->i64_constant(0), validity_index_var); + // Loop body builder->SetInsertPoint(loop_body); @@ -368,7 +379,7 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count, // The visitor can add code to both the entry/loop blocks. Visitor visitor(this, fn, loop_entry, arg_addrs, arg_local_bitmaps, arg_holder_ptrs, - slice_offsets, arg_context_ptr, position_var); + slice_offsets, arg_context_ptr, position_var, validity_index_var); value_expr->Accept(visitor); LValuePtr output_value = visitor.result(); @@ -397,12 +408,46 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count, AddFunctionCall("gdv_fn_populate_varlen_vector", types()->i32_type(), {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, loop_var, output_value->data(), output_value->length()}); + } else if (output_type_id == arrow::Type::STRUCT) { + auto slot_offset = builder->CreateGEP(types()->IRType(output_type_id), output_ref, loop_var); + builder->CreateStore(output_value->data(), slot_offset); + } else if (output_type_id == arrow::Type::LIST) { + auto output_list_internal_type = output->Type()->field(0)->type()->id(); + + if (arrow::is_binary_like(output_list_internal_type)) { + auto output_list_value = std::dynamic_pointer_cast(output_value); + llvm::Value* child_output_offset_ref = GetChildOffsetsReference( + arg_addrs, output->child_data_offsets_idx(), output->field()); + AddFunctionCall( + "gdv_fn_populate_list_varlen_vector", types()->i32_type(), + {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, + child_output_offset_ref, loop_var, output_list_value->data(), + output_list_value->child_offsets(), output_list_value->offsets_length()}); + } else if (output_list_internal_type == arrow::Type::INT32) { + AddFunctionCall("gdv_fn_populate_list_int32_t_vector", types()->i32_type(), + {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, + loop_var, output_value->data(), output_value->length(), output_value->validity()}); + } else if (output_list_internal_type == arrow::Type::INT64) { + AddFunctionCall("gdv_fn_populate_list_int64_t_vector", types()->i32_type(), + {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, + loop_var, output_value->data(), output_value->length(), output_value->validity()}); + } else if (output_list_internal_type == arrow::Type::FLOAT) { + AddFunctionCall("gdv_fn_populate_list_float_vector", types()->i32_type(), + {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, + loop_var, output_value->data(), output_value->length(), output_value->validity()}); + } else if (output_list_internal_type == arrow::Type::DOUBLE) { + AddFunctionCall("gdv_fn_populate_list_double_vector", types()->i32_type(), + {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, + loop_var, output_value->data(), output_value->length(), output_value->validity()}); + } else { + return Status::NotImplemented("list internal type ", + output->Type()->field(0)->type()->ToString(), + " not supported"); + } } else { return Status::NotImplemented("output type ", output->Type()->ToString(), " not supported"); } - ADD_TRACE("saving result " + output->Name() + " value %T", output_value->data()); - if (visitor.has_arena_allocs()) { // Reset allocations to avoid excessive memory usage. Once the result is copied to // the output vector (store instruction above), any memory allocations in this @@ -496,10 +541,10 @@ void LLVMGenerator::ComputeBitMapsForExpr(const CompiledExpr& compiled_expr, /// /// 1. Do the intersection of input/local bitmaps to generate a temporary bitmap. /// 2. copy just the relevant bits from the temporary bitmap to the output bitmap. + LocalBitMapsHolder bit_map_holder(eval_batch->num_records(), 1); uint8_t* temp_bitmap = bit_map_holder.GetLocalBitMap(0); accumulator.ComputeResult(temp_bitmap); - auto num_out_records = selection_vector->GetNumSlots(); // the memset isn't required, doing it just for valgrind. memset(dst_bitmap, 0, arrow::bit_util::BytesForBits(num_out_records)); @@ -530,6 +575,7 @@ llvm::Value* LLVMGenerator::AddFunctionCall(const std::string& full_name, value = ir_builder()->CreateCall(fn, args); } else { value = ir_builder()->CreateCall(fn, args, full_name); + DCHECK(value->getType() == ret_type); } @@ -558,7 +604,8 @@ LLVMGenerator::Visitor::Visitor(LLVMGenerator* generator, llvm::Function* functi llvm::Value* arg_local_bitmaps, llvm::Value* arg_holder_ptrs, std::vector slice_offsets, - llvm::Value* arg_context_ptr, llvm::Value* loop_var) + llvm::Value* arg_context_ptr, llvm::Value* loop_var, + llvm::Value* validity_index_var) : generator_(generator), function_(function), entry_block_(entry_block), @@ -568,11 +615,13 @@ LLVMGenerator::Visitor::Visitor(LLVMGenerator* generator, llvm::Function* functi slice_offsets_(slice_offsets), arg_context_ptr_(arg_context_ptr), loop_var_(loop_var), + validity_index_var_(validity_index_var), has_arena_allocs_(false) { ADD_VISITOR_TRACE("Iteration %T", loop_var); } void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex& dex) { + ADD_VISITOR_TRACE("VectorReadFixedLenValueDex"); llvm::IRBuilder<>* builder = ir_builder(); auto types = generator_->types(); llvm::Value* slot_ref = GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field()); @@ -580,6 +629,7 @@ void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex& dex) { llvm::Value* slot_value; std::shared_ptr lvalue; + ADD_VISITOR_TRACE("VectorReadFixedLenValueDex"); switch (dex.FieldType()->id()) { case arrow::Type::BOOL: slot_value = generator_->GetPackedBitValue(slot_ref, slot_index); @@ -606,11 +656,77 @@ void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex& dex) { result_ = lvalue; } -void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueDex& dex) { +void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueListDex& dex) { + ADD_VISITOR_TRACE("VectorReadFixedLenValueListDex"); llvm::IRBuilder<>* builder = ir_builder(); llvm::Value* slot; auto types = generator_->types(); + auto type = types->IRType(dex.FieldType()->id()); + + auto dt = dex.FieldType(); + if (dt->id() == arrow::Type::LIST) { + type = types->IRType(dt->fields()[0]->type()->id() ); + } + + arrow::Type::type at32 = arrow::Type::INT32; + auto type32 = types->IRType(at32); + + // compute list len from the offsets array. + llvm::Value* offsets_slot_ref = + GetBufferReference(dex.OffsetsIdx(), kBufferTypeOffsets, dex.Field()); + llvm::Value* offsets_slot_index = + builder->CreateAdd(loop_var_, GetSliceOffset(dex.OffsetsIdx())); + slot = builder->CreateGEP(type32, offsets_slot_ref, offsets_slot_index); + llvm::Value* offset_start = builder->CreateLoad(type32, slot, "offset_start"); + // => offset_end = offsets[loop_var + 1] + llvm::Value* offsets_slot_index_next = builder->CreateAdd( + offsets_slot_index, generator_->types()->i64_constant(1), "loop_var+1"); + slot = builder->CreateGEP(type32, offsets_slot_ref, offsets_slot_index_next); + llvm::Value* offset_end = builder->CreateLoad(type32, slot, "offset_end"); + + // => offsets_len_value = offset_end - offset_start + llvm::Value* list_len = builder->CreateSub(offset_end, offset_start, "offsets_len"); + + // get data array + llvm::Value* slot_ref = GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field()); + // do not forget slice offset + llvm::Value* offset_start_int64 = + builder->CreateIntCast(offset_start, generator_->types()->i64_type(), true); + llvm::Value* slot_index = + builder->CreateAdd(offset_start_int64, GetSliceOffset(dex.DataIdx())); + llvm::Value* data_list = builder->CreateGEP(type, slot_ref, slot_index); + + auto list_len_var = builder->CreateIntCast(list_len, types->i64_type(), true); + llvm::Value* vv_end = builder->CreateLoad(generator_->types()->i64_type(),validity_index_var_, "vv_end"); + +llvm::Value* updated_validity_index_var = builder->CreateAdd( + vv_end, list_len_var, "validity_index_var+offset"); + + builder->CreateStore(updated_validity_index_var, validity_index_var_); + llvm::Value* b_slot_index = + builder->CreateAdd(loop_var_, GetSliceOffset(dex.ValidityIdx())); + llvm::Value* b_slot_ref = GetBufferReference(dex.ChildValidityIdx(), kBufferTypeValidity, dex.Field()); + llvm::Value* validity = builder->CreateGEP(type32, b_slot_ref, b_slot_index); + + std::string str3 = "validity:"; + if (validity) { + llvm::raw_string_ostream output3(str3); + validity->print(output3); + } + ADD_VISITOR_TRACE("visit fixed-len data list vector " + dex.FieldName() + " length %T", + list_len); + ADD_VISITOR_TRACE("visit fixed-len data list vector " + dex.FieldName() + " updated_validity_index_var %T", + updated_validity_index_var); + + result_.reset(new LValue(data_list, list_len, validity)); +} + +void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueDex& dex) { + llvm::IRBuilder<>* builder = ir_builder(); + llvm::Value* slot; + auto types = generator_->types(); + ADD_VISITOR_TRACE("VectorReadVarLenValueDex"); // compute len from the offsets array. llvm::Value* offsets_slot_ref = GetBufferReference(dex.OffsetsIdx(), kBufferTypeOffsets, dex.Field()); @@ -641,7 +757,73 @@ void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueDex& dex) { result_.reset(new LValue(data_value, len_value)); } +/* + * create list type field context for each loop + */ +void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueListDex& dex) { + /* Example + * list_data: [["var_len_val11"], ["var_len_val211", "var_len_val22"], + * ["var_len_val3331"]] loop_var: 0, 1, 2 data_buffer: + * var_len_val11var_len_val211var_len_val22var_len_val3331 offsets_buffer: 0, 1, 3, 4 + * list_element_len = offsets[loop_var+1]-offsets[loop_var] => 1, 2, 1 + * child_offsets_buffer: 0, 13, 27, 40, 55 + * for i in list_element_len: + * data_buffer[child_offsets_buffer[offsets[i+1]] - child_offsets_buffer[offsets[i]]] + * => list_data[loop_var][i] + */ + ADD_VISITOR_TRACE("VectorReadVarLenValueListDex"); + llvm::IRBuilder<>* builder = ir_builder(); + llvm::Value* slot; + auto types = generator_->types(); + auto type = types->IRType(dex.FieldType()->id()); + + arrow::Type::type at = arrow::Type::INT32; + type = types->IRType(at); + + // compute list length from the offsets array + llvm::Value* offsets_slot_ref = + GetBufferReference(dex.OffsetsIdx(), kBufferTypeOffsets, dex.Field()); + llvm::Value* offsets_slot_index = + builder->CreateAdd(loop_var_, GetSliceOffset(dex.OffsetsIdx())); + + // => offset_start = offsets[loop_var] + slot = builder->CreateGEP(type, offsets_slot_ref, offsets_slot_index); + llvm::Value* offset_start = builder->CreateLoad(type, slot, "offset_start"); + + // => offset_end = offsets[loop_var + 1] + llvm::Value* offsets_slot_index_next = builder->CreateAdd( + offsets_slot_index, generator_->types()->i64_constant(1), "loop_var+1"); + slot = builder->CreateGEP(type, offsets_slot_ref, offsets_slot_index_next); + llvm::Value* offset_end = builder->CreateLoad(type, slot, "offset_end"); + + // => list_data_length = offset_end - offset_start + llvm::Value* list_data_length = + builder->CreateSub(offset_end, offset_start, "offsets_len"); + + // get the child offsets array from the child offsets array, + // start from offset 'offset_start' + llvm::Value* child_offset_slot_ref = + GetBufferReference(dex.ChildOffsetsIdx(), kBufferTypeChildOffsets, dex.Field()); + // do not forget slice offset + llvm::Value* offset_start_int64 = + builder->CreateIntCast(offset_start, generator_->types()->i64_type(), true); + llvm::Value* child_offset_slot_index = + builder->CreateAdd(offset_start_int64, GetSliceOffset(dex.ChildOffsetsIdx())); + llvm::Value* child_offsets = + builder->CreateGEP(type, child_offset_slot_ref, child_offset_slot_index); + llvm::Value* child_offset_start = + builder->CreateLoad(type, child_offsets, "child_offset_start"); + + // get the data array + llvm::Value* data_slot_ref = + GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field()); + llvm::Value* data_value = builder->CreateGEP(type, data_slot_ref, child_offset_start); + + result_.reset(new ListLValue(data_value, child_offsets, list_data_length)); +} + void LLVMGenerator::Visitor::Visit(const VectorReadValidityDex& dex) { + ADD_VISITOR_TRACE("VectorReadValidityDex"); llvm::IRBuilder<>* builder = ir_builder(); llvm::Value* slot_ref = GetBufferReference(dex.ValidityIdx(), kBufferTypeValidity, dex.Field()); @@ -654,6 +836,7 @@ void LLVMGenerator::Visitor::Visit(const VectorReadValidityDex& dex) { } void LLVMGenerator::Visitor::Visit(const LocalBitMapValidityDex& dex) { + ADD_VISITOR_TRACE("LocalBitMapValidityDex"); llvm::Value* slot_ref = GetLocalBitMapReference(dex.local_bitmap_idx()); llvm::Value* validity = generator_->GetPackedBitValue(slot_ref, loop_var_); @@ -664,14 +847,17 @@ void LLVMGenerator::Visitor::Visit(const LocalBitMapValidityDex& dex) { } void LLVMGenerator::Visitor::Visit(const TrueDex& dex) { + ADD_VISITOR_TRACE("TrueDex"); result_.reset(new LValue(generator_->types()->true_constant())); } void LLVMGenerator::Visitor::Visit(const FalseDex& dex) { + ADD_VISITOR_TRACE("FalseDex"); result_.reset(new LValue(generator_->types()->false_constant())); } void LLVMGenerator::Visitor::Visit(const LiteralDex& dex) { + ADD_VISITOR_TRACE("LiteralDex"); LLVMTypes* types = generator_->types(); llvm::Value* value = nullptr; llvm::Value* len = nullptr; @@ -716,7 +902,6 @@ void LLVMGenerator::Visitor::Visit(const LiteralDex& dex) { case arrow::Type::STRING: case arrow::Type::BINARY: { const std::string& str = std::get(dex.holder()); - value = ir_builder()->CreateGlobalStringPtr(str.c_str()); len = types->i32_constant(static_cast(str.length())); break; @@ -777,8 +962,7 @@ void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex& dex) { llvm::IRBuilder<>* builder = ir_builder(); LLVMTypes* types = generator_->types(); auto arrow_type_id = arrow_return_type->id(); - auto result_type = types->IRType(arrow_type_id); - + auto result_type = types->DataVecType(arrow_return_type); // Build combined validity of the args. llvm::Value* is_valid = types->true_constant(); for (auto& pair : dex.args()) { @@ -836,18 +1020,34 @@ void LLVMGenerator::Visitor::Visit(const NullableInternalFuncDex& dex) { auto params = BuildParams(dex.get_holder_idx(), dex.args(), true, native_function->NeedsContext()); + + + auto arrow_return_type = dex.func_descriptor()->return_type(); + + bool passLoopVars = false; + for (auto& p : dex.func_descriptor()->params()) { + if (p->id() == arrow::Type::LIST) { + passLoopVars = true; + break; + } + } + if (passLoopVars) + { + params.push_back(loop_var_); + auto valid_var = builder->CreateLoad(types->i64_type(), validity_index_var_, "loaded_var"); + params.push_back(valid_var); + } + // add an extra arg for validity (allocated on stack). llvm::AllocaInst* result_valid_ptr = new llvm::AllocaInst(types->i8_type(), 0, "result_valid", entry_block_); params.push_back(result_valid_ptr); - auto arrow_return_type = dex.func_descriptor()->return_type(); result_ = BuildFunctionCall(native_function, arrow_return_type, ¶ms); // load the result validity and truncate to i1. auto result_valid_i8 = builder->CreateLoad(types->i8_type(), result_valid_ptr); llvm::Value* result_valid = builder->CreateTrunc(result_valid_i8, types->i1_type()); - // set validity bit in the local bitmap. ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), result_valid); } @@ -1125,25 +1325,31 @@ void LLVMGenerator::Visitor::VisitInExpression( } void LLVMGenerator::Visitor::Visit(const InExprDexBase& dex) { + ADD_VISITOR_TRACE("InExprDexBase&"); VisitInExpression(dex); } void LLVMGenerator::Visitor::Visit(const InExprDexBase& dex) { + ADD_VISITOR_TRACE("InExprDexBase&"); VisitInExpression(dex); } void LLVMGenerator::Visitor::Visit(const InExprDexBase& dex) { + ADD_VISITOR_TRACE("InExprDexBase&"); VisitInExpression(dex); } void LLVMGenerator::Visitor::Visit(const InExprDexBase& dex) { + ADD_VISITOR_TRACE("InExprDexBase&"); VisitInExpression(dex); } void LLVMGenerator::Visitor::Visit(const InExprDexBase& dex) { + ADD_VISITOR_TRACE("InExprDexBase&"); VisitInExpression(dex); } void LLVMGenerator::Visitor::Visit(const InExprDexBase& dex) { + ADD_VISITOR_TRACE("InExprDexBase&"); VisitInExpression(dex); } @@ -1151,6 +1357,7 @@ LValuePtr LLVMGenerator::Visitor::BuildIfElse(llvm::Value* condition, std::function then_func, std::function else_func, DataTypePtr result_type) { + ADD_VISITOR_TRACE("BuildIfElse"); llvm::IRBuilder<>* builder = ir_builder(); llvm::LLVMContext* context = generator_->context(); LLVMTypes* types = generator_->types(); @@ -1180,7 +1387,7 @@ LValuePtr LLVMGenerator::Visitor::BuildIfElse(llvm::Value* condition, // Emit the merge block. builder->SetInsertPoint(merge_bb); - auto llvm_type = types->IRType(result_type->id()); + auto llvm_type = types->DataVecType(result_type); llvm::PHINode* result_value = builder->CreatePHI(llvm_type, 2, "res_value"); result_value->addIncoming(then_lvalue->data(), then_bb); result_value->addIncoming(else_lvalue->data(), else_bb); @@ -1226,7 +1433,7 @@ LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func, std::vector* params) { auto types = generator_->types(); auto arrow_return_type_id = arrow_return_type->id(); - auto llvm_return_type = types->IRType(arrow_return_type_id); + auto llvm_return_type = types->DataVecType(arrow_return_type); DecimalIR decimalIR(generator_->engine_.get()); if (arrow_return_type_id == arrow::Type::DECIMAL) { @@ -1255,6 +1462,7 @@ LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func, } // add extra arg for return length for variable len return types (allocated on stack). llvm::AllocaInst* result_len_ptr = nullptr; + llvm::AllocaInst* valid_ptr = nullptr; if (arrow::is_binary_like(arrow_return_type_id)) { result_len_ptr = new llvm::AllocaInst(generator_->types()->i32_type(), 0, "result_len", entry_block_); @@ -1262,6 +1470,17 @@ LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func, has_arena_allocs_ = true; } + if (arrow_return_type_id == arrow::Type::LIST) { + + result_len_ptr = new llvm::AllocaInst(generator_->types()->i32_type(), 0, + "result_len", entry_block_); + params->push_back(result_len_ptr); + has_arena_allocs_ = true; + valid_ptr = new llvm::AllocaInst(generator_->types()->i32_ptr_type(), 0, + "valid_ptr", entry_block_); + params->push_back(valid_ptr); + } + // Make the function call llvm::IRBuilder<>* builder = ir_builder(); auto value = @@ -1272,7 +1491,11 @@ LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func, (result_len_ptr == nullptr) ? nullptr : builder->CreateLoad(result_len_ptr->getAllocatedType(), result_len_ptr); - return std::make_shared(value, value_len); + auto validity = + (valid_ptr == nullptr) + ? nullptr + : builder->CreateLoad(generator_->types()->i32_ptr_type(), valid_ptr); + return std::make_shared(value, value_len, validity); } } @@ -1281,6 +1504,7 @@ std::vector LLVMGenerator::Visitor::BuildParams( bool with_context) { std::vector params; + ADD_VISITOR_TRACE("LLVMGenerator::Visitor::BuildParams"); // add context if required. if (with_context) { params.push_back(arg_context_ptr_); @@ -1311,6 +1535,7 @@ std::vector LLVMGenerator::Visitor::BuildParams( // append all the parameters corresponding to this LValue. result_ref.AppendFunctionParams(¶ms); + // build validity. if (with_validity) { llvm::Value* validity_expr = BuildCombinedValidity(pair->validity_exprs()); @@ -1356,6 +1581,10 @@ llvm::Value* LLVMGenerator::Visitor::GetBufferReference(int idx, BufferType buff case kBufferTypeOffsets: slot_ref = generator_->GetOffsetsReference(arg_addrs_, idx, field); break; + + case kBufferTypeChildOffsets: + slot_ref = generator_->GetChildOffsetsReference(arg_addrs_, idx, field); + break; } // Revert to the saved block. @@ -1384,6 +1613,7 @@ llvm::Value* LLVMGenerator::Visitor::GetLocalBitMapReference(int idx) { /// The local bitmap is pre-filled with 1s. Clear only if invalid. void LLVMGenerator::Visitor::ClearLocalBitMapIfNotValid(int local_bitmap_idx, llvm::Value* is_valid) { + ADD_VISITOR_TRACE("ClearLocalBitMapIfNotValid"); llvm::Value* slot_ref = GetLocalBitMapReference(local_bitmap_idx); generator_->ClearPackedBitValueIfFalse(slot_ref, loop_var_, is_valid); } @@ -1454,5 +1684,4 @@ void LLVMGenerator::AddTrace(const std::string& msg, llvm::Value* value) { } AddFunctionCall(print_fn_name, types()->i32_type(), args); } - } // namespace gandiva diff --git a/cpp/src/gandiva/llvm_generator.h b/cpp/src/gandiva/llvm_generator.h index 04f9b854b1d29..e8c15bdf00744 100644 --- a/cpp/src/gandiva/llvm_generator.h +++ b/cpp/src/gandiva/llvm_generator.h @@ -98,11 +98,13 @@ class GANDIVA_EXPORT LLVMGenerator { llvm::BasicBlock* entry_block, llvm::Value* arg_addrs, llvm::Value* arg_local_bitmaps, llvm::Value* arg_holder_ptrs, std::vector slice_offsets, llvm::Value* arg_context_ptr, - llvm::Value* loop_var); + llvm::Value* loop_var, llvm::Value* validity_index); void Visit(const VectorReadValidityDex& dex) override; void Visit(const VectorReadFixedLenValueDex& dex) override; + void Visit(const VectorReadFixedLenValueListDex& dex) override; void Visit(const VectorReadVarLenValueDex& dex) override; + void Visit(const VectorReadVarLenValueListDex& dex) override; void Visit(const LocalBitMapValidityDex& dex) override; void Visit(const TrueDex& dex) override; void Visit(const FalseDex& dex) override; @@ -127,7 +129,12 @@ class GANDIVA_EXPORT LLVMGenerator { bool has_arena_allocs() { return has_arena_allocs_; } private: - enum BufferType { kBufferTypeValidity = 0, kBufferTypeData, kBufferTypeOffsets }; + enum BufferType { + kBufferTypeValidity = 0, + kBufferTypeData, + kBufferTypeOffsets, + kBufferTypeChildOffsets + }; llvm::IRBuilder<>* ir_builder() { return generator_->ir_builder(); } llvm::Module* module() { return generator_->module(); } @@ -175,6 +182,7 @@ class GANDIVA_EXPORT LLVMGenerator { std::vector slice_offsets_; llvm::Value* arg_context_ptr_; llvm::Value* loop_var_; + llvm::Value* validity_index_var_; bool has_arena_allocs_; }; @@ -195,6 +203,10 @@ class GANDIVA_EXPORT LLVMGenerator { /// Generate code to load the vector at specified index and cast it as offsets array. llvm::Value* GetOffsetsReference(llvm::Value* arg_addrs, int idx, FieldPtr field); + /// Generate code to load the vector at specified index and cast it as child offsets + /// array. + llvm::Value* GetChildOffsetsReference(llvm::Value* arg_addrs, int idx, FieldPtr field); + /// Generate code to load the vector at specified index and cast it as buffer pointer. llvm::Value* GetDataBufferPtrReference(llvm::Value* arg_addrs, int idx, FieldPtr field); diff --git a/cpp/src/gandiva/llvm_types.cc b/cpp/src/gandiva/llvm_types.cc index de322a8c0fcb5..3eb49f39037f6 100644 --- a/cpp/src/gandiva/llvm_types.cc +++ b/cpp/src/gandiva/llvm_types.cc @@ -42,7 +42,8 @@ LLVMTypes::LLVMTypes(llvm::LLVMContext& context) : context_(context) { {arrow::Type::type::BINARY, i8_ptr_type()}, {arrow::Type::type::DECIMAL, i128_type()}, {arrow::Type::type::INTERVAL_MONTHS, i32_type()}, - {arrow::Type::type::INTERVAL_DAY_TIME, i64_type()}}; + {arrow::Type::type::INTERVAL_DAY_TIME, i64_type()}, + {arrow::Type::type::LIST, list_type()}}; } } // namespace gandiva diff --git a/cpp/src/gandiva/llvm_types.h b/cpp/src/gandiva/llvm_types.h index d6f0952713efc..58b7c3008695f 100644 --- a/cpp/src/gandiva/llvm_types.h +++ b/cpp/src/gandiva/llvm_types.h @@ -46,6 +46,8 @@ class GANDIVA_EXPORT LLVMTypes { llvm::Type* i128_type() { return llvm::Type::getInt128Ty(context_); } + llvm::VectorType* list_type() { return llvm::ScalableVectorType::get(i8_type(), (unsigned int)0); } + llvm::StructType* i128_split_type() { // struct with high/low bits (see decimal_ops.cc:DecimalSplit) return llvm::StructType::get(context_, {i64_type(), i64_type()}, false); @@ -57,6 +59,8 @@ class GANDIVA_EXPORT LLVMTypes { llvm::PointerType* ptr_type(llvm::Type* type) { return type->getPointerTo(); } + llvm::PointerType* i1_ptr_type() { return ptr_type(i1_type()); } + llvm::PointerType* i8_ptr_type() { return ptr_type(i8_type()); } llvm::PointerType* i32_ptr_type() { return ptr_type(i32_type()); } @@ -65,6 +69,10 @@ class GANDIVA_EXPORT LLVMTypes { llvm::PointerType* i128_ptr_type() { return ptr_type(i128_type()); } + llvm::PointerType* float_ptr_type() { return ptr_type(float_type()); } + + llvm::PointerType* double_ptr_type() { return ptr_type(double_type()); } + template llvm::Constant* int_constant(ctype val) { return llvm::ConstantInt::get(context_, llvm::APInt(N, val)); @@ -87,6 +95,10 @@ class GANDIVA_EXPORT LLVMTypes { return llvm::ConstantFP::get(float_type(), val); } + llvm::LLVMContext* get_context() { + return &context_; + } + llvm::Constant* double_constant(double val) { return llvm::ConstantFP::get(double_type(), val); } @@ -104,6 +116,17 @@ class GANDIVA_EXPORT LLVMTypes { /// For a given data type, find the ir type used for the data vector slot. llvm::Type* DataVecType(const DataTypePtr& data_type) { + // support list type + // list type data is formed by base type buffer, wrapped with offsets buffer + // offsets buffer is to separate data into list + // not support nested list + if (data_type->id() == arrow::Type::LIST) { + //Nested lists aren't supported yet. + if (data_type->field(0)->type()->id() == arrow::Type::LIST) { + return NULL; + } + return IRType(data_type->field(0)->type()->id()); + } return IRType(data_type->id()); } diff --git a/cpp/src/gandiva/llvm_types_test.cc b/cpp/src/gandiva/llvm_types_test.cc index 6669683061825..665a82d133fad 100644 --- a/cpp/src/gandiva/llvm_types_test.cc +++ b/cpp/src/gandiva/llvm_types_test.cc @@ -50,12 +50,22 @@ TEST_F(TestLLVMTypes, TestFound) { types_->i64_type()); EXPECT_EQ(types_->DataVecType(arrow::timestamp(arrow::TimeUnit::MILLI)), types_->i64_type()); + + EXPECT_EQ(types_->IRType(arrow::Type::STRING), types_->i8_ptr_type()); + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::boolean())), types_->i1_type()); + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::int32())), types_->i32_type()); + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::int64())), types_->i64_type()); + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::float32())), types_->float_type()); + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::float64())), types_->double_type()); + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::utf8())), types_->i8_ptr_type()); } TEST_F(TestLLVMTypes, TestNotFound) { EXPECT_EQ(types_->IRType(arrow::Type::SPARSE_UNION), nullptr); EXPECT_EQ(types_->IRType(arrow::Type::DENSE_UNION), nullptr); EXPECT_EQ(types_->DataVecType(arrow::null()), nullptr); + // not support nested list type + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::list(arrow::utf8()))), nullptr); } } // namespace gandiva diff --git a/cpp/src/gandiva/lvalue.h b/cpp/src/gandiva/lvalue.h index df292855b69af..04862dc9d18c8 100644 --- a/cpp/src/gandiva/lvalue.h +++ b/cpp/src/gandiva/lvalue.h @@ -46,9 +46,36 @@ class GANDIVA_EXPORT LValue { if (length_ != NULLPTR) { params->push_back(length_); } + if (validity_ != NULLPTR) { + params->push_back(validity_); + } } - private: + virtual std::string to_string() { + std::string s = "Base LValue"; + + std::string str1 = "data:"; + if (data_) { + llvm::raw_string_ostream output1(str1); + data_->print(output1); + } + + std::string str2 = "length:"; + if (length_) { + llvm::raw_string_ostream output2(str2); + length_->print(output2); + } + + std::string str3 = "validity:"; + if (validity_) { + llvm::raw_string_ostream output3(str3); + validity_->print(output3); + } + + return s + "\n" + str1 + "\n" + str2 + "\n" + str3; + } + + protected: llvm::Value* data_; llvm::Value* length_; llvm::Value* validity_; @@ -74,4 +101,48 @@ class GANDIVA_EXPORT DecimalLValue : public LValue { llvm::Value* scale_; }; +class GANDIVA_EXPORT ListLValue : public LValue { + public: + ListLValue(llvm::Value* data, llvm::Value* child_offsets, llvm::Value* offsets_length, + llvm::Value* validity = NULLPTR) + : LValue(data, NULLPTR, validity), + child_offsets_(child_offsets), + offsets_length_(offsets_length) { + } + + llvm::Value* child_offsets() { return child_offsets_; } + + llvm::Value* offsets_length() { return offsets_length_; } + + void AppendFunctionParams(std::vector* params) override { + LValue::AppendFunctionParams(params); + params->push_back(child_offsets_); + params->push_back(offsets_length_); + params->push_back(validity_); + } + + virtual std::string to_string() override { + std::string s = "List LValue"; + s += " " + LValue::to_string(); + + std::string str1 = "child_offsets_:"; + if (child_offsets_) { + llvm::raw_string_ostream output1(str1); + child_offsets_->print(output1); + } + + std::string str2 = "offsets_length_:"; + if (offsets_length_) { + llvm::raw_string_ostream output2(str2); + offsets_length_->print(output2); + } + + return s + "\n" + str1 + "\n" + str2; + } + + private: + llvm::Value* child_offsets_; + llvm::Value* offsets_length_; +}; + } // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/types.h b/cpp/src/gandiva/precompiled/types.h index 83bbdee208562..117b27b2808dd 100644 --- a/cpp/src/gandiva/precompiled/types.h +++ b/cpp/src/gandiva/precompiled/types.h @@ -19,6 +19,8 @@ #include + +#include "gandiva/array_ops.h" #include "gandiva/gdv_function_stubs.h" // Use the same names as in arrow data types. Makes it easy to write pre-processor macros. diff --git a/cpp/src/gandiva/projector.cc b/cpp/src/gandiva/projector.cc index 54de03963f7e7..0181ece3b2607 100644 --- a/cpp/src/gandiva/projector.cc +++ b/cpp/src/gandiva/projector.cc @@ -184,7 +184,10 @@ Status Projector::Evaluate(const arrow::RecordBatch& batch, ValidateArrayDataCapacity(*array_data, *(output_fields_[idx]), num_rows)); ++idx; } - return llvm_generator_->Execute(batch, selection_vector, output_data_vecs); + ARROW_RETURN_NOT_OK( + llvm_generator_->Execute(batch, selection_vector, output_data_vecs)); + + return Status::OK(); } Status Projector::Evaluate(const arrow::RecordBatch& batch, arrow::MemoryPool* pool, @@ -215,14 +218,45 @@ Status Projector::Evaluate(const arrow::RecordBatch& batch, llvm_generator_->Execute(batch, selection_vector, output_data_vecs)); // Create and return array arrays. + int const child_data_buffer_index = 1; + int const int_data_size = 4; + int const double_data_size = 8; output->clear(); for (auto& array_data : output_data_vecs) { + if (array_data->type->id() == arrow::Type::LIST) { + auto child_data = array_data->child_data[0]; + int64_t child_data_size = 1; + if (arrow::is_binary_like(child_data->type->id())) { + /* when allocate array data, child data length is an initialized value, + * after calculating, child data offsets buffer has been resized for results, + * but array data length is unchanged. + * We should recalculate child data length and make ArrayData with new length + * + * Otherwise, child data offsets buffer length is data length + 1 + * and offset data is int32_t, need use buffer->size()/4 - 1 + */ + child_data_size = child_data->buffers[child_data_buffer_index]->size() / int_data_size - 1; + } else if (child_data->type->id() == arrow::Type::INT32) { + child_data_size = child_data->buffers[child_data_buffer_index]->size() / int_data_size; + } else if (child_data->type->id() == arrow::Type::INT64) { + child_data_size = child_data->buffers[child_data_buffer_index]->size() / double_data_size; + } else if (child_data->type->id() == arrow::Type::FLOAT) { + child_data_size = child_data->buffers[child_data_buffer_index]->size() / int_data_size; + } else if (child_data->type->id() == arrow::Type::DOUBLE) { + child_data_size = child_data->buffers[child_data_buffer_index]->size() / double_data_size; + } + auto new_child_data = arrow::ArrayData::Make( + child_data->type, child_data_size, child_data->buffers, child_data->offset); + array_data = arrow::ArrayData::Make(array_data->type, array_data->length, + array_data->buffers, {new_child_data}, + array_data->null_count, array_data->offset); + } + output->push_back(arrow::MakeArray(array_data)); } return Status::OK(); } -// TODO : handle complex vectors (list/map/..) Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records, arrow::MemoryPool* pool, ArrayDataPtr* array_data) const { @@ -243,6 +277,23 @@ Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records, buffers.push_back(std::move(offsets_buffer)); } + if (type_id == arrow::Type::LIST) { + auto offsets_len = arrow::bit_util::BytesForBits((num_records + 1) * 32); + + ARROW_ASSIGN_OR_RAISE(auto offsets_buffer, arrow::AllocateBuffer(offsets_len, pool)); + buffers.push_back(std::move(offsets_buffer)); + + if (arrow::is_binary_like(type->field(0)->type()->id())) { + // child offsets length is internal data length + 1 + // offsets element is int32 + // so here i just allocate extra 32 bit for extra 1 length + ARROW_ASSIGN_OR_RAISE( + auto child_offsets_buffer, + arrow::AllocateResizableBuffer(arrow::bit_util::BytesForBits(32), pool)); + buffers.push_back(std::move(child_offsets_buffer)); + } + } + // The output vector always has a data array. int64_t data_len; if (arrow::is_primitive(type_id) || type_id == arrow::Type::DECIMAL) { @@ -251,6 +302,8 @@ Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records, } else if (arrow::is_binary_like(type_id)) { // we don't know the expected size for varlen output vectors. data_len = 0; + } else if (type_id == arrow::Type::LIST) { + data_len = 0; } else { return Status::Invalid("Unsupported output data type " + type->ToString()); } @@ -263,7 +316,27 @@ Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records, } buffers.push_back(std::move(data_buffer)); - *array_data = arrow::ArrayData::Make(type, num_records, std::move(buffers)); + ARROW_ASSIGN_OR_RAISE(auto data_valid_buffer, arrow::AllocateResizableBuffer(data_len, pool)); + + if (type->id() == arrow::Type::LIST) { + auto internal_type = type->field(0)->type(); + ArrayDataPtr child_data; + if (arrow::is_primitive(internal_type->id())) { + child_data = arrow::ArrayData::Make(internal_type, 0 /*initialize length*/, + {std::move(data_valid_buffer), std::move(buffers[2])}, 0); + } + if (arrow::is_binary_like(internal_type->id())) { + child_data = arrow::ArrayData::Make( + internal_type, 0 /*initialize length*/, + {nullptr, std::move(buffers[2]), std::move(buffers[3])}, 0); + } + *array_data = arrow::ArrayData::Make( + type, num_records, {std::move(buffers[0]), std::move(buffers[1])}, {child_data}); + + } else { + *array_data = arrow::ArrayData::Make(type, num_records, std::move(buffers)); + } + return Status::OK(); } @@ -312,7 +385,10 @@ Status Projector::ValidateArrayDataCapacity(const arrow::ArrayData& array_data, int64_t data_len = array_data.buffers[1]->capacity(); ARROW_RETURN_IF(data_len < min_data_len, Status::Invalid("Data buffer too small for ", field.name())); - } else { + } else if (type_id == arrow::Type::LIST) { + return Status::OK(); + } + else { return Status::Invalid("Unsupported output data type " + field.type()->ToString()); } @@ -339,4 +415,5 @@ std::shared_ptr Projector::GetSecondaryCacheKey(std::string prima return arrow::Buffer::FromString(key); } + } // namespace gandiva diff --git a/cpp/src/gandiva/projector.h b/cpp/src/gandiva/projector.h index 24ec11e3eab59..53d0ef6d62431 100644 --- a/cpp/src/gandiva/projector.h +++ b/cpp/src/gandiva/projector.h @@ -154,14 +154,14 @@ class GANDIVA_EXPORT Projector { bool GetBuiltFromCache(); void Clear(); + /// Allocate an ArrowData of length 'length'. + Status AllocArrayData(const DataTypePtr& type, int64_t num_records, + arrow::MemoryPool* pool, ArrayDataPtr* array_data) const; private: Projector(std::unique_ptr llvm_generator, SchemaPtr schema, const FieldVector& output_fields, std::shared_ptr); - /// Allocate an ArrowData of length 'length'. - Status AllocArrayData(const DataTypePtr& type, int64_t num_records, - arrow::MemoryPool* pool, ArrayDataPtr* array_data) const; /// Validate that the ArrayData has sufficient capacity to accommodate 'num_records'. Status ValidateArrayDataCapacity(const arrow::ArrayData& array_data, diff --git a/cpp/src/gandiva/tests/CMakeLists.txt b/cpp/src/gandiva/tests/CMakeLists.txt index b89c0ac225209..bc607702126af 100644 --- a/cpp/src/gandiva/tests/CMakeLists.txt +++ b/cpp/src/gandiva/tests/CMakeLists.txt @@ -25,6 +25,7 @@ add_gandiva_test(binary_test) add_gandiva_test(date_time_test) add_gandiva_test(to_string_test) add_gandiva_test(utf8_test) +add_gandiva_test(list_test) add_gandiva_test(hash_test) add_gandiva_test(in_expr_test) add_gandiva_test(null_validity_test) diff --git a/cpp/src/gandiva/tests/list_test.cc b/cpp/src/gandiva/tests/list_test.cc new file mode 100644 index 0000000000000..abc7b5d7091b8 --- /dev/null +++ b/cpp/src/gandiva/tests/list_test.cc @@ -0,0 +1,397 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include + +#include "arrow/memory_pool.h" +#include "arrow/status.h" +#include "gandiva/execution_context.h" +#include "gandiva/precompiled/types.h" +#include "gandiva/projector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float32; +using arrow::float64; +using arrow::int32; +using arrow::int64; +using arrow::utf8; +using std::string; +using std::vector; + +class TestList : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +template +void _build_list_array(const vector& values, const vector& length, + const vector& validity, arrow::MemoryPool* pool, + ArrayPtr* array, const vector& innerValidity = {}) { + size_t sum = 0; + for (auto& len : length) { + sum += len; + } + EXPECT_TRUE(values.size() == sum); + EXPECT_TRUE(length.size() == validity.size()); + + auto value_builder = std::make_shared(pool); + auto builder = std::make_shared(pool, value_builder); + int i = 0; + for (size_t l = 0; l < length.size(); l++) { + if (validity[l]) { + auto status = builder->Append(); + for (int j = 0; j < length[l]; j++) { + if (innerValidity.size() > (size_t)j && innerValidity[j] == false) { + auto v = value_builder->AppendNull(); + } else { + ASSERT_OK(value_builder->Append(values[i])); + } + i++; + } + } else { + ASSERT_OK(builder->AppendNull()); + for (int j = 0; j < length[l]; j++) { + i++; + } + } + } + ASSERT_OK(builder->Finish(array)); +} + +template +void _build_list_array2(const vector& values, const vector& length, + const vector& validity, const vector& innerValidity, arrow::MemoryPool* pool, + ArrayPtr* array) { + return _build_list_array(values, length, validity, pool, array); + } + +/* + * expression: + * input: a + * output: res + * typeof(a) can be list / list / list + */ +void _test_list_type_field_alias(DataTypePtr type, ArrayPtr array, + arrow::MemoryPool* pool, int num_records = 5) { + auto field_a = field("a", type); + auto schema = arrow::schema({field_a}); + auto result = field("res", type); + + std::cout << array->ToString() << std::endl; + assert(array->length() == num_records); + + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array}); + + // Make expression + std::cout << "Make expression" << std::endl; + auto field_a_node = TreeExprBuilder::MakeField(field_a); + auto expr = TreeExprBuilder::MakeExpression(field_a_node, result); + + std::cout << "Build a projector for the expressions." << std::endl; + // Build a projector for the expressions. + std::shared_ptr projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + std::cout << "status message: " << status.message() << std::endl; + EXPECT_TRUE(status.ok()) << status.message(); + + std::cout << "Evaluate expression" << std::endl; + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + std::cout << "Check results" << std::endl; + EXPECT_ARROW_ARRAY_EQUALS(array, outputs[0]); + // EXPECT_ARROW_ARRAY_EQUALS will not check the length of child data, but + // ArrayData::Slice method will check length. ArrayData::ToString method will call + // ArrayData::Slice method + EXPECT_TRUE(array->ToString() == outputs[0]->ToString()); + EXPECT_TRUE(array->null_count() == outputs[0]->null_count()); +} + +/* +TEST_F(TestList, TestArrayRemove) { + // schema for input fields + auto field_b = field("b", int32()); + + auto field_a = field("a", list(int32())); + auto schema = arrow::schema({field_a, field_b}); + + // output fields + auto res = field("res", list(int32())); + + // Create a row-batch with some sample data + int num_records = 2; + auto array_b = + MakeArrowArrayInt32({42, 42}, {true, true}); + + ArrayPtr array_a; + _build_list_array2( + {10, 42, 30, 42, 70, 80}, + {3, 3}, {true, true}, {true, true, true, true, true, true}, pool_, &array_a); + + // expected output + ArrayPtr exp1; + _build_list_array2( + {10, 30, 70, 80}, + {2, 2}, {true, true}, {true, true, true, true}, pool_, &exp1); + + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b}); + + auto expr = TreeExprBuilder::MakeExpression("array_remove", {field_a, field_b}, res); + + // Build a projector for the expressions. + std::shared_ptr projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + std::cout << "LR Test 2 " << std::endl; + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp1, outputs.at(0)); + + //Try the second method. + arrow::ArrayDataVector outputs2; + std::shared_ptr listDt = std::make_shared(); + std::shared_ptr dt = std::make_shared(listDt); + + + int num_records2 = 5; + std::vector> buffers; + + int64_t size = 20; + auto bitmap_buffer = arrow::AllocateBuffer(size, pool_); + buffers.push_back(*std::move(bitmap_buffer)); + auto offsets_len = arrow::bit_util::BytesForBits((num_records2 + 1) * 32); + + auto offsets_buffer = arrow::AllocateBuffer(offsets_len*10, pool_); + buffers.push_back(*std::move(offsets_buffer)); + +std::vector> buffers2; +auto bitmap_buffer2 = arrow::AllocateBuffer(size, pool_); + buffers2.push_back(*std::move(bitmap_buffer2)); + + auto offsets_buffer2 = arrow::AllocateBuffer(offsets_len, pool_); + buffers2.push_back(*std::move(offsets_buffer2)); +std::shared_ptr dt2 = std::make_shared(); + + auto array_data_child = arrow::ArrayData::Make(dt2, num_records2, buffers2, 0, 0); + array_data_child->buffers = std::move(buffers2); + + std::vector> kids; + kids.push_back(array_data_child); + + +auto array_data = arrow::ArrayData::Make(dt, num_records2, buffers, kids, 0, 0); +array_data->buffers = std::move(buffers); +outputs2.push_back(array_data); + + + status = projector->Evaluate(*(in_batch.get()), outputs2); + EXPECT_TRUE(status.ok()) << status.message(); + arrow::ArrayData ad = *outputs2.at(0); + arrow::ArraySpan sp(*ad.child_data.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp1, sp.ToArray()); + + + + +for (auto& array_data : outputs2) { + auto child_data = array_data->child_data[0]; + int64_t child_data_size = 1; + if (arrow::is_binary_like(child_data->type->id())) { + child_data_size = child_data->buffers[1]->size() / 4 - 1; + } else if (child_data->type->id() == arrow::Type::INT32) { + child_data_size = child_data->buffers[1]->size() / 4; + } else if (child_data->type->id() == arrow::Type::INT64) { + child_data_size = child_data->buffers[1]->size() / 8; + } else if (child_data->type->id() == arrow::Type::FLOAT) { + child_data_size = child_data->buffers[1]->size() / 4; + } else if (child_data->type->id() == arrow::Type::DOUBLE) { + child_data_size = child_data->buffers[1]->size() / 8; + } + auto new_child_data = arrow::ArrayData::Make( + child_data->type, child_data_size, child_data->buffers, child_data->offset); + array_data = arrow::ArrayData::Make(array_data->type, array_data->length, + array_data->buffers, {new_child_data}, + array_data->null_count, array_data->offset); + + + auto newArray = arrow::MakeArray(array_data); + //arrow::ArraySpan sp(newArray); + EXPECT_ARROW_ARRAY_EQUALS(exp1, newArray); +} + + +{ + std::shared_ptr listDt = std::make_shared(); + std::shared_ptr dt = std::make_shared(listDt); + +ArrayDataPtr output_data; + auto s = projector->AllocArrayData(dt, num_records2, pool_, &output_data); + ArrayDataVector output_data_vecs; + output_data_vecs.push_back(output_data); + + status = projector->Evaluate(*(in_batch.get()), output_data_vecs); + EXPECT_TRUE(status.ok()) << status.message(); + arrow::ArraySpan sp(*output_data_vecs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp1, sp.ToArray()); + } +} + +TEST_F(TestList, TestListArrayInt32) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast(&ctx); + int32_t data[] = {11, 2, 23, 42}; + int32_t entry_offsets_len = 4; + int32_t contains_data = 42; + + EXPECT_EQ( + array_int32_contains_int32(ctx_ptr, data, entry_offsets_len, + contains_data), + true); +} + + +TEST_F(TestList, TestListInt32LiteralContains) { + // schema for input fields + auto field_a = field("a", list(int32())); + auto field_b = field("b", int32()); + auto schema = arrow::schema({field_a, field_b}); + + // output fields + auto res = field("res", boolean()); + + // Create a row-batch with some sample data + int num_records = 5; + ArrayPtr array_a; + _build_list_array( + {1, 5, 19, 42, 57}, + {1, 1, 1, 1, 1}, {true, true, true, true, true}, pool_, &array_a); + + auto array_b = + MakeArrowArrayInt32({42, 42, 42, 42, 42}); + + // expected output + auto exp = MakeArrowArrayBool({false, false, false, true, false}, + {true, true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b}); + + std::vector field_nodes; + auto node = TreeExprBuilder::MakeField(field_a); + field_nodes.push_back(node); + + auto node2 = TreeExprBuilder::MakeLiteral(42); + field_nodes.push_back(node2); + + auto func_node = TreeExprBuilder::MakeFunction("array_contains", field_nodes, res->type()); + auto expr = TreeExprBuilder::MakeExpression(func_node, res); + //////// + + // Build a projector for the expressions. + std::shared_ptr projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestList, TestListInt32Contains) { + // schema for input fields + auto field_a = field("a", list(int32())); + auto field_b = field("b", int32()); + auto schema = arrow::schema({field_a, field_b}); + + // output fields + auto res = field("res", boolean()); + + // Create a row-batch with some sample data + int num_records = 5; + ArrayPtr array_a; + _build_list_array( + {1, 5, 19, 42, 57}, + {1, 1, 1, 1, 1}, {true, true, true, true, true}, pool_, &array_a); + + auto array_b = + MakeArrowArrayInt32({42, 42, 42, 42, 42}); + + // expected output + auto exp = MakeArrowArrayBool({false, false, false, true, false}, + {true, true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b}); + + // build expressions. + // array_contains(a, b) + auto expr = TreeExprBuilder::MakeExpression("array_contains", {field_a, field_b}, res); + + // Build a projector for the expressions. + std::shared_ptr projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestList, TestListFloat32) { + ArrayPtr array; + _build_list_array( + {1.1f, 11.1f, 22.2f, 111.1f, 222.2f, 333.3f, 1111.1f, 2222.2f, 3333.3f, 4444.4f, + 11111.1f, 22222.2f, 33333.3f, 44444.4f, 55555.5f}, + {1, 2, 3, 4, 5}, {true, true, true, true, true}, pool_, &array); + _test_list_type_field_alias(list(float32()), array, pool_); +} + +TEST_F(TestList, TestListFloat64) { + ArrayPtr array; + _build_list_array( + {1.1, 1.11, 2.22, 1.111, 2.222, 3.333, 1.1111, 2.2222, 3.3333, 4.4444, 1.11111, + 2.22222, 3.33333, 4.44444, 5.55555}, + {1, 2, 4, 3, 5}, {true, false, true, true, true}, pool_, &array); + _test_list_type_field_alias(list(float64()), array, pool_); +}*/ + +} // namespace gandiva diff --git a/cpp/src/gandiva/tests/projector_build_validation_test.cc b/cpp/src/gandiva/tests/projector_build_validation_test.cc index 5b86844f940bf..82b59ef19ad75 100644 --- a/cpp/src/gandiva/tests/projector_build_validation_test.cc +++ b/cpp/src/gandiva/tests/projector_build_validation_test.cc @@ -26,6 +26,7 @@ namespace gandiva { using arrow::boolean; using arrow::float32; using arrow::int32; +using arrow::utf8; class TestProjector : public ::testing::Test { public: @@ -80,7 +81,7 @@ TEST_F(TestProjector, TestNotMatchingDataType) { TEST_F(TestProjector, TestNotSupportedDataType) { // schema for input fields - auto field0 = field("f0", list(int32())); + auto field0 = field("f0", map(utf8(), int32())); auto schema = arrow::schema({field0}); // output fields @@ -94,7 +95,7 @@ TEST_F(TestProjector, TestNotSupportedDataType) { std::shared_ptr projector; auto status = Projector::Make(schema, {lt_expr}, TestConfiguration(), &projector); EXPECT_TRUE(status.IsExpressionValidationError()); - std::string expected_error = "Field f0 has unsupported data type list"; + std::string expected_error = "Field f0 has unsupported data type map"; EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); } diff --git a/java/gandiva/CMakeLists.txt b/java/gandiva/CMakeLists.txt index 629ab2fb347d8..60762f6307c06 100644 --- a/java/gandiva/CMakeLists.txt +++ b/java/gandiva/CMakeLists.txt @@ -38,7 +38,7 @@ set(GANDIVA_PROTO_DIR ${CMAKE_CURRENT_SOURCE_DIR}/proto) get_filename_component(GANDIVA_PROTO_FILE_ABSOLUTE ${GANDIVA_PROTO_DIR}/Types.proto ABSOLUTE) -find_package(Protobuf REQUIRED) +find_package(Protobuf CONFIG REQUIRED) add_custom_command(OUTPUT ${GANDIVA_PROTO_OUTPUT_FILES} COMMAND protobuf::protoc --proto_path ${GANDIVA_PROTO_DIR} --cpp_out ${GANDIVA_PROTO_OUTPUT_DIR} ${GANDIVA_PROTO_FILE_ABSOLUTE} diff --git a/java/gandiva/pom.xml b/java/gandiva/pom.xml index bed66b427e625..d2df653d5e1fe 100644 --- a/java/gandiva/pom.xml +++ b/java/gandiva/pom.xml @@ -30,6 +30,11 @@ ../../../cpp/release-build + + org.apache.arrow + arrow-format + ${project.version} + org.apache.arrow arrow-memory-core diff --git a/java/gandiva/proto/Types.proto b/java/gandiva/proto/Types.proto index eb0d996b92e63..a5c4df474db37 100644 --- a/java/gandiva/proto/Types.proto +++ b/java/gandiva/proto/Types.proto @@ -85,6 +85,7 @@ message ExtGandivaType { optional TimeUnit timeUnit = 6; // used by TIME32/TIME64 optional string timeZone = 7; // used by TIMESTAMP optional IntervalType intervalType = 8; // used by INTERVAL + optional GandivaType listType = 9; //used by LIST } message Field { diff --git a/java/gandiva/src/main/cpp/expression_registry_helper.cc b/java/gandiva/src/main/cpp/expression_registry_helper.cc index 6765df3b9727f..cc1ed04194861 100644 --- a/java/gandiva/src/main/cpp/expression_registry_helper.cc +++ b/java/gandiva/src/main/cpp/expression_registry_helper.cc @@ -136,6 +136,18 @@ void ArrowToProtobuf(DataTypePtr type, types::ExtGandivaType* gandiva_data_type) gandiva_data_type->set_type(types::GandivaType::INTERVAL); gandiva_data_type->set_intervaltype(types::IntervalType::DAY_TIME); break; + case arrow::Type::LIST: { + gandiva_data_type->set_type(types::GandivaType::LIST); + if (type->num_fields() <= 0) { + break; + } + if (type->fields()[0]->type()->id() != arrow::Type::LIST) { + types::ExtGandivaType gt; + ArrowToProtobuf(type->fields()[0]->type(), >); + gandiva_data_type->set_listtype(gt.type()); + } + break; + } default: // un-supported types. test ensures that // when one of these are added build breaks. diff --git a/java/gandiva/src/main/cpp/jni_common.cc b/java/gandiva/src/main/cpp/jni_common.cc index d5e54f38e3692..7a631ad856c47 100644 --- a/java/gandiva/src/main/cpp/jni_common.cc +++ b/java/gandiva/src/main/cpp/jni_common.cc @@ -82,10 +82,16 @@ jclass configuration_builder_class_; // refs for self. static jclass gandiva_exception_; static jclass vector_expander_class_; +static jclass listvector_expander_class_; static jclass vector_expander_ret_class_; +static jclass list_expander_ret_class_; static jmethodID vector_expander_method_; +static jmethodID listvector_expander_method_; static jfieldID vector_expander_ret_address_; static jfieldID vector_expander_ret_capacity_; +static jfieldID list_expander_ret_address_; +static jfieldID list_expander_valid_address_; +static jfieldID list_expander_ret_capacity_; static jclass secondary_cache_class_; static jmethodID cache_get_method_; @@ -125,16 +131,37 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { vector_expander_class_, "expandOutputVectorAtIndex", "(IJ)Lorg/apache/arrow/gandiva/evaluator/VectorExpander$ExpandResult;"); + jclass local_listexpander_class = + env->FindClass("org/apache/arrow/gandiva/evaluator/ListVectorExpander"); + listvector_expander_class_ = (jclass)env->NewGlobalRef(local_listexpander_class); + env->DeleteLocalRef(local_listexpander_class); + + listvector_expander_method_ = env->GetMethodID( + listvector_expander_class_, "expandOutputVectorAtIndex", + "(IJ)Lorg/apache/arrow/gandiva/evaluator/ListVectorExpander$ExpandResult;"); + jclass local_expander_ret_class = env->FindClass("org/apache/arrow/gandiva/evaluator/VectorExpander$ExpandResult"); vector_expander_ret_class_ = (jclass)env->NewGlobalRef(local_expander_ret_class); env->DeleteLocalRef(local_expander_ret_class); + jclass local_list_expander_ret_class = + env->FindClass("org/apache/arrow/gandiva/evaluator/ListVectorExpander$ExpandResult"); + list_expander_ret_class_ = (jclass)env->NewGlobalRef(local_list_expander_ret_class); + env->DeleteLocalRef(local_list_expander_ret_class); + vector_expander_ret_address_ = env->GetFieldID(vector_expander_ret_class_, "address", "J"); vector_expander_ret_capacity_ = env->GetFieldID(vector_expander_ret_class_, "capacity", "J"); + list_expander_ret_address_ = + env->GetFieldID(list_expander_ret_class_, "address", "J"); + list_expander_ret_capacity_ = + env->GetFieldID(list_expander_ret_class_, "capacity", "J"); + list_expander_valid_address_ = + env->GetFieldID(list_expander_ret_class_, "validityaddress", "J"); + jclass local_cache_class = env->FindClass("org/apache/arrow/gandiva/evaluator/JavaSecondaryCacheInterface"); secondary_cache_class_ = (jclass)env->NewGlobalRef(local_cache_class); @@ -164,11 +191,15 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) { env->DeleteGlobalRef(configuration_builder_class_); env->DeleteGlobalRef(gandiva_exception_); env->DeleteGlobalRef(vector_expander_class_); + env->DeleteGlobalRef(listvector_expander_class_); env->DeleteGlobalRef(vector_expander_ret_class_); + env->DeleteGlobalRef(list_expander_ret_class_); env->DeleteGlobalRef(secondary_cache_class_); env->DeleteGlobalRef(cache_buf_ret_class_); } +DataTypePtr SimpleProtoTypeToDataType(const types::GandivaType& gandiva_type); + DataTypePtr ProtoTypeToTime32(const types::ExtGandivaType& ext_type) { switch (ext_type.timeunit()) { case types::SEC: @@ -221,8 +252,13 @@ DataTypePtr ProtoTypeToInterval(const types::ExtGandivaType& ext_type) { } } -DataTypePtr ProtoTypeToDataType(const types::ExtGandivaType& ext_type) { - switch (ext_type.type()) { +DataTypePtr ProtoTypeToList(const types::ExtGandivaType& ext_type) { + DataTypePtr childType = SimpleProtoTypeToDataType(ext_type.listtype()); + return arrow::list(childType); +} + +DataTypePtr SimpleProtoTypeToDataType(const types::GandivaType& gandiva_type) { + switch (gandiva_type) { case types::NONE: return arrow::null(); case types::BOOL: @@ -257,6 +293,16 @@ DataTypePtr ProtoTypeToDataType(const types::ExtGandivaType& ext_type) { return arrow::date32(); case types::DATE64: return arrow::date64(); + default: + std::cerr << "Unknown data type: " << gandiva_type << "\n"; + return nullptr; + } +} + + + +DataTypePtr ProtoTypeToDataType(const types::ExtGandivaType& ext_type) { + switch (ext_type.type()) { case types::DECIMAL: // TODO: error handling return arrow::decimal(ext_type.precision(), ext_type.scale()); @@ -268,24 +314,36 @@ DataTypePtr ProtoTypeToDataType(const types::ExtGandivaType& ext_type) { return ProtoTypeToTimestamp(ext_type); case types::INTERVAL: return ProtoTypeToInterval(ext_type); - case types::FIXED_SIZE_BINARY: case types::LIST: - case types::STRUCT: + return ProtoTypeToList(ext_type); + case types::FIXED_SIZE_BINARY: case types::UNION: case types::DICTIONARY: case types::MAP: std::cerr << "Unhandled data type: " << ext_type.type() << "\n"; return nullptr; - default: - std::cerr << "Unknown data type: " << ext_type.type() << "\n"; + return SimpleProtoTypeToDataType(ext_type.type()); + } +} + +DataTypePtr ProtoTypeToDataType(const types::Field& f) { + const types::ExtGandivaType& ext_type = f.type(); + if (ext_type.type() == types::LIST) { + if (f.children().size() > 0 && f.children()[0].type().type() != types::LIST) { + DataTypePtr childType = ProtoTypeToDataType(f.children()[0].type()); + return arrow::list(childType); + } + std::cerr << "Unhandled list data type: " << ext_type.type() << "\n"; return nullptr; + } else { + return ProtoTypeToDataType(ext_type); } } FieldPtr ProtoTypeToField(const types::Field& f) { const std::string& name = f.name(); - DataTypePtr type = ProtoTypeToDataType(f.type()); + DataTypePtr type = ProtoTypeToDataType(f); bool nullable = true; if (f.has_nullable()) { nullable = f.nullable(); @@ -319,7 +377,7 @@ NodePtr ProtoTypeToFnNode(const types::FunctionNode& node) { children.push_back(n); } - + DataTypePtr return_type = ProtoTypeToDataType(node.returntype()); if (return_type == nullptr) { std::cerr << "Unknown return type for function: " << name << "\n"; @@ -602,7 +660,6 @@ Status make_record_batch_with_buf_addrs(SchemaPtr schema, int num_rows, auto validity = std::shared_ptr( new arrow::Buffer(reinterpret_cast(validity_addr), validity_size)); buffers.push_back(validity); - if (buf_idx >= in_bufs_len) { return Status::Invalid("insufficient number of in_buf_addrs"); } @@ -625,8 +682,61 @@ Status make_record_batch_with_buf_addrs(SchemaPtr schema, int num_rows, buffers.push_back(offsets); } - auto array_data = arrow::ArrayData::Make(field->type(), num_rows, std::move(buffers)); + + + +auto type = field->type(); +auto type_id = type->id(); + if (type_id == arrow::Type::LIST) { + + if (buf_idx >= in_bufs_len) { + return Status::Invalid("insufficient number of in_buf_addrs"); + } + + // add offsets buffer for variable-len fields. + jlong offsets_addr = in_buf_addrs[buf_idx++]; + jlong offsets_size = in_buf_sizes[sz_idx++]; + auto offsets = std::shared_ptr( + new arrow::Buffer(reinterpret_cast(offsets_addr), offsets_size)); + buffers.push_back(offsets); + if (arrow::is_binary_like(type->field(0)->type()->id())) { + // child offsets length is internal data length + 1 + // offsets element is int32 + // so here i just allocate extra 32 bit for extra 1 length + jlong offsets_addr = in_buf_addrs[buf_idx++]; + jlong offsets_size = in_buf_sizes[sz_idx++]; + + auto child_offsets_buffer = std::shared_ptr( new arrow::Buffer(reinterpret_cast(offsets_addr), offsets_size)); + + buffers.push_back(std::move(child_offsets_buffer)); + } + } + + if (type->id() == arrow::Type::LIST) { + jlong offsets_addr = in_buf_addrs[buf_idx++]; + jlong offsets_size = in_buf_sizes[sz_idx++]; + auto data_buffer = std::shared_ptr( new arrow::Buffer(reinterpret_cast(offsets_addr), offsets_size)); + auto internal_type = type->field(0)->type(); + std::shared_ptr child_data; + if (arrow::is_primitive(internal_type->id())) { + child_data = arrow::ArrayData::Make(internal_type, 0, + {std::move(buffers[2]), std::move(data_buffer)}); + } + if (arrow::is_binary_like(internal_type->id())) { + //LR TODO need this for strings I think. + //std::cout << "LR New ArrayData List NYI 2" << std::endl; + //child_data = arrow::ArrayData::Make( + // internal_type, 0, + // {nullptr, std::move(data_buffer), std::move(child_data)}, 0); + } + + auto array_data = arrow::ArrayData::Make(type, num_rows, {std::move(buffers[0]), std::move(buffers[1])}, {child_data}); columns.push_back(array_data); + + } else { + auto array_data = arrow::ArrayData::Make(type, num_rows, std::move(buffers)); + columns.push_back(array_data); + } } *batch = arrow::RecordBatch::Make(schema, num_rows, columns); return Status::OK(); @@ -797,12 +907,15 @@ JNIEXPORT jlong JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_build /// class JavaResizableBuffer : public arrow::ResizableBuffer { public: - JavaResizableBuffer(JNIEnv* env, jobject jexpander, int32_t vector_idx, uint8_t* buffer, - int32_t len) + JavaResizableBuffer(JNIEnv* env, jobject jexpander, jmethodID jmethod, int32_t vector_idx, uint8_t* buffer, + int32_t len, bool isListVec = false) : ResizableBuffer(buffer, len), env_(env), jexpander_(jexpander), - vector_idx_(vector_idx) { + vector_idx_(vector_idx), + method_(jmethod), + isList(isListVec) + { size_ = 0; } @@ -810,27 +923,44 @@ class JavaResizableBuffer : public arrow::ResizableBuffer { Status Reserve(const int64_t new_capacity) override; - private: + public: JNIEnv* env_; jobject jexpander_; + jmethodID method_; int32_t vector_idx_; + bool isList; }; Status JavaResizableBuffer::Reserve(const int64_t new_capacity) { // callback into java to expand the buffer - jobject ret = env_->CallObjectMethod(jexpander_, vector_expander_method_, vector_idx_, + jobject ret = env_->CallObjectMethod(jexpander_, method_, vector_idx_, new_capacity); if (env_->ExceptionCheck()) { env_->ExceptionDescribe(); env_->ExceptionClear(); - return Status::OutOfMemory("buffer expand failed in java"); + std::cout << "Buffer expand failed. New capacity is " << new_capacity << + " vector id " << vector_idx_ << " expander method " << method_ << + " jexpander_ " << jexpander_ << std::endl; + return Status::OutOfMemory("buffer expand failed in java."); } - jlong ret_address = env_->GetLongField(ret, vector_expander_ret_address_); - jlong ret_capacity = env_->GetLongField(ret, vector_expander_ret_capacity_); - data_ = reinterpret_cast(ret_address); - capacity_ = ret_capacity; + if (isList) { + jlong ret_address = env_->GetLongField(ret, list_expander_ret_address_); + jlong ret_capacity = env_->GetLongField(ret, list_expander_ret_capacity_); + jlong valid_address = env_->GetLongField(ret, list_expander_valid_address_); + + data_ = reinterpret_cast(ret_address); + capacity_ = ret_capacity; + validityBuffer = reinterpret_cast(valid_address); + } else { + jlong ret_address = env_->GetLongField(ret, vector_expander_ret_address_); + jlong ret_capacity = env_->GetLongField(ret, vector_expander_ret_capacity_); + + data_ = reinterpret_cast(ret_address); + capacity_ = ret_capacity; + } + return Status::OK(); } @@ -859,7 +989,7 @@ Status JavaResizableBuffer::Resize(const int64_t new_size, bool shrink_to_fit) { JNIEXPORT void JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector( - JNIEnv* env, jobject object, jobject jexpander, jlong module_id, jint num_rows, + JNIEnv* env, jobject object, jobject jexpander, jobject jListExpander, jlong module_id, jint num_rows, jlongArray buf_addrs, jlongArray buf_sizes, jint sel_vec_type, jint sel_vec_rows, jlong sel_vec_addr, jlong sel_vec_size, jlongArray out_buf_addrs, jlongArray out_buf_sizes) { @@ -898,7 +1028,6 @@ Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector( if (!status.ok()) { break; } - std::shared_ptr selection_vector; auto selection_buffer = std::make_shared( reinterpret_cast(sel_vec_addr), sel_vec_size); @@ -925,6 +1054,7 @@ Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector( break; } + std::shared_ptr outBufJava = nullptr; auto ret_types = holder->rettypes(); ArrayDataVector output; int buf_idx = 0; @@ -956,22 +1086,72 @@ Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector( "null"); break; } - buffers.push_back(std::make_shared( - env, jexpander, output_vector_idx, value_buf, data_sz)); + + buffers.push_back(std::make_shared( + env, jexpander, vector_expander_method_, output_vector_idx, value_buf, data_sz)); + } else if (field->type()->id() == arrow::Type::LIST) { + buffers.push_back(std::make_shared( + env, jexpander, vector_expander_method_, output_vector_idx, value_buf, data_sz)); } else { buffers.push_back(std::make_shared(value_buf, data_sz)); } + + + if (field->type()->id() == arrow::Type::LIST) { + std::vector> child_buffers; + + if (jListExpander == nullptr) { + status = Status::Invalid( + "expression has variable len output columns, but the jListExpander object is " + "null"); + break; + } + + data_sz = out_sizes[sz_idx++]; + CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len); + uint8_t* child_offset_buf = reinterpret_cast(out_bufs[buf_idx++]); + child_buffers.push_back(std::make_shared( + env, jListExpander, listvector_expander_method_, output_vector_idx, child_offset_buf, data_sz)); + + data_sz = out_sizes[sz_idx++]; + CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len); + uint8_t* child_data_buf = reinterpret_cast(out_bufs[buf_idx++]); + + outBufJava = std::make_shared( + env, jListExpander, listvector_expander_method_, output_vector_idx, child_data_buf, data_sz, true); + outBufJava->offsetBuffer = reinterpret_cast(out_bufs[1]); + outBufJava->offsetCapacity = out_sizes[1]; + outBufJava->validityBuffer = reinterpret_cast(out_bufs[2]); + child_buffers.push_back(outBufJava); + + std::shared_ptr dt2 = std::make_shared(); + if (field->type()->id() == arrow::Type::LIST && field->type()->num_fields() > 0) { + dt2 = field->type()->fields()[0]->type(); + } + + auto array_data_child = arrow::ArrayData::Make(dt2, output_row_count, child_buffers); + std::vector> kids; + kids.push_back(array_data_child); + auto array_data = arrow::ArrayData::Make(field->type(), output_row_count, buffers, kids); + array_data->child_data = std::move(kids); + output.push_back(array_data); + ++output_vector_idx; + } else { auto array_data = arrow::ArrayData::Make(field->type(), output_row_count, buffers); output.push_back(array_data); ++output_vector_idx; + } + } if (!status.ok()) { break; } + status = holder->projector()->Evaluate(*in_batch, selection_vector.get(), output); } while (0); + env->ReleaseLongArrayElements(buf_addrs, in_buf_addrs, JNI_ABORT); env->ReleaseLongArrayElements(buf_sizes, in_buf_sizes, JNI_ABORT); env->ReleaseLongArrayElements(out_buf_addrs, out_bufs, JNI_ABORT); @@ -1061,7 +1241,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_build // good to invoke the filter builder now status = Filter::Make(schema_ptr, condition_ptr, config, sec_cache, &filter); if (!status.ok()) { - ss << "Failed to make LLVM module due to " << status.message() << "\n"; + ss << "Failed to make LLVM module [2] due to " << status.message() << "\n"; releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env); goto err_out; } diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java index 0155af08234ad..6abc6719d63e6 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java @@ -17,9 +17,11 @@ package org.apache.arrow.gandiva.evaluator; +import java.util.ArrayList; import java.util.List; import java.util.Set; +import org.apache.arrow.flatbuf.Type; import org.apache.arrow.gandiva.exceptions.GandivaException; import org.apache.arrow.gandiva.ipc.GandivaTypes; import org.apache.arrow.gandiva.ipc.GandivaTypes.ExtGandivaType; @@ -32,7 +34,6 @@ import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.pojo.ArrowType; -import com.google.common.collect.Lists; import com.google.common.collect.Sets; import com.google.protobuf.InvalidProtocolBufferException; @@ -116,12 +117,20 @@ private static Set getSupportedFunctionsFromGandiva() throws String functionName = protoFunctionSignature.getName(); ArrowType returnType = getArrowType(protoFunctionSignature.getReturnType()); - List paramTypes = Lists.newArrayList(); + ArrowType returnListType = getArrowTypeSimple(protoFunctionSignature.getReturnType().getListType()); + List> paramTypes = new ArrayList>(); for (ExtGandivaType type : protoFunctionSignature.getParamTypesList()) { - paramTypes.add(getArrowType(type)); + ArrowType paramType = getArrowType(type); + ArrowType paramListType = getArrowTypeSimple(type.getListType()); + List paramArrowList = new ArrayList(); + paramArrowList.add(paramType); + if (paramType.getTypeID().getFlatbufID() == Type.List) { + paramArrowList.add(paramListType); + } + paramTypes.add(paramArrowList); } FunctionSignature functionSignature = new FunctionSignature(functionName, - returnType, paramTypes); + returnType, returnListType, paramTypes); supportedTypes.add(functionSignature); } } catch (InvalidProtocolBufferException invalidProtException) { @@ -130,8 +139,8 @@ private static Set getSupportedFunctionsFromGandiva() throws return supportedTypes; } - private static ArrowType getArrowType(ExtGandivaType type) { - switch (type.getType().getNumber()) { + private static ArrowType getArrowTypeSimple(GandivaType type) { + switch (type.getNumber()) { case GandivaType.BOOL_VALUE: return ArrowType.Bool.INSTANCE; case GandivaType.UINT8_VALUE: @@ -164,25 +173,15 @@ private static ArrowType getArrowType(ExtGandivaType type) { return new ArrowType.Date(DateUnit.DAY); case GandivaType.DATE64_VALUE: return new ArrowType.Date(DateUnit.MILLISECOND); - case GandivaType.TIMESTAMP_VALUE: - return new ArrowType.Timestamp(mapArrowTimeUnit(type.getTimeUnit()), null); - case GandivaType.TIME32_VALUE: - return new ArrowType.Time(mapArrowTimeUnit(type.getTimeUnit()), - BIT_WIDTH_32); - case GandivaType.TIME64_VALUE: - return new ArrowType.Time(mapArrowTimeUnit(type.getTimeUnit()), - BIT_WIDTH_64); case GandivaType.NONE_VALUE: return new ArrowType.Null(); case GandivaType.DECIMAL_VALUE: return new ArrowType.Decimal(0, 0, 128); - case GandivaType.INTERVAL_VALUE: - return new ArrowType.Interval(mapArrowIntervalUnit(type.getIntervalType())); + case GandivaType.LIST_VALUE: + return new ArrowType.List(); case GandivaType.FIXED_SIZE_BINARY_VALUE: case GandivaType.MAP_VALUE: case GandivaType.DICTIONARY_VALUE: - case GandivaType.LIST_VALUE: - case GandivaType.STRUCT_VALUE: case GandivaType.UNION_VALUE: default: assert false; @@ -190,6 +189,23 @@ private static ArrowType getArrowType(ExtGandivaType type) { return null; } + private static ArrowType getArrowType(ExtGandivaType type) { + switch (type.getType().getNumber()) { + case GandivaType.TIMESTAMP_VALUE: + return new ArrowType.Timestamp(mapArrowTimeUnit(type.getTimeUnit()), null); + case GandivaType.TIME32_VALUE: + return new ArrowType.Time(mapArrowTimeUnit(type.getTimeUnit()), + BIT_WIDTH_32); + case GandivaType.TIME64_VALUE: + return new ArrowType.Time(mapArrowTimeUnit(type.getTimeUnit()), + BIT_WIDTH_64); + case GandivaType.INTERVAL_VALUE: + return new ArrowType.Interval(mapArrowIntervalUnit(type.getIntervalType())); + default: + return getArrowTypeSimple(type.getType()); + } + } + private static TimeUnit mapArrowTimeUnit(GandivaTypes.TimeUnit timeUnit) { switch (timeUnit.getNumber()) { case GandivaTypes.TimeUnit.MICROSEC_VALUE: diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/FunctionSignature.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/FunctionSignature.java index d01881843de47..c5c6aeb5372b8 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/FunctionSignature.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/FunctionSignature.java @@ -17,6 +17,7 @@ package org.apache.arrow.gandiva.evaluator; +import java.util.ArrayList; import java.util.List; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -30,13 +31,18 @@ public class FunctionSignature { private final String name; private final ArrowType returnType; - private final List paramTypes; + private final ArrowType returnListType; + private final List> paramTypes; public ArrowType getReturnType() { return returnType; } - public List getParamTypes() { + public ArrowType getReturnListType() { + return returnListType; + } + + public List> getParamTypes() { return paramTypes; } @@ -48,14 +54,36 @@ public String getName() { * Ctor. * @param name - name of the function. * @param returnType - data type of return + * @param returnListType optional list type * @param paramTypes - data type of input args. */ - public FunctionSignature(String name, ArrowType returnType, List paramTypes) { + public FunctionSignature(String name, ArrowType returnType, ArrowType returnListType, + List> paramTypes) { this.name = name; this.returnType = returnType; + this.returnListType = returnListType; this.paramTypes = paramTypes; } + /** + * Ctor. + * @param name - name of the function. + * @param returnType - data type of return + * @param paramTypes - data type of input args. + */ + public FunctionSignature(String name, ArrowType returnType, List paramTypes) { + this.name = name; + this.returnType = returnType; + this.returnListType = ArrowType.Null.INSTANCE; + this.paramTypes = new ArrayList>(); + for (ArrowType paramType : paramTypes) { + List paramArrowList = new ArrayList(); + paramArrowList.add(paramType); + this.paramTypes.add(paramArrowList); + } + + } + /** * Override equals. * @param signature - signature to compare @@ -71,12 +99,13 @@ public boolean equals(Object signature) { final FunctionSignature other = (FunctionSignature) signature; return this.name.equalsIgnoreCase(other.name) && Objects.equal(this.returnType, other.returnType) && + Objects.equal(this.returnListType, other.returnListType) && Objects.equal(this.paramTypes, other.paramTypes); } @Override public int hashCode() { - return Objects.hashCode(this.name.toLowerCase(), this.returnType, this.paramTypes); + return Objects.hashCode(this.name.toLowerCase(), this.returnType, this.returnListType, this.paramTypes); } @Override @@ -84,6 +113,7 @@ public String toString() { return MoreObjects.toStringHelper(this) .add("name ", name) .add("return type ", returnType) + .add("return list type", returnListType) .add("param types ", paramTypes) .toString(); diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java index 293d51a87a5fd..f883ed7081547 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java @@ -50,6 +50,7 @@ native long buildProjector(Object cache, byte[] schemaBuf, byte[] exprListBuf, * and store the output in ValueVectors. Throws an exception in case of errors * * @param expander VectorExpander object. Used for callbacks from cpp. + * @param listExpander ListVectorExpander object. Used for callbacks from cpp. * @param moduleId moduleId representing expressions. Created using a call to * buildNativeCode * @param numRows Number of rows in the record batch @@ -63,7 +64,7 @@ native long buildProjector(Object cache, byte[] schemaBuf, byte[] exprListBuf, * @param outSizes The allocated size of the output buffers. On successful evaluation, * the result is stored in the output buffers */ - native void evaluateProjector(Object expander, long moduleId, int numRows, + native void evaluateProjector(Object expander, Object listExpander, long moduleId, int numRows, long[] bufAddrs, long[] bufSizes, int selectionVectorType, int selectionVectorSize, long selectionVectorBufferAddr, long selectionVectorBufferSize, diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ListVectorExpander.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ListVectorExpander.java new file mode 100644 index 0000000000000..1d02f38a4d591 --- /dev/null +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ListVectorExpander.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import org.apache.arrow.vector.complex.ListVector; + +/** + * This class provides the functionality to expand output ListVectors using a callback mechanism from + * gandiva. + */ +public class ListVectorExpander { + private final ListVector[] bufferVectors; + public static final int valueBufferIndex = 1; + public static final int validityBufferIndex = 0; + + public ListVectorExpander(ListVector[] bufferVectors) { + this.bufferVectors = bufferVectors; + } + + /** + * Result of ListVector expansion. + */ + public static class ExpandResult { + public long address; + public long capacity; + public long validityaddress; + + /** + * Result of expanding the buffer. + * @param address Data buffer address + * @param capacity Capacity + * @param validAdd Validity buffer address + * + */ + public ExpandResult(long address, long capacity, long validAdd) { + this.address = address; + this.capacity = capacity; + this.validityaddress = validAdd; + } + } + + /** + * Expand vector at specified index. This is used as a back call from jni, and is only + * relevant for ListVectors. + * + * @param index index of buffer in the list passed to jni. + * @param toCapacity the size to which the buffer should be expanded to. + * + * @return address and size of the buffer after expansion. + */ + public ExpandResult expandOutputVectorAtIndex(int index, long toCapacity) { + if (index >= bufferVectors.length || bufferVectors[index] == null) { + throw new IllegalArgumentException("invalid index " + index); + } + + ListVector vector = bufferVectors[index]; + while (vector.getDataVector().getFieldBuffers().get(ListVectorExpander.valueBufferIndex).capacity() < toCapacity) { + //Just realloc the data vector. + vector.getDataVector().reAlloc(); + } + + return new ExpandResult( + vector.getDataVector().getFieldBuffers().get(ListVectorExpander.valueBufferIndex).memoryAddress(), + vector.getDataVector().getFieldBuffers().get(ListVectorExpander.valueBufferIndex).capacity(), + vector.getDataVector().getFieldBuffers().get(ListVectorExpander.validityBufferIndex).memoryAddress()); + } + +} diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java index c146fce26c150..686539a169f57 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java @@ -22,16 +22,15 @@ import org.apache.arrow.gandiva.exceptions.EvaluatorClosedException; import org.apache.arrow.gandiva.exceptions.GandivaException; -import org.apache.arrow.gandiva.exceptions.UnsupportedTypeException; import org.apache.arrow.gandiva.expression.ArrowTypeHelper; import org.apache.arrow.gandiva.expression.ExpressionTree; import org.apache.arrow.gandiva.ipc.GandivaTypes; import org.apache.arrow.gandiva.ipc.GandivaTypes.SelectionVectorType; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.vector.BaseVariableWidthVector; -import org.apache.arrow.vector.FixedWidthVector; import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VariableWidthVector; +import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.ipc.message.ArrowBuffer; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.types.pojo.Schema; @@ -352,18 +351,21 @@ private void evaluate(int numRows, List buffers, List buf boolean hasVariableWidthColumns = false; BaseVariableWidthVector[] resizableVectors = new BaseVariableWidthVector[outColumns.size()]; + ListVector[] resizableListVectors = new ListVector[outColumns.size()]; + long[] outAddrs = new long[3 * outColumns.size()]; long[] outSizes = new long[3 * outColumns.size()]; + idx = 0; int outColumnIdx = 0; + final int listVectorBufferCount = 5; for (ValueVector valueVector : outColumns) { - boolean isFixedWith = valueVector instanceof FixedWidthVector; - boolean isVarWidth = valueVector instanceof VariableWidthVector; - if (!isFixedWith && !isVarWidth) { - throw new UnsupportedTypeException( - "Unsupported value vector type " + valueVector.getField().getFieldType()); + if (valueVector instanceof ListVector) { + outAddrs = new long[listVectorBufferCount * outColumns.size()]; + outSizes = new long[listVectorBufferCount * outColumns.size()]; } + boolean isVarWidth = valueVector instanceof VariableWidthVector; outAddrs[idx] = valueVector.getValidityBuffer().memoryAddress(); outSizes[idx++] = valueVector.getValidityBuffer().capacity(); if (isVarWidth) { @@ -374,19 +376,45 @@ private void evaluate(int numRows, List buffers, List buf // save vector to allow for resizing. resizableVectors[outColumnIdx] = (BaseVariableWidthVector) valueVector; } - outAddrs[idx] = valueVector.getDataBuffer().memoryAddress(); - outSizes[idx++] = valueVector.getDataBuffer().capacity(); + if (valueVector instanceof ListVector) { + hasVariableWidthColumns = true; + resizableListVectors[outColumnIdx] = (ListVector) valueVector; + List fieldBufs = ((ListVector) valueVector).getDataVector().getFieldBuffers(); + outAddrs[idx] = valueVector.getOffsetBuffer().memoryAddress(); + outSizes[idx++] = valueVector.getOffsetBuffer().capacity(); + + //vector valid + outAddrs[idx] = ((ListVector) valueVector).getDataVector().getFieldBuffers() + .get(ListVectorExpander.validityBufferIndex).memoryAddress(); + outSizes[idx++] = ((ListVector) valueVector).getDataVector().getFieldBuffers() + .get(ListVectorExpander.validityBufferIndex).capacity(); + + //vector offset + outAddrs[idx] = ((ListVector) valueVector).getDataVector().getFieldBuffers() + .get(ListVectorExpander.valueBufferIndex).memoryAddress(); + outSizes[idx++] = ((ListVector) valueVector).getDataVector().getFieldBuffers() + .get(ListVectorExpander.valueBufferIndex).capacity(); + } else { + outAddrs[idx] = valueVector.getDataBuffer().memoryAddress(); + outSizes[idx++] = valueVector.getDataBuffer().capacity(); + } valueVector.setValueCount(selectionVectorRecordCount); outColumnIdx++; } - wrapper.evaluateProjector( hasVariableWidthColumns ? new VectorExpander(resizableVectors) : null, + hasVariableWidthColumns ? new ListVectorExpander(resizableListVectors) : null, this.moduleId, numRows, bufAddrs, bufSizes, selectionVectorType, selectionVectorRecordCount, selectionVectorAddr, selectionVectorSize, outAddrs, outSizes); + + for (ValueVector valueVector : outColumns) { + if (valueVector instanceof ListVector) { + ((ListVector) valueVector).setLastSet(selectionVectorRecordCount - 1); + } + } } /** diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ArrowTypeHelper.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ArrowTypeHelper.java index 90f8684b455a8..fd1be362b8404 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ArrowTypeHelper.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ArrowTypeHelper.java @@ -143,6 +143,15 @@ private static void initArrowTypeDate(ArrowType.Date dateType, } } + private static void initArrowTypeList(ArrowType.List listType, + ArrowType subType, + GandivaTypes.ExtGandivaType.Builder builder) throws GandivaException { + if (subType != null) { + builder.setListType(arrowTypeToProtobuf(subType).getType()); + } + builder.setType(GandivaTypes.GandivaType.LIST); + } + private static void initArrowTypeTime(ArrowType.Time timeType, GandivaTypes.ExtGandivaType.Builder builder) { short timeUnit = timeType.getUnit().getFlatbufID(); @@ -227,11 +236,13 @@ private static void initArrowTypeInterval(ArrowType.Interval interval, * Converts an arrow type into a protobuf. * * @param arrowType Arrow type to be converted + * @param subType optional arrow type for list/complex types + * @param builder the builder to use * @return Protobuf representing the arrow type */ - public static GandivaTypes.ExtGandivaType arrowTypeToProtobuf(ArrowType arrowType) + public static GandivaTypes.ExtGandivaType arrowTypeToProtobuf(ArrowType arrowType, ArrowType subType, + GandivaTypes.ExtGandivaType.Builder builder) throws GandivaException { - GandivaTypes.ExtGandivaType.Builder builder = GandivaTypes.ExtGandivaType.newBuilder(); byte typeId = arrowType.getTypeID().getFlatbufID(); switch (typeId) { @@ -284,6 +295,7 @@ public static GandivaTypes.ExtGandivaType arrowTypeToProtobuf(ArrowType arrowTyp break; } case Type.List: { // 12 + ArrowTypeHelper.initArrowTypeList((ArrowType.List) arrowType, subType, builder); break; } case Type.Struct_: { // 13 @@ -315,6 +327,31 @@ public static GandivaTypes.ExtGandivaType arrowTypeToProtobuf(ArrowType arrowTyp return builder.build(); } + + /** + * Converts an arrow type into a protobuf. + * + * @param arrowType Arrow type to be converted + * @param f field optional for list/complex types + * @return Protobuf representing the arrow type + */ + public static GandivaTypes.ExtGandivaType arrowTypeToProtobuf(ArrowType arrowType, ArrowType f) + throws GandivaException { + GandivaTypes.ExtGandivaType.Builder builder = GandivaTypes.ExtGandivaType.newBuilder(); + return arrowTypeToProtobuf(arrowType, f, builder); + } + + /** + * Converts an arrow type into a protobuf. + * + * @param arrowType Arrow type to be converted + * @return Protobuf representing the arrow type + */ + public static GandivaTypes.ExtGandivaType arrowTypeToProtobuf(ArrowType arrowType) + throws GandivaException { + return arrowTypeToProtobuf(arrowType, null); + } + /** * Converts an arrow field object to a protobuf. * @param field Arrow field to be converted @@ -323,12 +360,21 @@ public static GandivaTypes.ExtGandivaType arrowTypeToProtobuf(ArrowType arrowTyp public static GandivaTypes.Field arrowFieldToProtobuf(Field field) throws GandivaException { GandivaTypes.Field.Builder builder = GandivaTypes.Field.newBuilder(); builder.setName(field.getName()); - builder.setType(ArrowTypeHelper.arrowTypeToProtobuf(field.getType())); builder.setNullable(field.isNullable()); + ArrowType subType = null; + if (field.getChildren().size() > 0 && field.getChildren().get(0) + .getType().getTypeID().getFlatbufID() != Type.List) { + subType = field.getChildren().get(0).getType(); + } + + builder.setType(ArrowTypeHelper.arrowTypeToProtobuf(field.getType(), subType)); for (Field child : field.getChildren()) { - builder.addChildren(ArrowTypeHelper.arrowFieldToProtobuf(child)); + if (child.getType() != ArrowType.Null.INSTANCE) { + builder.addChildren(ArrowTypeHelper.arrowFieldToProtobuf(child)); + } } + return builder.build(); } diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FunctionNode.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FunctionNode.java index ead1e146d5d8c..e092facfd69ba 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FunctionNode.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FunctionNode.java @@ -19,9 +19,12 @@ import java.util.List; +import org.apache.arrow.flatbuf.Type; import org.apache.arrow.gandiva.exceptions.GandivaException; import org.apache.arrow.gandiva.ipc.GandivaTypes; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; + /** * Node representing an arbitrary function in an expression. @@ -30,18 +33,40 @@ class FunctionNode implements TreeNode { private final String function; private final List children; private final ArrowType retType; + private final ArrowType retListType; - FunctionNode(String function, List children, ArrowType retType) { + FunctionNode(String function, List children, Field inField) { + this.function = function; + this.children = children; + this.retType = inField.getType(); + if (inField.getChildren().size() > 0 && inField.getChildren().get(0) + .getType().getTypeID().getFlatbufID() != Type.List) { + this.retListType = inField.getChildren().get(0).getType(); + } else { + this.retListType = null; + } + + } + + FunctionNode(String function, List children, ArrowType inType) { + this.function = function; + this.children = children; + this.retType = inType; + this.retListType = null; + } + + FunctionNode(String function, List children, ArrowType inType, ArrowType listType) { this.function = function; this.children = children; - this.retType = retType; + this.retType = inType; + this.retListType = listType; } @Override public GandivaTypes.TreeNode toProtobuf() throws GandivaException { GandivaTypes.FunctionNode.Builder fnNode = GandivaTypes.FunctionNode.newBuilder(); fnNode.setFunctionName(function); - fnNode.setReturnType(ArrowTypeHelper.arrowTypeToProtobuf(retType)); + fnNode.setReturnType(ArrowTypeHelper.arrowTypeToProtobuf(retType, retListType)); for (TreeNode arg : children) { fnNode.addInArgs(arg.toProtobuf()); diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java index 8656e886aae24..f8337a25f8377 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java @@ -97,6 +97,35 @@ public static TreeNode makeFunction(String function, return new FunctionNode(function, children, retType); } + /** + * Invoke this function to create a node representing a function. + * + * @param function Name of the function, e.g. add + * @param children The arguments to the function + * @param retType The type of the return value of the operator + * @param listType The type of the list return value of the operator + * @return Node representing a function + */ + public static TreeNode makeFunction(String function, + List children, + ArrowType retType, ArrowType listType) { + return new FunctionNode(function, children, retType, listType); + } + + /** + * Invoke this function to create a node representing a function. + * + * @param function Name of the function, e.g. add + * @param children The arguments to the function + * @param retType The field of the return value of the operator, could be a complex type. + * @return Node representing a function + */ + public static TreeNode makeFunction(String function, + List children, + Field retType) { + return new FunctionNode(function, children, retType); + } + /** * Invoke this function to create a node representing an if-clause. * @@ -161,7 +190,7 @@ public static ExpressionTree makeExpression(String function, children.add(makeField(field)); } - TreeNode root = makeFunction(function, children, resultField.getType()); + TreeNode root = makeFunction(function, children, resultField); return makeExpression(root, resultField); }