From a225426e1400a583bcf4dbd90dfd45563ff78a46 Mon Sep 17 00:00:00 2001 From: lriggs Date: Mon, 20 Nov 2023 11:13:32 -0800 Subject: [PATCH] DX-64328 Array types for Gandiva (#58) Add List input and output types for Gandiva functions. Add new reference implementations for array_contains and array_remove, tested via integration with Dremio. int32, int64, double and float list types have been tested. Support List types in function specification and llvm code generation. Pass back function type information through the expression registry. See 1p here: https://docs.google.com/document/d/1exwXdUUnk5FqZLzVZyTdhqgwxTk0u9bL54aLVNM5Tas/edit --- cpp/src/arrow/buffer.h | 21 +- cpp/src/gandiva/CMakeLists.txt | 5 +- cpp/src/gandiva/annotator.cc | 86 +++- cpp/src/gandiva/array_ops.cc | 357 ++++++++++++++++ cpp/src/gandiva/array_ops.h | 86 ++++ cpp/src/gandiva/array_ops_test.cc | 41 ++ cpp/src/gandiva/dex.h | 32 ++ cpp/src/gandiva/dex_visitor.h | 6 + cpp/src/gandiva/exported_funcs.h | 6 + cpp/src/gandiva/expr_decomposer.cc | 13 +- cpp/src/gandiva/expr_validator.cc | 4 +- cpp/src/gandiva/expression_registry.cc | 7 + cpp/src/gandiva/field_descriptor.h | 20 +- cpp/src/gandiva/function_registry.cc | 14 +- cpp/src/gandiva/function_registry_array.cc | 54 +++ cpp/src/gandiva/function_registry_array.h | 28 ++ cpp/src/gandiva/gdv_function_stubs.cc | 144 ++++++- cpp/src/gandiva/llvm_generator.cc | 263 +++++++++++- cpp/src/gandiva/llvm_generator.h | 16 +- cpp/src/gandiva/llvm_types.cc | 3 +- cpp/src/gandiva/llvm_types.h | 23 + cpp/src/gandiva/llvm_types_test.cc | 10 + cpp/src/gandiva/lvalue.h | 73 +++- cpp/src/gandiva/precompiled/types.h | 2 + cpp/src/gandiva/projector.cc | 85 +++- cpp/src/gandiva/projector.h | 6 +- cpp/src/gandiva/tests/CMakeLists.txt | 1 + cpp/src/gandiva/tests/list_test.cc | 397 ++++++++++++++++++ .../tests/projector_build_validation_test.cc | 5 +- java/gandiva/CMakeLists.txt | 2 +- java/gandiva/pom.xml | 5 + java/gandiva/proto/Types.proto | 1 + .../main/cpp/expression_registry_helper.cc | 12 + java/gandiva/src/main/cpp/jni_common.cc | 230 ++++++++-- .../gandiva/evaluator/ExpressionRegistry.java | 52 ++- .../gandiva/evaluator/FunctionSignature.java | 38 +- .../arrow/gandiva/evaluator/JniWrapper.java | 3 +- .../gandiva/evaluator/ListVectorExpander.java | 83 ++++ .../arrow/gandiva/evaluator/Projector.java | 48 ++- .../gandiva/expression/ArrowTypeHelper.java | 54 ++- .../gandiva/expression/FunctionNode.java | 31 +- .../arrow/gandiva/expression/TreeBuilder.java | 31 +- 42 files changed, 2278 insertions(+), 120 deletions(-) create mode 100644 cpp/src/gandiva/array_ops.cc create mode 100644 cpp/src/gandiva/array_ops.h create mode 100644 cpp/src/gandiva/array_ops_test.cc create mode 100644 cpp/src/gandiva/function_registry_array.cc create mode 100644 cpp/src/gandiva/function_registry_array.h create mode 100644 cpp/src/gandiva/tests/list_test.cc create mode 100644 java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ListVectorExpander.java diff --git a/cpp/src/arrow/buffer.h b/cpp/src/arrow/buffer.h index 9270c4dea3fb6..8daa8bafaaf39 100644 --- a/cpp/src/arrow/buffer.h +++ b/cpp/src/arrow/buffer.h @@ -444,10 +444,27 @@ class ARROW_EXPORT ResizableBuffer : public MutableBuffer { return Reserve(sizeof(T) * new_nb_elements); } + public: + uint8_t* offsetBuffer; + int64_t offsetCapacity; + uint8_t* validityBuffer; + uint8_t* outerValidityBuffer; + protected: - ResizableBuffer(uint8_t* data, int64_t size) : MutableBuffer(data, size) {} + ResizableBuffer(uint8_t* data, int64_t size) : MutableBuffer(data, size) { + offsetBuffer = nullptr; + offsetCapacity = 0; + validityBuffer = nullptr; + outerValidityBuffer = nullptr; + + } ResizableBuffer(uint8_t* data, int64_t size, std::shared_ptr mm) - : MutableBuffer(data, size, std::move(mm)) {} + : MutableBuffer(data, size, std::move(mm)) { + offsetBuffer = nullptr; + offsetCapacity = 0; + validityBuffer = nullptr; + outerValidityBuffer = nullptr; + } }; /// \defgroup buffer-allocation-functions Functions for allocating buffers diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt index 6a92224e9113d..dc0c427f48d23 100644 --- a/cpp/src/gandiva/CMakeLists.txt +++ b/cpp/src/gandiva/CMakeLists.txt @@ -61,9 +61,11 @@ set(SRC_FILES expression_registry.cc exported_funcs_registry.cc filter.cc + array_ops.cc function_ir_builder.cc function_registry.cc function_registry_arithmetic.cc + function_registry_array.cc function_registry_datetime.cc function_registry_hash.cc function_registry_math_ops.cc @@ -249,7 +251,8 @@ add_gandiva_test(internals-test random_generator_holder_test.cc hash_utils_test.cc gdv_function_stubs_test.cc - interval_holder_test.cc) + interval_holder_test.cc + array_ops_test.cc) add_subdirectory(precompiled) add_subdirectory(tests) diff --git a/cpp/src/gandiva/annotator.cc b/cpp/src/gandiva/annotator.cc index b341fdde3a3f4..abd5ba6b1a4bf 100644 --- a/cpp/src/gandiva/annotator.cc +++ b/cpp/src/gandiva/annotator.cc @@ -46,15 +46,27 @@ FieldDescriptorPtr Annotator::MakeDesc(FieldPtr field, bool is_output) { int data_idx = buffer_count_++; int validity_idx = buffer_count_++; int offsets_idx = FieldDescriptor::kInvalidIdx; + int child_offsets_idx = FieldDescriptor::kInvalidIdx; if (arrow::is_binary_like(field->type()->id())) { offsets_idx = buffer_count_++; } + + if (field->type()->id() == arrow::Type::LIST) { + offsets_idx = buffer_count_++; + if (arrow::is_binary_like(field->type()->field(0)->type()->id())) { + child_offsets_idx = buffer_count_++; + } + } int data_buffer_ptr_idx = FieldDescriptor::kInvalidIdx; if (is_output) { data_buffer_ptr_idx = buffer_count_++; } + int child_valid_buffer_ptr_idx = FieldDescriptor::kInvalidIdx; + if (field->type()->id() == arrow::Type::LIST) { + child_valid_buffer_ptr_idx = buffer_count_++; + } return std::make_shared(field, data_idx, validity_idx, offsets_idx, - data_buffer_ptr_idx); + data_buffer_ptr_idx, child_offsets_idx, child_valid_buffer_ptr_idx); } int Annotator::AddHolderPointer(void* holder) { @@ -80,17 +92,76 @@ void Annotator::PrepareBuffersForField(const FieldDescriptor& desc, if (desc.HasOffsetsIdx()) { uint8_t* offsets_buf = const_cast(array_data.buffers[buffer_idx]->data()); eval_batch->SetBuffer(desc.offsets_idx(), offsets_buf, array_data.offset); - ++buffer_idx; + + if (desc.HasChildOffsetsIdx()) { + if (is_output) { + // if list field is output field, we should put buffer pointer into eval batch + // for resizing + uint8_t* child_offsets_buf = reinterpret_cast( + array_data.child_data.at(0)->buffers[buffer_idx].get()); + eval_batch->SetBuffer(desc.child_data_offsets_idx(), child_offsets_buf, + array_data.child_data.at(0)->offset); + + uint8_t* child_valid_buf = reinterpret_cast( + array_data.child_data.at(0)->buffers[0].get()); + eval_batch->SetBuffer(desc.child_data_validity_idx(), child_valid_buf, + array_data.child_data.at(0)->offset); + + } else { + // if list field is input field, just put buffer data into eval batch + uint8_t* child_offsets_buf = const_cast( + array_data.child_data.at(0)->buffers[buffer_idx]->data()); + eval_batch->SetBuffer(desc.child_data_offsets_idx(), child_offsets_buf, + array_data.child_data.at(0)->offset); + + uint8_t* child_valid_buf = const_cast( + array_data.child_data.at(0)->buffers[0]->data()); + eval_batch->SetBuffer(desc.child_data_offsets_idx(), child_valid_buf, + array_data.child_data.at(0)->offset); + } + } + if (array_data.type->id() != arrow::Type::LIST || + arrow::is_binary_like(array_data.type->field(0)->type()->id())) { + // primitive type list data buffer index is 1 + // binary like type list data buffer index is 2 + ++buffer_idx; + } + } + + int const childDataIndex = 0; + if (array_data.type->id() != arrow::Type::LIST) { + uint8_t* data_buf = const_cast(array_data.buffers[buffer_idx]->data()); + eval_batch->SetBuffer(desc.data_idx(), data_buf, array_data.offset); + } else { + uint8_t* data_buf = + const_cast(array_data.child_data.at(childDataIndex)->buffers[buffer_idx]->data()); + eval_batch->SetBuffer(desc.data_idx(), data_buf, array_data.child_data.at(0)->offset); + + int const childDataBufferIndex = 0; + if (array_data.child_data.at(childDataIndex)->buffers[childDataBufferIndex] ) { + uint8_t* child_valid_buf = const_cast( + array_data.child_data.at(childDataIndex)->buffers[childDataBufferIndex]->data()); + eval_batch->SetBuffer(desc.child_data_validity_idx(), child_valid_buf, 0); + } + } - uint8_t* data_buf = const_cast(array_data.buffers[buffer_idx]->data()); - eval_batch->SetBuffer(desc.data_idx(), data_buf, array_data.offset); if (is_output) { // pass in the Buffer object for output data buffers. Can be used for resizing. - uint8_t* data_buf_ptr = - reinterpret_cast(array_data.buffers[buffer_idx].get()); - eval_batch->SetBuffer(desc.data_buffer_ptr_idx(), data_buf_ptr, array_data.offset); + + if (array_data.type->id() != arrow::Type::LIST) { + uint8_t* data_buf_ptr = + reinterpret_cast(array_data.buffers[buffer_idx].get()); + eval_batch->SetBuffer(desc.data_buffer_ptr_idx(), data_buf_ptr, array_data.offset); + } else { + // list data buffer is in child data buffer + uint8_t* data_buf_ptr = reinterpret_cast( + array_data.child_data.at(0)->buffers[buffer_idx].get()); + eval_batch->SetBuffer(desc.data_buffer_ptr_idx(), data_buf_ptr, + array_data.child_data.at(0)->offset); + } } + } EvalBatchPtr Annotator::PrepareEvalBatch(const arrow::RecordBatch& record_batch, @@ -106,7 +177,6 @@ EvalBatchPtr Annotator::PrepareEvalBatch(const arrow::RecordBatch& record_batch, // skip columns not involved in the expression. continue; } - PrepareBuffersForField(*(found->second), *(record_batch.column_data(i)), eval_batch.get(), false /*is_output*/); } diff --git a/cpp/src/gandiva/array_ops.cc b/cpp/src/gandiva/array_ops.cc new file mode 100644 index 0000000000000..7170534342085 --- /dev/null +++ b/cpp/src/gandiva/array_ops.cc @@ -0,0 +1,357 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/array_ops.h" + +#include +#include +#include + +#include "arrow/util/value_parsing.h" + +#include "gandiva/gdv_function_stubs.h" +#include "gandiva/engine.h" +#include "gandiva/exported_funcs.h" + +/// Stub functions that can be accessed from LLVM or the pre-compiled library. + +template +Type* array_remove_template(int64_t context_ptr, const Type* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + Type remove_data, bool remove_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row, int32_t* out_len, int32_t** valid_ptr) +{ + std::vector newInts; + + const int32_t* entry_validityAdjusted = entry_validity - (loop_var ); + int64_t validityBitIndex = 0; + //The validity index already has the current row length added to it, so decrement. + validityBitIndex = validity_index_var - entry_len; + std::vector outValid; + for (int i = 0; i < entry_len; i++) { + Type entry_item = *(entry_buf + i); + if (remove_data_valid && entry_item == remove_data) { + //Do not add the item to remove. + } else if (!arrow::bit_util::GetBit(reinterpret_cast(entry_validityAdjusted), validityBitIndex + i)) { + outValid.push_back(false); + newInts.push_back(0); + } else { + outValid.push_back(true); + newInts.push_back(entry_item); + } + } + + *out_len = (int)newInts.size(); + + //Since this function can remove values we don't know the length ahead of time. + //A fast way to compute Math.ceil(input / 8.0). + int validByteSize = (unsigned int)((*out_len) + 7) >> 3; + + uint8_t* validRet = gdv_fn_context_arena_malloc(context_ptr, validByteSize); + for (size_t i = 0; i < outValid.size(); i++) { + arrow::bit_util::SetBitTo(validRet, i, outValid[i]); + } + + int32_t outBufferLength = (int)*out_len * sizeof(Type); + //length is number of items, but buffers must account for byte size. + uint8_t* ret = gdv_fn_context_arena_malloc(context_ptr, outBufferLength); + memcpy(ret, newInts.data(), outBufferLength); + *valid_row = true; + + //Return null if the input array is null or the data to remove is null. + if (!combined_row_validity || !remove_data_valid) { + *out_len = 0; + *valid_row = false; //this one is what works for the top level validity. + } + + *valid_ptr = reinterpret_cast(validRet); + return reinterpret_cast(ret); +} + +template +bool array_contains_template(const Type* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + Type contains_data, bool contains_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row) { + if (!combined_row_validity || !contains_data_valid) { + *valid_row = false; + return false; + } + *valid_row = true; + + const int32_t* entry_validityAdjusted = entry_validity - (loop_var ); + int64_t validityBitIndex = validity_index_var - entry_len; + + bool found_null_in_data = false; + for (int i = 0; i < entry_len; i++) { + if (!arrow::bit_util::GetBit(reinterpret_cast(entry_validityAdjusted), validityBitIndex + i)) { + found_null_in_data = true; + continue; + } + Type entry_item = *(entry_buf + i); + if (contains_data_valid && entry_item == contains_data) { + return true; + } + } + //If there is null in the input and the item is not found the result is null. + if (found_null_in_data) { + *valid_row = false; + } + return false; +} + +extern "C" { + +bool array_int32_contains_int32(int64_t context_ptr, const int32_t* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + int32_t contains_data, bool contains_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row) { + return array_contains_template(entry_buf, entry_len, entry_validity, + combined_row_validity, contains_data, contains_data_valid, + loop_var, validity_index_var, valid_row); +} + +bool array_int64_contains_int64(int64_t context_ptr, const int64_t* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + int64_t contains_data, bool contains_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row) { + return array_contains_template(entry_buf, entry_len, entry_validity, + combined_row_validity, contains_data, contains_data_valid, + loop_var, validity_index_var, valid_row); +} + +bool array_float32_contains_float32(int64_t context_ptr, const float* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + float contains_data, bool contains_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row) { + return array_contains_template(entry_buf, entry_len, entry_validity, + combined_row_validity, contains_data, contains_data_valid, + loop_var, validity_index_var, valid_row); +} + +bool array_float64_contains_float64(int64_t context_ptr, const double* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + double contains_data, bool contains_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row) { + return array_contains_template(entry_buf, entry_len, entry_validity, + combined_row_validity, contains_data, contains_data_valid, + loop_var, validity_index_var, valid_row); +} + + + +int32_t* array_int32_remove(int64_t context_ptr, const int32_t* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + int32_t remove_data, bool remove_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row, int32_t* out_len, int32_t** valid_ptr) { + return array_remove_template(context_ptr, entry_buf, + entry_len, entry_validity, combined_row_validity, + remove_data, remove_data_valid, + loop_var, validity_index_var, + valid_row, out_len, valid_ptr); +} + +int64_t* array_int64_remove(int64_t context_ptr, const int64_t* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + int64_t remove_data, bool remove_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row, int32_t* out_len, int32_t** valid_ptr){ + return array_remove_template(context_ptr, entry_buf, + entry_len, entry_validity, combined_row_validity, + remove_data, remove_data_valid, + loop_var, validity_index_var, + valid_row, out_len, valid_ptr); +} + +float* array_float32_remove(int64_t context_ptr, const float* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + float remove_data, bool remove_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row, int32_t* out_len, int32_t** valid_ptr){ + return array_remove_template(context_ptr, entry_buf, + entry_len, entry_validity, combined_row_validity, + remove_data, remove_data_valid, + loop_var, validity_index_var, + valid_row, out_len, valid_ptr); +} + + +double* array_float64_remove(int64_t context_ptr, const double* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + double remove_data, bool remove_data_valid, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row, int32_t* out_len, int32_t** valid_ptr){ + return array_remove_template(context_ptr, entry_buf, + entry_len, entry_validity, combined_row_validity, + remove_data, remove_data_valid, + loop_var, validity_index_var, + valid_row, out_len, valid_ptr); +} +} + +namespace gandiva { +void ExportedArrayFunctions::AddMappings(Engine* engine) const { + std::vector args; + auto types = engine->types(); + + args = {types->i64_type(), // int64_t execution_context + types->i64_ptr_type(), // int8_t* data ptr + types->i32_type(), // int32_t data length + types->i32_ptr_type(), // input validity buffer + types->i1_type(), // bool input row validity + types->i32_type(), // int32_t value to check for + types->i1_type(), // bool validity --Needed? + types->i64_type(), //in loop var --Needed? + types->i64_type(), //in validity_index_var index into the valdity vector for the current row. + types->i1_ptr_type() //output validity for the row + }; + + engine->AddGlobalMappingForFunc("array_int32_contains_int32", + types->i1_type() /*return_type*/, args, + reinterpret_cast(array_int32_contains_int32)); + + args = {types->i64_type(), // int64_t execution_context + types->i64_ptr_type(), // int8_t* data ptr + types->i32_type(), // int32_t data length + types->i32_ptr_type(), // input validity buffer + types->i1_type(), // bool input row validity + types->i64_type(), // int32_t value to check for + types->i1_type(), // bool validity --Needed? + types->i64_type(), //in loop var --Needed? + types->i64_type(), //in validity_index_var index into the valdity vector for the current row. + types->i1_ptr_type() //output validity for the row + }; + + engine->AddGlobalMappingForFunc("array_int64_contains_int64", + types->i1_type() /*return_type*/, args, + reinterpret_cast(array_int64_contains_int64)); + + args = {types->i64_type(), // int64_t execution_context + types->float_ptr_type(), // int8_t* data ptr + types->i32_type(), // int32_t data length + types->i32_ptr_type(), // input validity buffer + types->i1_type(), // bool input row validity + types->float_type(), // int32_t value to check for + types->i1_type(), // bool validity --Needed? + types->i64_type(), //in loop var --Needed? + types->i64_type(), //in validity_index_var index into the valdity vector for the current row. + types->i1_ptr_type() //output validity for the row + }; + + engine->AddGlobalMappingForFunc("array_float32_contains_float32", + types->i1_type() /*return_type*/, args, + reinterpret_cast(array_float32_contains_float32)); + + args = {types->i64_type(), // int64_t execution_context + types->double_ptr_type(), // int8_t* data ptr + types->i32_type(), // int32_t data length + types->i32_ptr_type(), // input validity buffer + types->i1_type(), // bool input row validity + types->double_type(), // int32_t value to check for + types->i1_type(), // bool validity --Needed? + types->i64_type(), //in loop var --Needed? + types->i64_type(), //in validity_index_var index into the valdity vector for the current row. + types->i1_ptr_type() //output validity for the row + }; + + engine->AddGlobalMappingForFunc("array_float64_contains_float64", + types->i1_type() /*return_type*/, args, + reinterpret_cast(array_float64_contains_float64)); + //Array remove. + args = {types->i64_type(), // int64_t execution_context + types->i32_ptr_type(), // int8_t* input data ptr + types->i32_type(), // int32_t input length + types->i32_ptr_type(), // input validity buffer + types->i1_type(), // bool input row validity + types->i32_type(), //value to remove from input + types->i1_type(), // bool validity --Needed? + types->i64_type(), //in loop var --Needed? + types->i64_type(), //in validity_index_var index into the valdity vector for the current row. + types->i1_ptr_type(), //output validity for the row + types->i32_ptr_type(), // output array length + types->i32_ptr_type() //output pointer to new validity buffer + + }; + engine->AddGlobalMappingForFunc("array_int32_remove", + types->i32_ptr_type(), args, + reinterpret_cast(array_int32_remove)); + + args = {types->i64_type(), // int64_t execution_context + types->i64_ptr_type(), // int8_t* input data ptr + types->i32_type(), // int32_t input length + types->i32_ptr_type(), // input validity buffer + types->i1_type(), // bool input row validity + types->i64_type(), //value to remove from input + types->i1_type(), // bool validity --Needed? + types->i64_type(), //in loop var --Needed? + types->i64_type(), //in validity_index_var index into the valdity vector for the current row. + types->i1_ptr_type(), //output validity for the row + types->i32_ptr_type(), // output array length + types->i32_ptr_type() //output pointer to new validity buffer + + }; + + engine->AddGlobalMappingForFunc("array_int64_remove", + types->i64_ptr_type(), args, + reinterpret_cast(array_int64_remove)); + + args = {types->i64_type(), // int64_t execution_context + types->float_ptr_type(), // float* input data ptr + types->i32_type(), // int32_t input length + types->i32_ptr_type(), // input validity buffer + types->i1_type(), // bool input row validity + types->float_type(), //value to remove from input + types->i1_type(), // bool validity --Needed? + types->i64_type(), //in loop var --Needed? + types->i64_type(), //in validity_index_var index into the valdity vector for the current row. + types->i1_ptr_type(), //output validity for the row + types->i32_ptr_type(), // output array length + types->i32_ptr_type() //output pointer to new validity buffer + + }; + + engine->AddGlobalMappingForFunc("array_float32_remove", + types->float_ptr_type(), args, + reinterpret_cast(array_float32_remove)); + + args = {types->i64_type(), // int64_t execution_context + types->double_ptr_type(), // int8_t* input data ptr + types->i32_type(), // int32_t input length + types->i32_ptr_type(), // input validity buffer + types->i1_type(), // bool input row validity + types->double_type(), //value to remove from input + types->i1_type(), // bool validity --Needed? + types->i64_type(), //in loop var --Needed? + types->i64_type(), //in validity_index_var index into the valdity vector for the current row. + types->i1_ptr_type(), //output validity for the row + types->i32_ptr_type(), // output array length + types->i32_ptr_type() //output pointer to new validity buffer + + }; + + engine->AddGlobalMappingForFunc("array_float64_remove", + types->double_ptr_type(), args, + reinterpret_cast(array_float64_remove)); +} +} // namespace gandiva diff --git a/cpp/src/gandiva/array_ops.h b/cpp/src/gandiva/array_ops.h new file mode 100644 index 0000000000000..c0de72a39472b --- /dev/null +++ b/cpp/src/gandiva/array_ops.h @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "gandiva/visibility.h" + +namespace llvm { +class VectorType; +} + +/// Array functions that can be accessed from LLVM. +extern "C" { + +GANDIVA_EXPORT +bool array_int32_contains_int32(int64_t context_ptr, const int32_t* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + int32_t contains_data, bool entry_validWhat, + int64_t loop_var, int64_t validity_index_var, + bool* valid_buf); +GANDIVA_EXPORT +bool array_int64_contains_int64(int64_t context_ptr, const int64_t* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + int64_t contains_data, bool entry_validWhat, + int64_t loop_var, int64_t validity_index_var, + bool* valid_buf); + +GANDIVA_EXPORT +bool array_float32_contains_float32(int64_t context_ptr, const float* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + float contains_data, bool entry_validWhat, + int64_t loop_var, int64_t validity_index_var, + bool* valid_buf); + +GANDIVA_EXPORT +bool array_float64_contains_float64(int64_t context_ptr, const double* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + double contains_data, bool entry_validWhat, + int64_t loop_var, int64_t validity_index_var, + bool* valid_buf); + +GANDIVA_EXPORT +int32_t* array_int32_remove(int64_t context_ptr, const int32_t* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + int32_t remove_data, bool entry_validWhat, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row, int32_t* out_len, int32_t** valid_ptr); + +GANDIVA_EXPORT +int64_t* array_int64_remove(int64_t context_ptr, const int64_t* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + int64_t remove_data, bool entry_validWhat, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row, int32_t* out_len, int32_t** valid_ptr); + +GANDIVA_EXPORT +float* array_float32_remove(int64_t context_ptr, const float* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + float remove_data, bool entry_validWhat, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row, int32_t* out_len, int32_t** valid_ptr); + +GANDIVA_EXPORT +double* array_float64_remove(int64_t context_ptr, const double* entry_buf, + int32_t entry_len, const int32_t* entry_validity, bool combined_row_validity, + double remove_data, bool entry_validWhat, + int64_t loop_var, int64_t validity_index_var, + bool* valid_row, int32_t* out_len, int32_t** valid_ptr); + +} diff --git a/cpp/src/gandiva/array_ops_test.cc b/cpp/src/gandiva/array_ops_test.cc new file mode 100644 index 0000000000000..bf01c1fe0a091 --- /dev/null +++ b/cpp/src/gandiva/array_ops_test.cc @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "gandiva/execution_context.h" +#include "gandiva/precompiled/types.h" + +namespace gandiva { + +TEST(TestArrayOps, TestInt32ContainsInt32) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast(&ctx); + int32_t data[] = {1, 2, 3, 4}; + int32_t entry_offsets_len = 3; + int32_t contains_data = 2; + int32_t entry_validity = 15; + bool valid = false; + + EXPECT_EQ( + array_int32_contains_int32(ctx_ptr, data, entry_offsets_len, &entry_validity, + true, contains_data, true, 0, 3, &valid), + true); +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/dex.h b/cpp/src/gandiva/dex.h index 2998c2131769a..95053ddabfb75 100644 --- a/cpp/src/gandiva/dex.h +++ b/cpp/src/gandiva/dex.h @@ -80,6 +80,23 @@ class GANDIVA_EXPORT VectorReadFixedLenValueDex : public VectorReadBaseDex { void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } }; +/// value component of a fixed-len list ValueVector +class GANDIVA_EXPORT VectorReadFixedLenValueListDex : public VectorReadBaseDex { + public: + explicit VectorReadFixedLenValueListDex(FieldDescriptorPtr field_desc) + : VectorReadBaseDex(field_desc) {} + + int DataIdx() const { return field_desc_->data_idx(); } + + int OffsetsIdx() const { return field_desc_->offsets_idx(); } + + int ValidityIdx() const { return field_desc_->validity_idx(); } + + int ChildValidityIdx() const { return field_desc_->child_data_validity_idx(); } + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } +}; + /// value component of a variable-len ValueVector class GANDIVA_EXPORT VectorReadVarLenValueDex : public VectorReadBaseDex { public: @@ -93,6 +110,21 @@ class GANDIVA_EXPORT VectorReadVarLenValueDex : public VectorReadBaseDex { void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } }; +/// value component of a variable-len list ValueVector +class GANDIVA_EXPORT VectorReadVarLenValueListDex : public VectorReadBaseDex { + public: + explicit VectorReadVarLenValueListDex(FieldDescriptorPtr field_desc) + : VectorReadBaseDex(field_desc) {} + + int DataIdx() const { return field_desc_->data_idx(); } + + int OffsetsIdx() const { return field_desc_->offsets_idx(); } + + int ChildOffsetsIdx() const { return field_desc_->child_data_offsets_idx(); } + + void Accept(DexVisitor& visitor) override { visitor.Visit(*this); } +}; + /// validity based on a local bitmap. class GANDIVA_EXPORT LocalBitMapValidityDex : public Dex { public: diff --git a/cpp/src/gandiva/dex_visitor.h b/cpp/src/gandiva/dex_visitor.h index 5d160bb22ca68..4a03b9c21fc8a 100644 --- a/cpp/src/gandiva/dex_visitor.h +++ b/cpp/src/gandiva/dex_visitor.h @@ -28,7 +28,9 @@ namespace gandiva { class VectorReadValidityDex; class VectorReadFixedLenValueDex; +class VectorReadFixedLenValueListDex; class VectorReadVarLenValueDex; +class VectorReadVarLenValueListDex; class LocalBitMapValidityDex; class LiteralDex; class TrueDex; @@ -49,7 +51,9 @@ class GANDIVA_EXPORT DexVisitor { virtual void Visit(const VectorReadValidityDex& dex) = 0; virtual void Visit(const VectorReadFixedLenValueDex& dex) = 0; + virtual void Visit(const VectorReadFixedLenValueListDex& dex) = 0; virtual void Visit(const VectorReadVarLenValueDex& dex) = 0; + virtual void Visit(const VectorReadVarLenValueListDex& dex) = 0; virtual void Visit(const LocalBitMapValidityDex& dex) = 0; virtual void Visit(const TrueDex& dex) = 0; virtual void Visit(const FalseDex& dex) = 0; @@ -75,7 +79,9 @@ class GANDIVA_EXPORT DexVisitor { class GANDIVA_EXPORT DexDefaultVisitor : public DexVisitor { VISIT_DCHECK(VectorReadValidityDex) VISIT_DCHECK(VectorReadFixedLenValueDex) + VISIT_DCHECK(VectorReadFixedLenValueListDex) VISIT_DCHECK(VectorReadVarLenValueDex) + VISIT_DCHECK(VectorReadVarLenValueListDex) VISIT_DCHECK(LocalBitMapValidityDex) VISIT_DCHECK(TrueDex) VISIT_DCHECK(FalseDex) diff --git a/cpp/src/gandiva/exported_funcs.h b/cpp/src/gandiva/exported_funcs.h index 5a14c52162156..55145b301e78c 100644 --- a/cpp/src/gandiva/exported_funcs.h +++ b/cpp/src/gandiva/exported_funcs.h @@ -32,6 +32,12 @@ class ExportedFuncsBase { virtual void AddMappings(Engine* engine) const = 0; }; +// Class for exporting Array functions +class ExportedArrayFunctions : public ExportedFuncsBase { + void AddMappings(Engine* engine) const override; +}; +REGISTER_EXPORTED_FUNCS(ExportedArrayFunctions); + // Class for exporting Stub functions class ExportedStubFunctions : public ExportedFuncsBase { void AddMappings(Engine* engine) const override; diff --git a/cpp/src/gandiva/expr_decomposer.cc b/cpp/src/gandiva/expr_decomposer.cc index 957d9d046bd57..719d4006e65ae 100644 --- a/cpp/src/gandiva/expr_decomposer.cc +++ b/cpp/src/gandiva/expr_decomposer.cc @@ -39,8 +39,17 @@ Status ExprDecomposer::Visit(const FieldNode& node) { DexPtr validity_dex = std::make_shared(desc); DexPtr value_dex; - if (desc->HasOffsetsIdx()) { - value_dex = std::make_shared(desc); + if (desc->HasChildOffsetsIdx()) { + // handle list type + value_dex = std::make_shared(desc); + } else if (desc->HasOffsetsIdx()) { + if (desc->field()->type()->id() == arrow::Type::LIST) { + // handle list type + auto p = std::make_shared(desc); + value_dex = p; + } else { + value_dex = std::make_shared(desc); + } } else { value_dex = std::make_shared(desc); } diff --git a/cpp/src/gandiva/expr_validator.cc b/cpp/src/gandiva/expr_validator.cc index 35a13494523d0..265f2c119cd0e 100644 --- a/cpp/src/gandiva/expr_validator.cc +++ b/cpp/src/gandiva/expr_validator.cc @@ -67,7 +67,7 @@ Status ExprValidator::Validate(const ExpressionPtr& expr) { } Status ExprValidator::Visit(const FieldNode& node) { - auto llvm_type = types_->IRType(node.return_type()->id()); + auto llvm_type = types_->DataVecType(node.return_type()); ARROW_RETURN_IF(llvm_type == nullptr, Status::ExpressionValidationError("Field ", node.field()->name(), " has unsupported data type ", @@ -136,7 +136,7 @@ Status ExprValidator::Visit(const IfNode& node) { } Status ExprValidator::Visit(const LiteralNode& node) { - auto llvm_type = types_->IRType(node.return_type()->id()); + auto llvm_type = types_->DataVecType(node.return_type()); ARROW_RETURN_IF(llvm_type == nullptr, Status::ExpressionValidationError("Value ", ToString(node.holder()), " has unsupported data type ", diff --git a/cpp/src/gandiva/expression_registry.cc b/cpp/src/gandiva/expression_registry.cc index 9bff97f5ad269..12ac0d0b154e8 100644 --- a/cpp/src/gandiva/expression_registry.cc +++ b/cpp/src/gandiva/expression_registry.cc @@ -166,6 +166,13 @@ static void AddArrowTypesToVector(arrow::Type::type type, DataTypeVector& vector case arrow::Type::type::INTERVAL_DAY_TIME: vector.push_back(arrow::day_time_interval()); break; + case arrow::Type::type::LIST: + vector.push_back(arrow::list(arrow::int32())); + vector.push_back(arrow::list(arrow::int64())); + vector.push_back(arrow::list(arrow::float32())); + vector.push_back(arrow::list(arrow::float64())); + vector.push_back(arrow::list(arrow::utf8())); + break; default: // Unsupported types. test ensures that // when one of these are added build breaks. diff --git a/cpp/src/gandiva/field_descriptor.h b/cpp/src/gandiva/field_descriptor.h index 0fe6fe37f4dd3..dfcf6872d501d 100644 --- a/cpp/src/gandiva/field_descriptor.h +++ b/cpp/src/gandiva/field_descriptor.h @@ -30,12 +30,16 @@ class FieldDescriptor { static const int kInvalidIdx = -1; FieldDescriptor(FieldPtr field, int data_idx, int validity_idx = kInvalidIdx, - int offsets_idx = kInvalidIdx, int data_buffer_ptr_idx = kInvalidIdx) + int offsets_idx = kInvalidIdx, int data_buffer_ptr_idx = kInvalidIdx, + int child_offsets_idx = kInvalidIdx, int child_validity_idx = kInvalidIdx) : field_(field), data_idx_(data_idx), validity_idx_(validity_idx), offsets_idx_(offsets_idx), - data_buffer_ptr_idx_(data_buffer_ptr_idx) {} + data_buffer_ptr_idx_(data_buffer_ptr_idx), + child_offsets_idx_(child_offsets_idx), + child_validity_idx_(child_validity_idx) { + } /// Index of validity array in the array-of-buffers int validity_idx() const { return validity_idx_; } @@ -49,6 +53,12 @@ class FieldDescriptor { /// Index of data buffer pointer in the array-of-buffers int data_buffer_ptr_idx() const { return data_buffer_ptr_idx_; } + /// Index of list type child data offsets + int child_data_offsets_idx() const { return child_offsets_idx_; } + int child_data_validity_idx() const { return child_validity_idx_; } + void set_child_data_validity_idx(int val) { + child_validity_idx_ = val; + } FieldPtr field() const { return field_; } const std::string& Name() const { return field_->name(); } @@ -58,12 +68,18 @@ class FieldDescriptor { bool HasDataBufferPtrIdx() const { return data_buffer_ptr_idx_ != kInvalidIdx; } + bool HasChildOffsetsIdx() const { return child_offsets_idx_ != kInvalidIdx; } + + bool HasChildValidityIdx() const { return child_validity_idx_ != kInvalidIdx; } + private: FieldPtr field_; int data_idx_; int validity_idx_; int offsets_idx_; int data_buffer_ptr_idx_; + int child_offsets_idx_; + int child_validity_idx_; }; } // namespace gandiva diff --git a/cpp/src/gandiva/function_registry.cc b/cpp/src/gandiva/function_registry.cc index 67b7b404b325c..9180e8c33ca33 100644 --- a/cpp/src/gandiva/function_registry.cc +++ b/cpp/src/gandiva/function_registry.cc @@ -16,17 +16,19 @@ // under the License. #include "gandiva/function_registry.h" + +#include +#include +#include + #include "gandiva/function_registry_arithmetic.h" +#include "gandiva/function_registry_array.h" #include "gandiva/function_registry_datetime.h" #include "gandiva/function_registry_hash.h" #include "gandiva/function_registry_math_ops.h" #include "gandiva/function_registry_string.h" #include "gandiva/function_registry_timestamp_arithmetic.h" -#include -#include -#include - namespace gandiva { FunctionRegistry::iterator FunctionRegistry::begin() const { @@ -64,6 +66,10 @@ SignatureMap FunctionRegistry::InitPCMap() { auto v6 = GetDateTimeArithmeticFunctionRegistry(); pc_registry_.insert(std::end(pc_registry_), v6.begin(), v6.end()); + + auto v7 = GetArrayFunctionRegistry(); + pc_registry_.insert(std::end(pc_registry_), v7.begin(), v7.end()); + for (auto& elem : pc_registry_) { for (auto& func_signature : elem.signatures()) { map.insert(std::make_pair(&(func_signature), &elem)); diff --git a/cpp/src/gandiva/function_registry_array.cc b/cpp/src/gandiva/function_registry_array.cc new file mode 100644 index 0000000000000..893ba6e3d2b04 --- /dev/null +++ b/cpp/src/gandiva/function_registry_array.cc @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/function_registry_array.h" + +#include "gandiva/function_registry_common.h" + +namespace gandiva { +std::vector GetArrayFunctionRegistry() { + static std::vector array_fn_registry_ = { + NativeFunction("array_contains", {}, DataTypeVector{list(int32()), int32()}, + boolean(), kResultNullInternal, "array_int32_contains_int32", + NativeFunction::kNeedsContext), + NativeFunction("array_contains", {}, DataTypeVector{list(int64()), int64()}, + boolean(), kResultNullInternal, "array_int64_contains_int64", + NativeFunction::kNeedsContext), + NativeFunction("array_contains", {}, DataTypeVector{list(float32()), float32()}, + boolean(), kResultNullInternal, "array_float32_contains_float32", + NativeFunction::kNeedsContext), + NativeFunction("array_contains", {}, DataTypeVector{list(float64()), float64()}, + boolean(), kResultNullInternal, "array_float64_contains_float64", + NativeFunction::kNeedsContext), + + NativeFunction("array_remove", {}, DataTypeVector{list(int32()), int32()}, + list(int32()), kResultNullInternal, "array_int32_remove", + NativeFunction::kNeedsContext), + NativeFunction("array_remove", {}, DataTypeVector{list(int64()), int64()}, + list(int64()), kResultNullInternal, "array_int64_remove", + NativeFunction::kNeedsContext), + NativeFunction("array_remove", {}, DataTypeVector{list(float32()), float32()}, + list(float32()), kResultNullInternal, "array_float32_remove", + NativeFunction::kNeedsContext), + NativeFunction("array_remove", {}, DataTypeVector{list(float64()), float64()}, + list(float64()), kResultNullInternal, "array_float64_remove", + NativeFunction::kNeedsContext), + }; + return array_fn_registry_; +} + +} // namespace gandiva diff --git a/cpp/src/gandiva/function_registry_array.h b/cpp/src/gandiva/function_registry_array.h new file mode 100644 index 0000000000000..9b8e4553702a8 --- /dev/null +++ b/cpp/src/gandiva/function_registry_array.h @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "gandiva/native_function.h" + +namespace gandiva { + +std::vector GetArrayFunctionRegistry(); + +} // namespace gandiva diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index 5146f7fa1990a..2ca9529fa846b 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -36,8 +36,6 @@ #include "gandiva/random_generator_holder.h" #include "gandiva/to_date_holder.h" -/// Stub functions that can be accessed from LLVM or the pre-compiled library. - extern "C" { static char mask_array[256] = { @@ -161,6 +159,96 @@ int32_t gdv_fn_populate_varlen_vector(int64_t context_ptr, int8_t* data_ptr, return 0; } +/// Stub functions that can be accessed from LLVM or the pre-compiled library. +#define POPULATE_NUMERIC_LIST_TYPE_VECTOR(TYPE, SCALE) \ + int32_t gdv_fn_populate_list_##TYPE##_vector(int64_t context_ptr, int8_t* data_ptr, \ + int32_t* offsets, int64_t slot, \ + TYPE* entry_buf, int32_t entry_len, int32_t** valid_ptr) { \ + auto buffer = reinterpret_cast(data_ptr); \ + int32_t offset = static_cast(buffer->size()); \ + auto status = buffer->Resize(offset + entry_len * SCALE, false /*shrink*/); \ + if (!status.ok()) { \ + gandiva::ExecutionContext* context = \ + reinterpret_cast(context_ptr); \ + context->set_error_msg(status.message().c_str()); \ + return -1; \ + } \ + memcpy(buffer->mutable_data() + offset, (char*)entry_buf, entry_len * SCALE); \ + int validbitIndex = offset / SCALE; \ + for (int i = 0; i < entry_len; i++) { \ + arrow::bit_util::SetBitTo(buffer->validityBuffer, validbitIndex + i, arrow::bit_util::GetBit(reinterpret_cast(valid_ptr), i)); \ + } \ + offsets = reinterpret_cast(buffer->offsetBuffer); \ + offsets[slot] = offset / SCALE; \ + offsets[slot + 1] = offset / SCALE + entry_len; \ + return 0; \ + }\ + +POPULATE_NUMERIC_LIST_TYPE_VECTOR(int32_t, 4) +POPULATE_NUMERIC_LIST_TYPE_VECTOR(int64_t, 8) +POPULATE_NUMERIC_LIST_TYPE_VECTOR(float, 4) +POPULATE_NUMERIC_LIST_TYPE_VECTOR(double, 8) + +int32_t gdv_fn_populate_list_varlen_vector(int64_t context_ptr, int8_t* data_ptr, + int32_t* offsets, int32_t* child_offsets, + int64_t slot, const char* entry_buf, + int32_t* entry_child_offsets, + int32_t entry_offsets_len) { + // we should calculate varlen list type varlen offset + // copy from entry child offsets + // it should be noted that, + // buffer size unit is byte(8 bit), + // offset element unit is int32(32 bit) + auto child_offsets_buffer = reinterpret_cast(child_offsets); + int32_t child_offsets_buffer_offset = + static_cast(child_offsets_buffer->size()); + + // data buffer elelment is char(8 bit) + auto data_buffer = reinterpret_cast(data_ptr); + int32_t data_buffer_offset = static_cast(data_buffer->size()); + + // sets the size in the child offsets buffer + // offsets element is int32, we should resize buffer by extra offsets_len * 4 + auto status = child_offsets_buffer->Resize( + child_offsets_buffer_offset + entry_offsets_len * 4, false /*shrink*/); + if (!status.ok()) { + gandiva::ExecutionContext* context = + reinterpret_cast(context_ptr); + + context->set_error_msg(status.message().c_str()); + return -1; + } + + // append the new child offsets entry to child offsets buffer + // offsets buffer last offset number indicating data length + // we should take this extra offset into consider + // so the initialize child_offsets_buffer length is 1(int32) + memcpy(child_offsets_buffer->mutable_data() + child_offsets_buffer_offset - 4, + (char*)entry_child_offsets, (entry_offsets_len + 1) * 4); + + // compute data length + int32_t data_length = + *(entry_child_offsets + entry_offsets_len) - *(entry_child_offsets); + + // sets the size in the child offsets buffer. + status = data_buffer->Resize(data_buffer_offset + data_length, false /*shrink*/); + if (!status.ok()) { + gandiva::ExecutionContext* context = + reinterpret_cast(context_ptr); + + context->set_error_msg(status.message().c_str()); + return -1; + } + + // append the new child offsets entry to child offsets buffer + memcpy(data_buffer->mutable_data() + data_buffer_offset, entry_buf, data_length); + + // update offsets buffer. + offsets[slot] = child_offsets_buffer_offset / 4 - 1; + offsets[slot + 1] = child_offsets_buffer_offset / 4 - 1 + entry_offsets_len; + return 0; +} + #define CRC_FUNCTION(TYPE) \ GANDIVA_EXPORT \ int64_t gdv_fn_crc_32_##TYPE(int64_t ctx, const char* input, int32_t input_len) { \ @@ -838,6 +926,8 @@ const char* gdv_mask_show_last_n_utf8_int32(int64_t context, const char* data, int32_t n_to_mask = num_of_chars - n_to_show; return gdv_mask_first_n_utf8_int32(context, data, data_len, n_to_mask, out_len); } + +#undef POPULATE_NUMERIC_LIST_TYPE_VECTOR } namespace gandiva { @@ -1174,6 +1264,34 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { types->i32_type() /*return_type*/, args, reinterpret_cast(gdv_fn_cast_intervalyear_utf8)); + engine->AddGlobalMappingForFunc("gdv_fn_in_expr_lookup_utf8", + types->i1_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_in_expr_lookup_utf8)); + +#define ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION(LLVM_TYPE, DATA_TYPE) \ + args = {types->i64_type(), types->i8_ptr_type(), types->i32_ptr_type(), \ + types->i64_type(), types->LLVM_TYPE##_ptr_type(), types->i32_type(), types->i32_ptr_type()}; \ + engine->AddGlobalMappingForFunc( \ + "gdv_fn_populate_list_" #DATA_TYPE "_vector", types->i32_type() /*return_type*/, \ + args, reinterpret_cast(gdv_fn_populate_list_##DATA_TYPE##_vector)); + + ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION(i32, int32_t) + ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION(i64, int64_t) + ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION(float, float) + ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION(double, double) + + // gdv_fn_populate_varlen_vector + args = {types->i64_type(), // int64_t execution_context + types->i8_ptr_type(), // int8_t* data ptr + types->i32_ptr_type(), // int32_t* offsets ptr + types->i64_type(), // int64_t slot + types->i8_ptr_type(), // const char* entry_buf + types->i32_type()}; // int32_t entry__len + + engine->AddGlobalMappingForFunc("gdv_fn_populate_varlen_vector", + types->i32_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_populate_varlen_vector)); + // gdv_fn_cast_intervalyear_utf8_int32 args = { types->i64_type(), // context @@ -1190,6 +1308,26 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { "gdv_fn_cast_intervalyear_utf8_int32", types->i32_type() /*return_type*/, args, reinterpret_cast(gdv_fn_cast_intervalyear_utf8_int32)); + // gdv_fn_populate_list_varlen_vector + args = {types->i64_type(), // int64_t execution_context + types->i8_ptr_type(), // int8_t* data ptr + types->i32_ptr_type(), // int32_t* offsets ptr + types->i32_ptr_type(), // int32_t* child offsets ptr + types->i64_type(), // int64_t slot + types->i8_ptr_type(), // const char* entry_buf + types->i32_ptr_type(), // int32_t* entry child offsets ptr + types->i32_type(), // int32_t entry child offsets length + types->i32_ptr_type() // int32_t* entry child valid ptr + }; + + engine->AddGlobalMappingForFunc( + "gdv_fn_populate_list_varlen_vector", types->i32_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_populate_list_varlen_vector)); + + // gdv_fn_random + args = {types->i64_type()}; + engine->AddGlobalMappingForFunc("gdv_fn_random", types->double_type(), args, + reinterpret_cast(gdv_fn_random)); // to_utc_timezone_timestamp args = { types->i64_type(), // context @@ -1289,4 +1427,6 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { engine->AddGlobalMappingForFunc("mask_utf8", types->i8_ptr_type() /*return_type*/, args, reinterpret_cast(mask_utf8)); } + +#undef ADD_MAPPING_FOR_NUMERIC_LIST_TYPE_POPULATE_FUNCTION } // namespace gandiva diff --git a/cpp/src/gandiva/llvm_generator.cc b/cpp/src/gandiva/llvm_generator.cc index 1615eece1f2c7..5e676d70251fa 100644 --- a/cpp/src/gandiva/llvm_generator.cc +++ b/cpp/src/gandiva/llvm_generator.cc @@ -30,7 +30,6 @@ #include "gandiva/lvalue.h" namespace gandiva { - #define ADD_TRACE(...) \ if (enable_ir_traces_) { \ AddTrace(__VA_ARGS__); \ @@ -94,7 +93,7 @@ Status LLVMGenerator::Build(const ExpressionVector& exprs, SelectionVector::Mode // Compile and inject into the process' memory the generated function. ARROW_RETURN_NOT_OK(engine_->FinalizeModule()); - + // setup the jit functions for each expression. for (auto& compiled_expr : compiled_exprs_) { auto fn_name = compiled_expr->GetFunctionName(mode); @@ -210,6 +209,14 @@ llvm::Value* LLVMGenerator::GetOffsetsReference(llvm::Value* arg_addrs, int idx, return ir_builder()->CreateIntToPtr(load, types()->i32_ptr_type(), name + "_oarray"); } +/// Get reference to child offsets array at specified index in the args list. +llvm::Value* LLVMGenerator::GetChildOffsetsReference(llvm::Value* arg_addrs, int idx, + FieldPtr field) { + const std::string& name = field->name(); + llvm::Value* load = LoadVectorAtIndex(arg_addrs, types()->i64_type(), idx, name); + return ir_builder()->CreateIntToPtr(load, types()->i32_ptr_type(), name + "_coarray"); +} + /// Get reference to local bitmap array at specified index in the args list. llvm::Value* LLVMGenerator::GetLocalBitMapReference(llvm::Value* arg_bitmaps, int idx) { llvm::Value* load = LoadVectorAtIndex(arg_bitmaps, types()->i64_type(), idx, ""); @@ -350,6 +357,10 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count, slice_offsets.push_back(offset); } + llvm::AllocaInst* validity_index_var = + new llvm::AllocaInst(types()->i64_type(), 0, "validity_index_var", loop_entry); + builder->CreateStore(types()->i64_constant(0), validity_index_var); + // Loop body builder->SetInsertPoint(loop_body); @@ -368,7 +379,7 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count, // The visitor can add code to both the entry/loop blocks. Visitor visitor(this, fn, loop_entry, arg_addrs, arg_local_bitmaps, arg_holder_ptrs, - slice_offsets, arg_context_ptr, position_var); + slice_offsets, arg_context_ptr, position_var, validity_index_var); value_expr->Accept(visitor); LValuePtr output_value = visitor.result(); @@ -397,12 +408,46 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count, AddFunctionCall("gdv_fn_populate_varlen_vector", types()->i32_type(), {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, loop_var, output_value->data(), output_value->length()}); + } else if (output_type_id == arrow::Type::STRUCT) { + auto slot_offset = builder->CreateGEP(types()->IRType(output_type_id), output_ref, loop_var); + builder->CreateStore(output_value->data(), slot_offset); + } else if (output_type_id == arrow::Type::LIST) { + auto output_list_internal_type = output->Type()->field(0)->type()->id(); + + if (arrow::is_binary_like(output_list_internal_type)) { + auto output_list_value = std::dynamic_pointer_cast(output_value); + llvm::Value* child_output_offset_ref = GetChildOffsetsReference( + arg_addrs, output->child_data_offsets_idx(), output->field()); + AddFunctionCall( + "gdv_fn_populate_list_varlen_vector", types()->i32_type(), + {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, + child_output_offset_ref, loop_var, output_list_value->data(), + output_list_value->child_offsets(), output_list_value->offsets_length()}); + } else if (output_list_internal_type == arrow::Type::INT32) { + AddFunctionCall("gdv_fn_populate_list_int32_t_vector", types()->i32_type(), + {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, + loop_var, output_value->data(), output_value->length(), output_value->validity()}); + } else if (output_list_internal_type == arrow::Type::INT64) { + AddFunctionCall("gdv_fn_populate_list_int64_t_vector", types()->i32_type(), + {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, + loop_var, output_value->data(), output_value->length(), output_value->validity()}); + } else if (output_list_internal_type == arrow::Type::FLOAT) { + AddFunctionCall("gdv_fn_populate_list_float_vector", types()->i32_type(), + {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, + loop_var, output_value->data(), output_value->length(), output_value->validity()}); + } else if (output_list_internal_type == arrow::Type::DOUBLE) { + AddFunctionCall("gdv_fn_populate_list_double_vector", types()->i32_type(), + {arg_context_ptr, output_buffer_ptr_ref, output_offset_ref, + loop_var, output_value->data(), output_value->length(), output_value->validity()}); + } else { + return Status::NotImplemented("list internal type ", + output->Type()->field(0)->type()->ToString(), + " not supported"); + } } else { return Status::NotImplemented("output type ", output->Type()->ToString(), " not supported"); } - ADD_TRACE("saving result " + output->Name() + " value %T", output_value->data()); - if (visitor.has_arena_allocs()) { // Reset allocations to avoid excessive memory usage. Once the result is copied to // the output vector (store instruction above), any memory allocations in this @@ -496,10 +541,10 @@ void LLVMGenerator::ComputeBitMapsForExpr(const CompiledExpr& compiled_expr, /// /// 1. Do the intersection of input/local bitmaps to generate a temporary bitmap. /// 2. copy just the relevant bits from the temporary bitmap to the output bitmap. + LocalBitMapsHolder bit_map_holder(eval_batch->num_records(), 1); uint8_t* temp_bitmap = bit_map_holder.GetLocalBitMap(0); accumulator.ComputeResult(temp_bitmap); - auto num_out_records = selection_vector->GetNumSlots(); // the memset isn't required, doing it just for valgrind. memset(dst_bitmap, 0, arrow::bit_util::BytesForBits(num_out_records)); @@ -530,6 +575,7 @@ llvm::Value* LLVMGenerator::AddFunctionCall(const std::string& full_name, value = ir_builder()->CreateCall(fn, args); } else { value = ir_builder()->CreateCall(fn, args, full_name); + DCHECK(value->getType() == ret_type); } @@ -558,7 +604,8 @@ LLVMGenerator::Visitor::Visitor(LLVMGenerator* generator, llvm::Function* functi llvm::Value* arg_local_bitmaps, llvm::Value* arg_holder_ptrs, std::vector slice_offsets, - llvm::Value* arg_context_ptr, llvm::Value* loop_var) + llvm::Value* arg_context_ptr, llvm::Value* loop_var, + llvm::Value* validity_index_var) : generator_(generator), function_(function), entry_block_(entry_block), @@ -568,11 +615,13 @@ LLVMGenerator::Visitor::Visitor(LLVMGenerator* generator, llvm::Function* functi slice_offsets_(slice_offsets), arg_context_ptr_(arg_context_ptr), loop_var_(loop_var), + validity_index_var_(validity_index_var), has_arena_allocs_(false) { ADD_VISITOR_TRACE("Iteration %T", loop_var); } void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex& dex) { + ADD_VISITOR_TRACE("VectorReadFixedLenValueDex"); llvm::IRBuilder<>* builder = ir_builder(); auto types = generator_->types(); llvm::Value* slot_ref = GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field()); @@ -580,6 +629,7 @@ void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex& dex) { llvm::Value* slot_value; std::shared_ptr lvalue; + ADD_VISITOR_TRACE("VectorReadFixedLenValueDex"); switch (dex.FieldType()->id()) { case arrow::Type::BOOL: slot_value = generator_->GetPackedBitValue(slot_ref, slot_index); @@ -606,11 +656,77 @@ void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex& dex) { result_ = lvalue; } -void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueDex& dex) { +void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueListDex& dex) { + ADD_VISITOR_TRACE("VectorReadFixedLenValueListDex"); llvm::IRBuilder<>* builder = ir_builder(); llvm::Value* slot; auto types = generator_->types(); + auto type = types->IRType(dex.FieldType()->id()); + + auto dt = dex.FieldType(); + if (dt->id() == arrow::Type::LIST) { + type = types->IRType(dt->fields()[0]->type()->id() ); + } + + arrow::Type::type at32 = arrow::Type::INT32; + auto type32 = types->IRType(at32); + + // compute list len from the offsets array. + llvm::Value* offsets_slot_ref = + GetBufferReference(dex.OffsetsIdx(), kBufferTypeOffsets, dex.Field()); + llvm::Value* offsets_slot_index = + builder->CreateAdd(loop_var_, GetSliceOffset(dex.OffsetsIdx())); + slot = builder->CreateGEP(type32, offsets_slot_ref, offsets_slot_index); + llvm::Value* offset_start = builder->CreateLoad(type32, slot, "offset_start"); + // => offset_end = offsets[loop_var + 1] + llvm::Value* offsets_slot_index_next = builder->CreateAdd( + offsets_slot_index, generator_->types()->i64_constant(1), "loop_var+1"); + slot = builder->CreateGEP(type32, offsets_slot_ref, offsets_slot_index_next); + llvm::Value* offset_end = builder->CreateLoad(type32, slot, "offset_end"); + + // => offsets_len_value = offset_end - offset_start + llvm::Value* list_len = builder->CreateSub(offset_end, offset_start, "offsets_len"); + + // get data array + llvm::Value* slot_ref = GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field()); + // do not forget slice offset + llvm::Value* offset_start_int64 = + builder->CreateIntCast(offset_start, generator_->types()->i64_type(), true); + llvm::Value* slot_index = + builder->CreateAdd(offset_start_int64, GetSliceOffset(dex.DataIdx())); + llvm::Value* data_list = builder->CreateGEP(type, slot_ref, slot_index); + + auto list_len_var = builder->CreateIntCast(list_len, types->i64_type(), true); + llvm::Value* vv_end = builder->CreateLoad(generator_->types()->i64_type(),validity_index_var_, "vv_end"); + +llvm::Value* updated_validity_index_var = builder->CreateAdd( + vv_end, list_len_var, "validity_index_var+offset"); + + builder->CreateStore(updated_validity_index_var, validity_index_var_); + llvm::Value* b_slot_index = + builder->CreateAdd(loop_var_, GetSliceOffset(dex.ValidityIdx())); + llvm::Value* b_slot_ref = GetBufferReference(dex.ChildValidityIdx(), kBufferTypeValidity, dex.Field()); + llvm::Value* validity = builder->CreateGEP(type32, b_slot_ref, b_slot_index); + + std::string str3 = "validity:"; + if (validity) { + llvm::raw_string_ostream output3(str3); + validity->print(output3); + } + ADD_VISITOR_TRACE("visit fixed-len data list vector " + dex.FieldName() + " length %T", + list_len); + ADD_VISITOR_TRACE("visit fixed-len data list vector " + dex.FieldName() + " updated_validity_index_var %T", + updated_validity_index_var); + + result_.reset(new LValue(data_list, list_len, validity)); +} + +void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueDex& dex) { + llvm::IRBuilder<>* builder = ir_builder(); + llvm::Value* slot; + auto types = generator_->types(); + ADD_VISITOR_TRACE("VectorReadVarLenValueDex"); // compute len from the offsets array. llvm::Value* offsets_slot_ref = GetBufferReference(dex.OffsetsIdx(), kBufferTypeOffsets, dex.Field()); @@ -641,7 +757,73 @@ void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueDex& dex) { result_.reset(new LValue(data_value, len_value)); } +/* + * create list type field context for each loop + */ +void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueListDex& dex) { + /* Example + * list_data: [["var_len_val11"], ["var_len_val211", "var_len_val22"], + * ["var_len_val3331"]] loop_var: 0, 1, 2 data_buffer: + * var_len_val11var_len_val211var_len_val22var_len_val3331 offsets_buffer: 0, 1, 3, 4 + * list_element_len = offsets[loop_var+1]-offsets[loop_var] => 1, 2, 1 + * child_offsets_buffer: 0, 13, 27, 40, 55 + * for i in list_element_len: + * data_buffer[child_offsets_buffer[offsets[i+1]] - child_offsets_buffer[offsets[i]]] + * => list_data[loop_var][i] + */ + ADD_VISITOR_TRACE("VectorReadVarLenValueListDex"); + llvm::IRBuilder<>* builder = ir_builder(); + llvm::Value* slot; + auto types = generator_->types(); + auto type = types->IRType(dex.FieldType()->id()); + + arrow::Type::type at = arrow::Type::INT32; + type = types->IRType(at); + + // compute list length from the offsets array + llvm::Value* offsets_slot_ref = + GetBufferReference(dex.OffsetsIdx(), kBufferTypeOffsets, dex.Field()); + llvm::Value* offsets_slot_index = + builder->CreateAdd(loop_var_, GetSliceOffset(dex.OffsetsIdx())); + + // => offset_start = offsets[loop_var] + slot = builder->CreateGEP(type, offsets_slot_ref, offsets_slot_index); + llvm::Value* offset_start = builder->CreateLoad(type, slot, "offset_start"); + + // => offset_end = offsets[loop_var + 1] + llvm::Value* offsets_slot_index_next = builder->CreateAdd( + offsets_slot_index, generator_->types()->i64_constant(1), "loop_var+1"); + slot = builder->CreateGEP(type, offsets_slot_ref, offsets_slot_index_next); + llvm::Value* offset_end = builder->CreateLoad(type, slot, "offset_end"); + + // => list_data_length = offset_end - offset_start + llvm::Value* list_data_length = + builder->CreateSub(offset_end, offset_start, "offsets_len"); + + // get the child offsets array from the child offsets array, + // start from offset 'offset_start' + llvm::Value* child_offset_slot_ref = + GetBufferReference(dex.ChildOffsetsIdx(), kBufferTypeChildOffsets, dex.Field()); + // do not forget slice offset + llvm::Value* offset_start_int64 = + builder->CreateIntCast(offset_start, generator_->types()->i64_type(), true); + llvm::Value* child_offset_slot_index = + builder->CreateAdd(offset_start_int64, GetSliceOffset(dex.ChildOffsetsIdx())); + llvm::Value* child_offsets = + builder->CreateGEP(type, child_offset_slot_ref, child_offset_slot_index); + llvm::Value* child_offset_start = + builder->CreateLoad(type, child_offsets, "child_offset_start"); + + // get the data array + llvm::Value* data_slot_ref = + GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field()); + llvm::Value* data_value = builder->CreateGEP(type, data_slot_ref, child_offset_start); + + result_.reset(new ListLValue(data_value, child_offsets, list_data_length)); +} + void LLVMGenerator::Visitor::Visit(const VectorReadValidityDex& dex) { + ADD_VISITOR_TRACE("VectorReadValidityDex"); llvm::IRBuilder<>* builder = ir_builder(); llvm::Value* slot_ref = GetBufferReference(dex.ValidityIdx(), kBufferTypeValidity, dex.Field()); @@ -654,6 +836,7 @@ void LLVMGenerator::Visitor::Visit(const VectorReadValidityDex& dex) { } void LLVMGenerator::Visitor::Visit(const LocalBitMapValidityDex& dex) { + ADD_VISITOR_TRACE("LocalBitMapValidityDex"); llvm::Value* slot_ref = GetLocalBitMapReference(dex.local_bitmap_idx()); llvm::Value* validity = generator_->GetPackedBitValue(slot_ref, loop_var_); @@ -664,14 +847,17 @@ void LLVMGenerator::Visitor::Visit(const LocalBitMapValidityDex& dex) { } void LLVMGenerator::Visitor::Visit(const TrueDex& dex) { + ADD_VISITOR_TRACE("TrueDex"); result_.reset(new LValue(generator_->types()->true_constant())); } void LLVMGenerator::Visitor::Visit(const FalseDex& dex) { + ADD_VISITOR_TRACE("FalseDex"); result_.reset(new LValue(generator_->types()->false_constant())); } void LLVMGenerator::Visitor::Visit(const LiteralDex& dex) { + ADD_VISITOR_TRACE("LiteralDex"); LLVMTypes* types = generator_->types(); llvm::Value* value = nullptr; llvm::Value* len = nullptr; @@ -716,7 +902,6 @@ void LLVMGenerator::Visitor::Visit(const LiteralDex& dex) { case arrow::Type::STRING: case arrow::Type::BINARY: { const std::string& str = std::get(dex.holder()); - value = ir_builder()->CreateGlobalStringPtr(str.c_str()); len = types->i32_constant(static_cast(str.length())); break; @@ -777,8 +962,7 @@ void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex& dex) { llvm::IRBuilder<>* builder = ir_builder(); LLVMTypes* types = generator_->types(); auto arrow_type_id = arrow_return_type->id(); - auto result_type = types->IRType(arrow_type_id); - + auto result_type = types->DataVecType(arrow_return_type); // Build combined validity of the args. llvm::Value* is_valid = types->true_constant(); for (auto& pair : dex.args()) { @@ -836,18 +1020,34 @@ void LLVMGenerator::Visitor::Visit(const NullableInternalFuncDex& dex) { auto params = BuildParams(dex.get_holder_idx(), dex.args(), true, native_function->NeedsContext()); + + + auto arrow_return_type = dex.func_descriptor()->return_type(); + + bool passLoopVars = false; + for (auto& p : dex.func_descriptor()->params()) { + if (p->id() == arrow::Type::LIST) { + passLoopVars = true; + break; + } + } + if (passLoopVars) + { + params.push_back(loop_var_); + auto valid_var = builder->CreateLoad(types->i64_type(), validity_index_var_, "loaded_var"); + params.push_back(valid_var); + } + // add an extra arg for validity (allocated on stack). llvm::AllocaInst* result_valid_ptr = new llvm::AllocaInst(types->i8_type(), 0, "result_valid", entry_block_); params.push_back(result_valid_ptr); - auto arrow_return_type = dex.func_descriptor()->return_type(); result_ = BuildFunctionCall(native_function, arrow_return_type, ¶ms); // load the result validity and truncate to i1. auto result_valid_i8 = builder->CreateLoad(types->i8_type(), result_valid_ptr); llvm::Value* result_valid = builder->CreateTrunc(result_valid_i8, types->i1_type()); - // set validity bit in the local bitmap. ClearLocalBitMapIfNotValid(dex.local_bitmap_idx(), result_valid); } @@ -1125,25 +1325,31 @@ void LLVMGenerator::Visitor::VisitInExpression( } void LLVMGenerator::Visitor::Visit(const InExprDexBase& dex) { + ADD_VISITOR_TRACE("InExprDexBase&"); VisitInExpression(dex); } void LLVMGenerator::Visitor::Visit(const InExprDexBase& dex) { + ADD_VISITOR_TRACE("InExprDexBase&"); VisitInExpression(dex); } void LLVMGenerator::Visitor::Visit(const InExprDexBase& dex) { + ADD_VISITOR_TRACE("InExprDexBase&"); VisitInExpression(dex); } void LLVMGenerator::Visitor::Visit(const InExprDexBase& dex) { + ADD_VISITOR_TRACE("InExprDexBase&"); VisitInExpression(dex); } void LLVMGenerator::Visitor::Visit(const InExprDexBase& dex) { + ADD_VISITOR_TRACE("InExprDexBase&"); VisitInExpression(dex); } void LLVMGenerator::Visitor::Visit(const InExprDexBase& dex) { + ADD_VISITOR_TRACE("InExprDexBase&"); VisitInExpression(dex); } @@ -1151,6 +1357,7 @@ LValuePtr LLVMGenerator::Visitor::BuildIfElse(llvm::Value* condition, std::function then_func, std::function else_func, DataTypePtr result_type) { + ADD_VISITOR_TRACE("BuildIfElse"); llvm::IRBuilder<>* builder = ir_builder(); llvm::LLVMContext* context = generator_->context(); LLVMTypes* types = generator_->types(); @@ -1180,7 +1387,7 @@ LValuePtr LLVMGenerator::Visitor::BuildIfElse(llvm::Value* condition, // Emit the merge block. builder->SetInsertPoint(merge_bb); - auto llvm_type = types->IRType(result_type->id()); + auto llvm_type = types->DataVecType(result_type); llvm::PHINode* result_value = builder->CreatePHI(llvm_type, 2, "res_value"); result_value->addIncoming(then_lvalue->data(), then_bb); result_value->addIncoming(else_lvalue->data(), else_bb); @@ -1226,7 +1433,7 @@ LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func, std::vector* params) { auto types = generator_->types(); auto arrow_return_type_id = arrow_return_type->id(); - auto llvm_return_type = types->IRType(arrow_return_type_id); + auto llvm_return_type = types->DataVecType(arrow_return_type); DecimalIR decimalIR(generator_->engine_.get()); if (arrow_return_type_id == arrow::Type::DECIMAL) { @@ -1255,6 +1462,7 @@ LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func, } // add extra arg for return length for variable len return types (allocated on stack). llvm::AllocaInst* result_len_ptr = nullptr; + llvm::AllocaInst* valid_ptr = nullptr; if (arrow::is_binary_like(arrow_return_type_id)) { result_len_ptr = new llvm::AllocaInst(generator_->types()->i32_type(), 0, "result_len", entry_block_); @@ -1262,6 +1470,17 @@ LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func, has_arena_allocs_ = true; } + if (arrow_return_type_id == arrow::Type::LIST) { + + result_len_ptr = new llvm::AllocaInst(generator_->types()->i32_type(), 0, + "result_len", entry_block_); + params->push_back(result_len_ptr); + has_arena_allocs_ = true; + valid_ptr = new llvm::AllocaInst(generator_->types()->i32_ptr_type(), 0, + "valid_ptr", entry_block_); + params->push_back(valid_ptr); + } + // Make the function call llvm::IRBuilder<>* builder = ir_builder(); auto value = @@ -1272,7 +1491,11 @@ LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func, (result_len_ptr == nullptr) ? nullptr : builder->CreateLoad(result_len_ptr->getAllocatedType(), result_len_ptr); - return std::make_shared(value, value_len); + auto validity = + (valid_ptr == nullptr) + ? nullptr + : builder->CreateLoad(generator_->types()->i32_ptr_type(), valid_ptr); + return std::make_shared(value, value_len, validity); } } @@ -1281,6 +1504,7 @@ std::vector LLVMGenerator::Visitor::BuildParams( bool with_context) { std::vector params; + ADD_VISITOR_TRACE("LLVMGenerator::Visitor::BuildParams"); // add context if required. if (with_context) { params.push_back(arg_context_ptr_); @@ -1311,6 +1535,7 @@ std::vector LLVMGenerator::Visitor::BuildParams( // append all the parameters corresponding to this LValue. result_ref.AppendFunctionParams(¶ms); + // build validity. if (with_validity) { llvm::Value* validity_expr = BuildCombinedValidity(pair->validity_exprs()); @@ -1356,6 +1581,10 @@ llvm::Value* LLVMGenerator::Visitor::GetBufferReference(int idx, BufferType buff case kBufferTypeOffsets: slot_ref = generator_->GetOffsetsReference(arg_addrs_, idx, field); break; + + case kBufferTypeChildOffsets: + slot_ref = generator_->GetChildOffsetsReference(arg_addrs_, idx, field); + break; } // Revert to the saved block. @@ -1384,6 +1613,7 @@ llvm::Value* LLVMGenerator::Visitor::GetLocalBitMapReference(int idx) { /// The local bitmap is pre-filled with 1s. Clear only if invalid. void LLVMGenerator::Visitor::ClearLocalBitMapIfNotValid(int local_bitmap_idx, llvm::Value* is_valid) { + ADD_VISITOR_TRACE("ClearLocalBitMapIfNotValid"); llvm::Value* slot_ref = GetLocalBitMapReference(local_bitmap_idx); generator_->ClearPackedBitValueIfFalse(slot_ref, loop_var_, is_valid); } @@ -1454,5 +1684,4 @@ void LLVMGenerator::AddTrace(const std::string& msg, llvm::Value* value) { } AddFunctionCall(print_fn_name, types()->i32_type(), args); } - } // namespace gandiva diff --git a/cpp/src/gandiva/llvm_generator.h b/cpp/src/gandiva/llvm_generator.h index 04f9b854b1d29..e8c15bdf00744 100644 --- a/cpp/src/gandiva/llvm_generator.h +++ b/cpp/src/gandiva/llvm_generator.h @@ -98,11 +98,13 @@ class GANDIVA_EXPORT LLVMGenerator { llvm::BasicBlock* entry_block, llvm::Value* arg_addrs, llvm::Value* arg_local_bitmaps, llvm::Value* arg_holder_ptrs, std::vector slice_offsets, llvm::Value* arg_context_ptr, - llvm::Value* loop_var); + llvm::Value* loop_var, llvm::Value* validity_index); void Visit(const VectorReadValidityDex& dex) override; void Visit(const VectorReadFixedLenValueDex& dex) override; + void Visit(const VectorReadFixedLenValueListDex& dex) override; void Visit(const VectorReadVarLenValueDex& dex) override; + void Visit(const VectorReadVarLenValueListDex& dex) override; void Visit(const LocalBitMapValidityDex& dex) override; void Visit(const TrueDex& dex) override; void Visit(const FalseDex& dex) override; @@ -127,7 +129,12 @@ class GANDIVA_EXPORT LLVMGenerator { bool has_arena_allocs() { return has_arena_allocs_; } private: - enum BufferType { kBufferTypeValidity = 0, kBufferTypeData, kBufferTypeOffsets }; + enum BufferType { + kBufferTypeValidity = 0, + kBufferTypeData, + kBufferTypeOffsets, + kBufferTypeChildOffsets + }; llvm::IRBuilder<>* ir_builder() { return generator_->ir_builder(); } llvm::Module* module() { return generator_->module(); } @@ -175,6 +182,7 @@ class GANDIVA_EXPORT LLVMGenerator { std::vector slice_offsets_; llvm::Value* arg_context_ptr_; llvm::Value* loop_var_; + llvm::Value* validity_index_var_; bool has_arena_allocs_; }; @@ -195,6 +203,10 @@ class GANDIVA_EXPORT LLVMGenerator { /// Generate code to load the vector at specified index and cast it as offsets array. llvm::Value* GetOffsetsReference(llvm::Value* arg_addrs, int idx, FieldPtr field); + /// Generate code to load the vector at specified index and cast it as child offsets + /// array. + llvm::Value* GetChildOffsetsReference(llvm::Value* arg_addrs, int idx, FieldPtr field); + /// Generate code to load the vector at specified index and cast it as buffer pointer. llvm::Value* GetDataBufferPtrReference(llvm::Value* arg_addrs, int idx, FieldPtr field); diff --git a/cpp/src/gandiva/llvm_types.cc b/cpp/src/gandiva/llvm_types.cc index de322a8c0fcb5..3eb49f39037f6 100644 --- a/cpp/src/gandiva/llvm_types.cc +++ b/cpp/src/gandiva/llvm_types.cc @@ -42,7 +42,8 @@ LLVMTypes::LLVMTypes(llvm::LLVMContext& context) : context_(context) { {arrow::Type::type::BINARY, i8_ptr_type()}, {arrow::Type::type::DECIMAL, i128_type()}, {arrow::Type::type::INTERVAL_MONTHS, i32_type()}, - {arrow::Type::type::INTERVAL_DAY_TIME, i64_type()}}; + {arrow::Type::type::INTERVAL_DAY_TIME, i64_type()}, + {arrow::Type::type::LIST, list_type()}}; } } // namespace gandiva diff --git a/cpp/src/gandiva/llvm_types.h b/cpp/src/gandiva/llvm_types.h index d6f0952713efc..58b7c3008695f 100644 --- a/cpp/src/gandiva/llvm_types.h +++ b/cpp/src/gandiva/llvm_types.h @@ -46,6 +46,8 @@ class GANDIVA_EXPORT LLVMTypes { llvm::Type* i128_type() { return llvm::Type::getInt128Ty(context_); } + llvm::VectorType* list_type() { return llvm::ScalableVectorType::get(i8_type(), (unsigned int)0); } + llvm::StructType* i128_split_type() { // struct with high/low bits (see decimal_ops.cc:DecimalSplit) return llvm::StructType::get(context_, {i64_type(), i64_type()}, false); @@ -57,6 +59,8 @@ class GANDIVA_EXPORT LLVMTypes { llvm::PointerType* ptr_type(llvm::Type* type) { return type->getPointerTo(); } + llvm::PointerType* i1_ptr_type() { return ptr_type(i1_type()); } + llvm::PointerType* i8_ptr_type() { return ptr_type(i8_type()); } llvm::PointerType* i32_ptr_type() { return ptr_type(i32_type()); } @@ -65,6 +69,10 @@ class GANDIVA_EXPORT LLVMTypes { llvm::PointerType* i128_ptr_type() { return ptr_type(i128_type()); } + llvm::PointerType* float_ptr_type() { return ptr_type(float_type()); } + + llvm::PointerType* double_ptr_type() { return ptr_type(double_type()); } + template llvm::Constant* int_constant(ctype val) { return llvm::ConstantInt::get(context_, llvm::APInt(N, val)); @@ -87,6 +95,10 @@ class GANDIVA_EXPORT LLVMTypes { return llvm::ConstantFP::get(float_type(), val); } + llvm::LLVMContext* get_context() { + return &context_; + } + llvm::Constant* double_constant(double val) { return llvm::ConstantFP::get(double_type(), val); } @@ -104,6 +116,17 @@ class GANDIVA_EXPORT LLVMTypes { /// For a given data type, find the ir type used for the data vector slot. llvm::Type* DataVecType(const DataTypePtr& data_type) { + // support list type + // list type data is formed by base type buffer, wrapped with offsets buffer + // offsets buffer is to separate data into list + // not support nested list + if (data_type->id() == arrow::Type::LIST) { + //Nested lists aren't supported yet. + if (data_type->field(0)->type()->id() == arrow::Type::LIST) { + return NULL; + } + return IRType(data_type->field(0)->type()->id()); + } return IRType(data_type->id()); } diff --git a/cpp/src/gandiva/llvm_types_test.cc b/cpp/src/gandiva/llvm_types_test.cc index 6669683061825..665a82d133fad 100644 --- a/cpp/src/gandiva/llvm_types_test.cc +++ b/cpp/src/gandiva/llvm_types_test.cc @@ -50,12 +50,22 @@ TEST_F(TestLLVMTypes, TestFound) { types_->i64_type()); EXPECT_EQ(types_->DataVecType(arrow::timestamp(arrow::TimeUnit::MILLI)), types_->i64_type()); + + EXPECT_EQ(types_->IRType(arrow::Type::STRING), types_->i8_ptr_type()); + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::boolean())), types_->i1_type()); + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::int32())), types_->i32_type()); + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::int64())), types_->i64_type()); + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::float32())), types_->float_type()); + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::float64())), types_->double_type()); + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::utf8())), types_->i8_ptr_type()); } TEST_F(TestLLVMTypes, TestNotFound) { EXPECT_EQ(types_->IRType(arrow::Type::SPARSE_UNION), nullptr); EXPECT_EQ(types_->IRType(arrow::Type::DENSE_UNION), nullptr); EXPECT_EQ(types_->DataVecType(arrow::null()), nullptr); + // not support nested list type + EXPECT_EQ(types_->DataVecType(arrow::list(arrow::list(arrow::utf8()))), nullptr); } } // namespace gandiva diff --git a/cpp/src/gandiva/lvalue.h b/cpp/src/gandiva/lvalue.h index df292855b69af..04862dc9d18c8 100644 --- a/cpp/src/gandiva/lvalue.h +++ b/cpp/src/gandiva/lvalue.h @@ -46,9 +46,36 @@ class GANDIVA_EXPORT LValue { if (length_ != NULLPTR) { params->push_back(length_); } + if (validity_ != NULLPTR) { + params->push_back(validity_); + } } - private: + virtual std::string to_string() { + std::string s = "Base LValue"; + + std::string str1 = "data:"; + if (data_) { + llvm::raw_string_ostream output1(str1); + data_->print(output1); + } + + std::string str2 = "length:"; + if (length_) { + llvm::raw_string_ostream output2(str2); + length_->print(output2); + } + + std::string str3 = "validity:"; + if (validity_) { + llvm::raw_string_ostream output3(str3); + validity_->print(output3); + } + + return s + "\n" + str1 + "\n" + str2 + "\n" + str3; + } + + protected: llvm::Value* data_; llvm::Value* length_; llvm::Value* validity_; @@ -74,4 +101,48 @@ class GANDIVA_EXPORT DecimalLValue : public LValue { llvm::Value* scale_; }; +class GANDIVA_EXPORT ListLValue : public LValue { + public: + ListLValue(llvm::Value* data, llvm::Value* child_offsets, llvm::Value* offsets_length, + llvm::Value* validity = NULLPTR) + : LValue(data, NULLPTR, validity), + child_offsets_(child_offsets), + offsets_length_(offsets_length) { + } + + llvm::Value* child_offsets() { return child_offsets_; } + + llvm::Value* offsets_length() { return offsets_length_; } + + void AppendFunctionParams(std::vector* params) override { + LValue::AppendFunctionParams(params); + params->push_back(child_offsets_); + params->push_back(offsets_length_); + params->push_back(validity_); + } + + virtual std::string to_string() override { + std::string s = "List LValue"; + s += " " + LValue::to_string(); + + std::string str1 = "child_offsets_:"; + if (child_offsets_) { + llvm::raw_string_ostream output1(str1); + child_offsets_->print(output1); + } + + std::string str2 = "offsets_length_:"; + if (offsets_length_) { + llvm::raw_string_ostream output2(str2); + offsets_length_->print(output2); + } + + return s + "\n" + str1 + "\n" + str2; + } + + private: + llvm::Value* child_offsets_; + llvm::Value* offsets_length_; +}; + } // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/types.h b/cpp/src/gandiva/precompiled/types.h index 83bbdee208562..117b27b2808dd 100644 --- a/cpp/src/gandiva/precompiled/types.h +++ b/cpp/src/gandiva/precompiled/types.h @@ -19,6 +19,8 @@ #include + +#include "gandiva/array_ops.h" #include "gandiva/gdv_function_stubs.h" // Use the same names as in arrow data types. Makes it easy to write pre-processor macros. diff --git a/cpp/src/gandiva/projector.cc b/cpp/src/gandiva/projector.cc index 54de03963f7e7..0181ece3b2607 100644 --- a/cpp/src/gandiva/projector.cc +++ b/cpp/src/gandiva/projector.cc @@ -184,7 +184,10 @@ Status Projector::Evaluate(const arrow::RecordBatch& batch, ValidateArrayDataCapacity(*array_data, *(output_fields_[idx]), num_rows)); ++idx; } - return llvm_generator_->Execute(batch, selection_vector, output_data_vecs); + ARROW_RETURN_NOT_OK( + llvm_generator_->Execute(batch, selection_vector, output_data_vecs)); + + return Status::OK(); } Status Projector::Evaluate(const arrow::RecordBatch& batch, arrow::MemoryPool* pool, @@ -215,14 +218,45 @@ Status Projector::Evaluate(const arrow::RecordBatch& batch, llvm_generator_->Execute(batch, selection_vector, output_data_vecs)); // Create and return array arrays. + int const child_data_buffer_index = 1; + int const int_data_size = 4; + int const double_data_size = 8; output->clear(); for (auto& array_data : output_data_vecs) { + if (array_data->type->id() == arrow::Type::LIST) { + auto child_data = array_data->child_data[0]; + int64_t child_data_size = 1; + if (arrow::is_binary_like(child_data->type->id())) { + /* when allocate array data, child data length is an initialized value, + * after calculating, child data offsets buffer has been resized for results, + * but array data length is unchanged. + * We should recalculate child data length and make ArrayData with new length + * + * Otherwise, child data offsets buffer length is data length + 1 + * and offset data is int32_t, need use buffer->size()/4 - 1 + */ + child_data_size = child_data->buffers[child_data_buffer_index]->size() / int_data_size - 1; + } else if (child_data->type->id() == arrow::Type::INT32) { + child_data_size = child_data->buffers[child_data_buffer_index]->size() / int_data_size; + } else if (child_data->type->id() == arrow::Type::INT64) { + child_data_size = child_data->buffers[child_data_buffer_index]->size() / double_data_size; + } else if (child_data->type->id() == arrow::Type::FLOAT) { + child_data_size = child_data->buffers[child_data_buffer_index]->size() / int_data_size; + } else if (child_data->type->id() == arrow::Type::DOUBLE) { + child_data_size = child_data->buffers[child_data_buffer_index]->size() / double_data_size; + } + auto new_child_data = arrow::ArrayData::Make( + child_data->type, child_data_size, child_data->buffers, child_data->offset); + array_data = arrow::ArrayData::Make(array_data->type, array_data->length, + array_data->buffers, {new_child_data}, + array_data->null_count, array_data->offset); + } + output->push_back(arrow::MakeArray(array_data)); } return Status::OK(); } -// TODO : handle complex vectors (list/map/..) Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records, arrow::MemoryPool* pool, ArrayDataPtr* array_data) const { @@ -243,6 +277,23 @@ Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records, buffers.push_back(std::move(offsets_buffer)); } + if (type_id == arrow::Type::LIST) { + auto offsets_len = arrow::bit_util::BytesForBits((num_records + 1) * 32); + + ARROW_ASSIGN_OR_RAISE(auto offsets_buffer, arrow::AllocateBuffer(offsets_len, pool)); + buffers.push_back(std::move(offsets_buffer)); + + if (arrow::is_binary_like(type->field(0)->type()->id())) { + // child offsets length is internal data length + 1 + // offsets element is int32 + // so here i just allocate extra 32 bit for extra 1 length + ARROW_ASSIGN_OR_RAISE( + auto child_offsets_buffer, + arrow::AllocateResizableBuffer(arrow::bit_util::BytesForBits(32), pool)); + buffers.push_back(std::move(child_offsets_buffer)); + } + } + // The output vector always has a data array. int64_t data_len; if (arrow::is_primitive(type_id) || type_id == arrow::Type::DECIMAL) { @@ -251,6 +302,8 @@ Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records, } else if (arrow::is_binary_like(type_id)) { // we don't know the expected size for varlen output vectors. data_len = 0; + } else if (type_id == arrow::Type::LIST) { + data_len = 0; } else { return Status::Invalid("Unsupported output data type " + type->ToString()); } @@ -263,7 +316,27 @@ Status Projector::AllocArrayData(const DataTypePtr& type, int64_t num_records, } buffers.push_back(std::move(data_buffer)); - *array_data = arrow::ArrayData::Make(type, num_records, std::move(buffers)); + ARROW_ASSIGN_OR_RAISE(auto data_valid_buffer, arrow::AllocateResizableBuffer(data_len, pool)); + + if (type->id() == arrow::Type::LIST) { + auto internal_type = type->field(0)->type(); + ArrayDataPtr child_data; + if (arrow::is_primitive(internal_type->id())) { + child_data = arrow::ArrayData::Make(internal_type, 0 /*initialize length*/, + {std::move(data_valid_buffer), std::move(buffers[2])}, 0); + } + if (arrow::is_binary_like(internal_type->id())) { + child_data = arrow::ArrayData::Make( + internal_type, 0 /*initialize length*/, + {nullptr, std::move(buffers[2]), std::move(buffers[3])}, 0); + } + *array_data = arrow::ArrayData::Make( + type, num_records, {std::move(buffers[0]), std::move(buffers[1])}, {child_data}); + + } else { + *array_data = arrow::ArrayData::Make(type, num_records, std::move(buffers)); + } + return Status::OK(); } @@ -312,7 +385,10 @@ Status Projector::ValidateArrayDataCapacity(const arrow::ArrayData& array_data, int64_t data_len = array_data.buffers[1]->capacity(); ARROW_RETURN_IF(data_len < min_data_len, Status::Invalid("Data buffer too small for ", field.name())); - } else { + } else if (type_id == arrow::Type::LIST) { + return Status::OK(); + } + else { return Status::Invalid("Unsupported output data type " + field.type()->ToString()); } @@ -339,4 +415,5 @@ std::shared_ptr Projector::GetSecondaryCacheKey(std::string prima return arrow::Buffer::FromString(key); } + } // namespace gandiva diff --git a/cpp/src/gandiva/projector.h b/cpp/src/gandiva/projector.h index 24ec11e3eab59..53d0ef6d62431 100644 --- a/cpp/src/gandiva/projector.h +++ b/cpp/src/gandiva/projector.h @@ -154,14 +154,14 @@ class GANDIVA_EXPORT Projector { bool GetBuiltFromCache(); void Clear(); + /// Allocate an ArrowData of length 'length'. + Status AllocArrayData(const DataTypePtr& type, int64_t num_records, + arrow::MemoryPool* pool, ArrayDataPtr* array_data) const; private: Projector(std::unique_ptr llvm_generator, SchemaPtr schema, const FieldVector& output_fields, std::shared_ptr); - /// Allocate an ArrowData of length 'length'. - Status AllocArrayData(const DataTypePtr& type, int64_t num_records, - arrow::MemoryPool* pool, ArrayDataPtr* array_data) const; /// Validate that the ArrayData has sufficient capacity to accommodate 'num_records'. Status ValidateArrayDataCapacity(const arrow::ArrayData& array_data, diff --git a/cpp/src/gandiva/tests/CMakeLists.txt b/cpp/src/gandiva/tests/CMakeLists.txt index b89c0ac225209..bc607702126af 100644 --- a/cpp/src/gandiva/tests/CMakeLists.txt +++ b/cpp/src/gandiva/tests/CMakeLists.txt @@ -25,6 +25,7 @@ add_gandiva_test(binary_test) add_gandiva_test(date_time_test) add_gandiva_test(to_string_test) add_gandiva_test(utf8_test) +add_gandiva_test(list_test) add_gandiva_test(hash_test) add_gandiva_test(in_expr_test) add_gandiva_test(null_validity_test) diff --git a/cpp/src/gandiva/tests/list_test.cc b/cpp/src/gandiva/tests/list_test.cc new file mode 100644 index 0000000000000..abc7b5d7091b8 --- /dev/null +++ b/cpp/src/gandiva/tests/list_test.cc @@ -0,0 +1,397 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include + +#include "arrow/memory_pool.h" +#include "arrow/status.h" +#include "gandiva/execution_context.h" +#include "gandiva/precompiled/types.h" +#include "gandiva/projector.h" +#include "gandiva/tests/test_util.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::boolean; +using arrow::float32; +using arrow::float64; +using arrow::int32; +using arrow::int64; +using arrow::utf8; +using std::string; +using std::vector; + +class TestList : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +template +void _build_list_array(const vector& values, const vector& length, + const vector& validity, arrow::MemoryPool* pool, + ArrayPtr* array, const vector& innerValidity = {}) { + size_t sum = 0; + for (auto& len : length) { + sum += len; + } + EXPECT_TRUE(values.size() == sum); + EXPECT_TRUE(length.size() == validity.size()); + + auto value_builder = std::make_shared(pool); + auto builder = std::make_shared(pool, value_builder); + int i = 0; + for (size_t l = 0; l < length.size(); l++) { + if (validity[l]) { + auto status = builder->Append(); + for (int j = 0; j < length[l]; j++) { + if (innerValidity.size() > (size_t)j && innerValidity[j] == false) { + auto v = value_builder->AppendNull(); + } else { + ASSERT_OK(value_builder->Append(values[i])); + } + i++; + } + } else { + ASSERT_OK(builder->AppendNull()); + for (int j = 0; j < length[l]; j++) { + i++; + } + } + } + ASSERT_OK(builder->Finish(array)); +} + +template +void _build_list_array2(const vector& values, const vector& length, + const vector& validity, const vector& innerValidity, arrow::MemoryPool* pool, + ArrayPtr* array) { + return _build_list_array(values, length, validity, pool, array); + } + +/* + * expression: + * input: a + * output: res + * typeof(a) can be list / list / list + */ +void _test_list_type_field_alias(DataTypePtr type, ArrayPtr array, + arrow::MemoryPool* pool, int num_records = 5) { + auto field_a = field("a", type); + auto schema = arrow::schema({field_a}); + auto result = field("res", type); + + std::cout << array->ToString() << std::endl; + assert(array->length() == num_records); + + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array}); + + // Make expression + std::cout << "Make expression" << std::endl; + auto field_a_node = TreeExprBuilder::MakeField(field_a); + auto expr = TreeExprBuilder::MakeExpression(field_a_node, result); + + std::cout << "Build a projector for the expressions." << std::endl; + // Build a projector for the expressions. + std::shared_ptr projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + std::cout << "status message: " << status.message() << std::endl; + EXPECT_TRUE(status.ok()) << status.message(); + + std::cout << "Evaluate expression" << std::endl; + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + std::cout << "Check results" << std::endl; + EXPECT_ARROW_ARRAY_EQUALS(array, outputs[0]); + // EXPECT_ARROW_ARRAY_EQUALS will not check the length of child data, but + // ArrayData::Slice method will check length. ArrayData::ToString method will call + // ArrayData::Slice method + EXPECT_TRUE(array->ToString() == outputs[0]->ToString()); + EXPECT_TRUE(array->null_count() == outputs[0]->null_count()); +} + +/* +TEST_F(TestList, TestArrayRemove) { + // schema for input fields + auto field_b = field("b", int32()); + + auto field_a = field("a", list(int32())); + auto schema = arrow::schema({field_a, field_b}); + + // output fields + auto res = field("res", list(int32())); + + // Create a row-batch with some sample data + int num_records = 2; + auto array_b = + MakeArrowArrayInt32({42, 42}, {true, true}); + + ArrayPtr array_a; + _build_list_array2( + {10, 42, 30, 42, 70, 80}, + {3, 3}, {true, true}, {true, true, true, true, true, true}, pool_, &array_a); + + // expected output + ArrayPtr exp1; + _build_list_array2( + {10, 30, 70, 80}, + {2, 2}, {true, true}, {true, true, true, true}, pool_, &exp1); + + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b}); + + auto expr = TreeExprBuilder::MakeExpression("array_remove", {field_a, field_b}, res); + + // Build a projector for the expressions. + std::shared_ptr projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + std::cout << "LR Test 2 " << std::endl; + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp1, outputs.at(0)); + + //Try the second method. + arrow::ArrayDataVector outputs2; + std::shared_ptr listDt = std::make_shared(); + std::shared_ptr dt = std::make_shared(listDt); + + + int num_records2 = 5; + std::vector> buffers; + + int64_t size = 20; + auto bitmap_buffer = arrow::AllocateBuffer(size, pool_); + buffers.push_back(*std::move(bitmap_buffer)); + auto offsets_len = arrow::bit_util::BytesForBits((num_records2 + 1) * 32); + + auto offsets_buffer = arrow::AllocateBuffer(offsets_len*10, pool_); + buffers.push_back(*std::move(offsets_buffer)); + +std::vector> buffers2; +auto bitmap_buffer2 = arrow::AllocateBuffer(size, pool_); + buffers2.push_back(*std::move(bitmap_buffer2)); + + auto offsets_buffer2 = arrow::AllocateBuffer(offsets_len, pool_); + buffers2.push_back(*std::move(offsets_buffer2)); +std::shared_ptr dt2 = std::make_shared(); + + auto array_data_child = arrow::ArrayData::Make(dt2, num_records2, buffers2, 0, 0); + array_data_child->buffers = std::move(buffers2); + + std::vector> kids; + kids.push_back(array_data_child); + + +auto array_data = arrow::ArrayData::Make(dt, num_records2, buffers, kids, 0, 0); +array_data->buffers = std::move(buffers); +outputs2.push_back(array_data); + + + status = projector->Evaluate(*(in_batch.get()), outputs2); + EXPECT_TRUE(status.ok()) << status.message(); + arrow::ArrayData ad = *outputs2.at(0); + arrow::ArraySpan sp(*ad.child_data.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp1, sp.ToArray()); + + + + +for (auto& array_data : outputs2) { + auto child_data = array_data->child_data[0]; + int64_t child_data_size = 1; + if (arrow::is_binary_like(child_data->type->id())) { + child_data_size = child_data->buffers[1]->size() / 4 - 1; + } else if (child_data->type->id() == arrow::Type::INT32) { + child_data_size = child_data->buffers[1]->size() / 4; + } else if (child_data->type->id() == arrow::Type::INT64) { + child_data_size = child_data->buffers[1]->size() / 8; + } else if (child_data->type->id() == arrow::Type::FLOAT) { + child_data_size = child_data->buffers[1]->size() / 4; + } else if (child_data->type->id() == arrow::Type::DOUBLE) { + child_data_size = child_data->buffers[1]->size() / 8; + } + auto new_child_data = arrow::ArrayData::Make( + child_data->type, child_data_size, child_data->buffers, child_data->offset); + array_data = arrow::ArrayData::Make(array_data->type, array_data->length, + array_data->buffers, {new_child_data}, + array_data->null_count, array_data->offset); + + + auto newArray = arrow::MakeArray(array_data); + //arrow::ArraySpan sp(newArray); + EXPECT_ARROW_ARRAY_EQUALS(exp1, newArray); +} + + +{ + std::shared_ptr listDt = std::make_shared(); + std::shared_ptr dt = std::make_shared(listDt); + +ArrayDataPtr output_data; + auto s = projector->AllocArrayData(dt, num_records2, pool_, &output_data); + ArrayDataVector output_data_vecs; + output_data_vecs.push_back(output_data); + + status = projector->Evaluate(*(in_batch.get()), output_data_vecs); + EXPECT_TRUE(status.ok()) << status.message(); + arrow::ArraySpan sp(*output_data_vecs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp1, sp.ToArray()); + } +} + +TEST_F(TestList, TestListArrayInt32) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast(&ctx); + int32_t data[] = {11, 2, 23, 42}; + int32_t entry_offsets_len = 4; + int32_t contains_data = 42; + + EXPECT_EQ( + array_int32_contains_int32(ctx_ptr, data, entry_offsets_len, + contains_data), + true); +} + + +TEST_F(TestList, TestListInt32LiteralContains) { + // schema for input fields + auto field_a = field("a", list(int32())); + auto field_b = field("b", int32()); + auto schema = arrow::schema({field_a, field_b}); + + // output fields + auto res = field("res", boolean()); + + // Create a row-batch with some sample data + int num_records = 5; + ArrayPtr array_a; + _build_list_array( + {1, 5, 19, 42, 57}, + {1, 1, 1, 1, 1}, {true, true, true, true, true}, pool_, &array_a); + + auto array_b = + MakeArrowArrayInt32({42, 42, 42, 42, 42}); + + // expected output + auto exp = MakeArrowArrayBool({false, false, false, true, false}, + {true, true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b}); + + std::vector field_nodes; + auto node = TreeExprBuilder::MakeField(field_a); + field_nodes.push_back(node); + + auto node2 = TreeExprBuilder::MakeLiteral(42); + field_nodes.push_back(node2); + + auto func_node = TreeExprBuilder::MakeFunction("array_contains", field_nodes, res->type()); + auto expr = TreeExprBuilder::MakeExpression(func_node, res); + //////// + + // Build a projector for the expressions. + std::shared_ptr projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestList, TestListInt32Contains) { + // schema for input fields + auto field_a = field("a", list(int32())); + auto field_b = field("b", int32()); + auto schema = arrow::schema({field_a, field_b}); + + // output fields + auto res = field("res", boolean()); + + // Create a row-batch with some sample data + int num_records = 5; + ArrayPtr array_a; + _build_list_array( + {1, 5, 19, 42, 57}, + {1, 1, 1, 1, 1}, {true, true, true, true, true}, pool_, &array_a); + + auto array_b = + MakeArrowArrayInt32({42, 42, 42, 42, 42}); + + // expected output + auto exp = MakeArrowArrayBool({false, false, false, true, false}, + {true, true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array_a, array_b}); + + // build expressions. + // array_contains(a, b) + auto expr = TreeExprBuilder::MakeExpression("array_contains", {field_a, field_b}, res); + + // Build a projector for the expressions. + std::shared_ptr projector; + auto status = Projector::Make(schema, {expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()) << status.message(); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()) << status.message(); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0)); +} + +TEST_F(TestList, TestListFloat32) { + ArrayPtr array; + _build_list_array( + {1.1f, 11.1f, 22.2f, 111.1f, 222.2f, 333.3f, 1111.1f, 2222.2f, 3333.3f, 4444.4f, + 11111.1f, 22222.2f, 33333.3f, 44444.4f, 55555.5f}, + {1, 2, 3, 4, 5}, {true, true, true, true, true}, pool_, &array); + _test_list_type_field_alias(list(float32()), array, pool_); +} + +TEST_F(TestList, TestListFloat64) { + ArrayPtr array; + _build_list_array( + {1.1, 1.11, 2.22, 1.111, 2.222, 3.333, 1.1111, 2.2222, 3.3333, 4.4444, 1.11111, + 2.22222, 3.33333, 4.44444, 5.55555}, + {1, 2, 4, 3, 5}, {true, false, true, true, true}, pool_, &array); + _test_list_type_field_alias(list(float64()), array, pool_); +}*/ + +} // namespace gandiva diff --git a/cpp/src/gandiva/tests/projector_build_validation_test.cc b/cpp/src/gandiva/tests/projector_build_validation_test.cc index 5b86844f940bf..82b59ef19ad75 100644 --- a/cpp/src/gandiva/tests/projector_build_validation_test.cc +++ b/cpp/src/gandiva/tests/projector_build_validation_test.cc @@ -26,6 +26,7 @@ namespace gandiva { using arrow::boolean; using arrow::float32; using arrow::int32; +using arrow::utf8; class TestProjector : public ::testing::Test { public: @@ -80,7 +81,7 @@ TEST_F(TestProjector, TestNotMatchingDataType) { TEST_F(TestProjector, TestNotSupportedDataType) { // schema for input fields - auto field0 = field("f0", list(int32())); + auto field0 = field("f0", map(utf8(), int32())); auto schema = arrow::schema({field0}); // output fields @@ -94,7 +95,7 @@ TEST_F(TestProjector, TestNotSupportedDataType) { std::shared_ptr projector; auto status = Projector::Make(schema, {lt_expr}, TestConfiguration(), &projector); EXPECT_TRUE(status.IsExpressionValidationError()); - std::string expected_error = "Field f0 has unsupported data type list"; + std::string expected_error = "Field f0 has unsupported data type map"; EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); } diff --git a/java/gandiva/CMakeLists.txt b/java/gandiva/CMakeLists.txt index 629ab2fb347d8..60762f6307c06 100644 --- a/java/gandiva/CMakeLists.txt +++ b/java/gandiva/CMakeLists.txt @@ -38,7 +38,7 @@ set(GANDIVA_PROTO_DIR ${CMAKE_CURRENT_SOURCE_DIR}/proto) get_filename_component(GANDIVA_PROTO_FILE_ABSOLUTE ${GANDIVA_PROTO_DIR}/Types.proto ABSOLUTE) -find_package(Protobuf REQUIRED) +find_package(Protobuf CONFIG REQUIRED) add_custom_command(OUTPUT ${GANDIVA_PROTO_OUTPUT_FILES} COMMAND protobuf::protoc --proto_path ${GANDIVA_PROTO_DIR} --cpp_out ${GANDIVA_PROTO_OUTPUT_DIR} ${GANDIVA_PROTO_FILE_ABSOLUTE} diff --git a/java/gandiva/pom.xml b/java/gandiva/pom.xml index bed66b427e625..d2df653d5e1fe 100644 --- a/java/gandiva/pom.xml +++ b/java/gandiva/pom.xml @@ -30,6 +30,11 @@ ../../../cpp/release-build + + org.apache.arrow + arrow-format + ${project.version} + org.apache.arrow arrow-memory-core diff --git a/java/gandiva/proto/Types.proto b/java/gandiva/proto/Types.proto index eb0d996b92e63..a5c4df474db37 100644 --- a/java/gandiva/proto/Types.proto +++ b/java/gandiva/proto/Types.proto @@ -85,6 +85,7 @@ message ExtGandivaType { optional TimeUnit timeUnit = 6; // used by TIME32/TIME64 optional string timeZone = 7; // used by TIMESTAMP optional IntervalType intervalType = 8; // used by INTERVAL + optional GandivaType listType = 9; //used by LIST } message Field { diff --git a/java/gandiva/src/main/cpp/expression_registry_helper.cc b/java/gandiva/src/main/cpp/expression_registry_helper.cc index 6765df3b9727f..cc1ed04194861 100644 --- a/java/gandiva/src/main/cpp/expression_registry_helper.cc +++ b/java/gandiva/src/main/cpp/expression_registry_helper.cc @@ -136,6 +136,18 @@ void ArrowToProtobuf(DataTypePtr type, types::ExtGandivaType* gandiva_data_type) gandiva_data_type->set_type(types::GandivaType::INTERVAL); gandiva_data_type->set_intervaltype(types::IntervalType::DAY_TIME); break; + case arrow::Type::LIST: { + gandiva_data_type->set_type(types::GandivaType::LIST); + if (type->num_fields() <= 0) { + break; + } + if (type->fields()[0]->type()->id() != arrow::Type::LIST) { + types::ExtGandivaType gt; + ArrowToProtobuf(type->fields()[0]->type(), >); + gandiva_data_type->set_listtype(gt.type()); + } + break; + } default: // un-supported types. test ensures that // when one of these are added build breaks. diff --git a/java/gandiva/src/main/cpp/jni_common.cc b/java/gandiva/src/main/cpp/jni_common.cc index d5e54f38e3692..7a631ad856c47 100644 --- a/java/gandiva/src/main/cpp/jni_common.cc +++ b/java/gandiva/src/main/cpp/jni_common.cc @@ -82,10 +82,16 @@ jclass configuration_builder_class_; // refs for self. static jclass gandiva_exception_; static jclass vector_expander_class_; +static jclass listvector_expander_class_; static jclass vector_expander_ret_class_; +static jclass list_expander_ret_class_; static jmethodID vector_expander_method_; +static jmethodID listvector_expander_method_; static jfieldID vector_expander_ret_address_; static jfieldID vector_expander_ret_capacity_; +static jfieldID list_expander_ret_address_; +static jfieldID list_expander_valid_address_; +static jfieldID list_expander_ret_capacity_; static jclass secondary_cache_class_; static jmethodID cache_get_method_; @@ -125,16 +131,37 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { vector_expander_class_, "expandOutputVectorAtIndex", "(IJ)Lorg/apache/arrow/gandiva/evaluator/VectorExpander$ExpandResult;"); + jclass local_listexpander_class = + env->FindClass("org/apache/arrow/gandiva/evaluator/ListVectorExpander"); + listvector_expander_class_ = (jclass)env->NewGlobalRef(local_listexpander_class); + env->DeleteLocalRef(local_listexpander_class); + + listvector_expander_method_ = env->GetMethodID( + listvector_expander_class_, "expandOutputVectorAtIndex", + "(IJ)Lorg/apache/arrow/gandiva/evaluator/ListVectorExpander$ExpandResult;"); + jclass local_expander_ret_class = env->FindClass("org/apache/arrow/gandiva/evaluator/VectorExpander$ExpandResult"); vector_expander_ret_class_ = (jclass)env->NewGlobalRef(local_expander_ret_class); env->DeleteLocalRef(local_expander_ret_class); + jclass local_list_expander_ret_class = + env->FindClass("org/apache/arrow/gandiva/evaluator/ListVectorExpander$ExpandResult"); + list_expander_ret_class_ = (jclass)env->NewGlobalRef(local_list_expander_ret_class); + env->DeleteLocalRef(local_list_expander_ret_class); + vector_expander_ret_address_ = env->GetFieldID(vector_expander_ret_class_, "address", "J"); vector_expander_ret_capacity_ = env->GetFieldID(vector_expander_ret_class_, "capacity", "J"); + list_expander_ret_address_ = + env->GetFieldID(list_expander_ret_class_, "address", "J"); + list_expander_ret_capacity_ = + env->GetFieldID(list_expander_ret_class_, "capacity", "J"); + list_expander_valid_address_ = + env->GetFieldID(list_expander_ret_class_, "validityaddress", "J"); + jclass local_cache_class = env->FindClass("org/apache/arrow/gandiva/evaluator/JavaSecondaryCacheInterface"); secondary_cache_class_ = (jclass)env->NewGlobalRef(local_cache_class); @@ -164,11 +191,15 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) { env->DeleteGlobalRef(configuration_builder_class_); env->DeleteGlobalRef(gandiva_exception_); env->DeleteGlobalRef(vector_expander_class_); + env->DeleteGlobalRef(listvector_expander_class_); env->DeleteGlobalRef(vector_expander_ret_class_); + env->DeleteGlobalRef(list_expander_ret_class_); env->DeleteGlobalRef(secondary_cache_class_); env->DeleteGlobalRef(cache_buf_ret_class_); } +DataTypePtr SimpleProtoTypeToDataType(const types::GandivaType& gandiva_type); + DataTypePtr ProtoTypeToTime32(const types::ExtGandivaType& ext_type) { switch (ext_type.timeunit()) { case types::SEC: @@ -221,8 +252,13 @@ DataTypePtr ProtoTypeToInterval(const types::ExtGandivaType& ext_type) { } } -DataTypePtr ProtoTypeToDataType(const types::ExtGandivaType& ext_type) { - switch (ext_type.type()) { +DataTypePtr ProtoTypeToList(const types::ExtGandivaType& ext_type) { + DataTypePtr childType = SimpleProtoTypeToDataType(ext_type.listtype()); + return arrow::list(childType); +} + +DataTypePtr SimpleProtoTypeToDataType(const types::GandivaType& gandiva_type) { + switch (gandiva_type) { case types::NONE: return arrow::null(); case types::BOOL: @@ -257,6 +293,16 @@ DataTypePtr ProtoTypeToDataType(const types::ExtGandivaType& ext_type) { return arrow::date32(); case types::DATE64: return arrow::date64(); + default: + std::cerr << "Unknown data type: " << gandiva_type << "\n"; + return nullptr; + } +} + + + +DataTypePtr ProtoTypeToDataType(const types::ExtGandivaType& ext_type) { + switch (ext_type.type()) { case types::DECIMAL: // TODO: error handling return arrow::decimal(ext_type.precision(), ext_type.scale()); @@ -268,24 +314,36 @@ DataTypePtr ProtoTypeToDataType(const types::ExtGandivaType& ext_type) { return ProtoTypeToTimestamp(ext_type); case types::INTERVAL: return ProtoTypeToInterval(ext_type); - case types::FIXED_SIZE_BINARY: case types::LIST: - case types::STRUCT: + return ProtoTypeToList(ext_type); + case types::FIXED_SIZE_BINARY: case types::UNION: case types::DICTIONARY: case types::MAP: std::cerr << "Unhandled data type: " << ext_type.type() << "\n"; return nullptr; - default: - std::cerr << "Unknown data type: " << ext_type.type() << "\n"; + return SimpleProtoTypeToDataType(ext_type.type()); + } +} + +DataTypePtr ProtoTypeToDataType(const types::Field& f) { + const types::ExtGandivaType& ext_type = f.type(); + if (ext_type.type() == types::LIST) { + if (f.children().size() > 0 && f.children()[0].type().type() != types::LIST) { + DataTypePtr childType = ProtoTypeToDataType(f.children()[0].type()); + return arrow::list(childType); + } + std::cerr << "Unhandled list data type: " << ext_type.type() << "\n"; return nullptr; + } else { + return ProtoTypeToDataType(ext_type); } } FieldPtr ProtoTypeToField(const types::Field& f) { const std::string& name = f.name(); - DataTypePtr type = ProtoTypeToDataType(f.type()); + DataTypePtr type = ProtoTypeToDataType(f); bool nullable = true; if (f.has_nullable()) { nullable = f.nullable(); @@ -319,7 +377,7 @@ NodePtr ProtoTypeToFnNode(const types::FunctionNode& node) { children.push_back(n); } - + DataTypePtr return_type = ProtoTypeToDataType(node.returntype()); if (return_type == nullptr) { std::cerr << "Unknown return type for function: " << name << "\n"; @@ -602,7 +660,6 @@ Status make_record_batch_with_buf_addrs(SchemaPtr schema, int num_rows, auto validity = std::shared_ptr( new arrow::Buffer(reinterpret_cast(validity_addr), validity_size)); buffers.push_back(validity); - if (buf_idx >= in_bufs_len) { return Status::Invalid("insufficient number of in_buf_addrs"); } @@ -625,8 +682,61 @@ Status make_record_batch_with_buf_addrs(SchemaPtr schema, int num_rows, buffers.push_back(offsets); } - auto array_data = arrow::ArrayData::Make(field->type(), num_rows, std::move(buffers)); + + + +auto type = field->type(); +auto type_id = type->id(); + if (type_id == arrow::Type::LIST) { + + if (buf_idx >= in_bufs_len) { + return Status::Invalid("insufficient number of in_buf_addrs"); + } + + // add offsets buffer for variable-len fields. + jlong offsets_addr = in_buf_addrs[buf_idx++]; + jlong offsets_size = in_buf_sizes[sz_idx++]; + auto offsets = std::shared_ptr( + new arrow::Buffer(reinterpret_cast(offsets_addr), offsets_size)); + buffers.push_back(offsets); + if (arrow::is_binary_like(type->field(0)->type()->id())) { + // child offsets length is internal data length + 1 + // offsets element is int32 + // so here i just allocate extra 32 bit for extra 1 length + jlong offsets_addr = in_buf_addrs[buf_idx++]; + jlong offsets_size = in_buf_sizes[sz_idx++]; + + auto child_offsets_buffer = std::shared_ptr( new arrow::Buffer(reinterpret_cast(offsets_addr), offsets_size)); + + buffers.push_back(std::move(child_offsets_buffer)); + } + } + + if (type->id() == arrow::Type::LIST) { + jlong offsets_addr = in_buf_addrs[buf_idx++]; + jlong offsets_size = in_buf_sizes[sz_idx++]; + auto data_buffer = std::shared_ptr( new arrow::Buffer(reinterpret_cast(offsets_addr), offsets_size)); + auto internal_type = type->field(0)->type(); + std::shared_ptr child_data; + if (arrow::is_primitive(internal_type->id())) { + child_data = arrow::ArrayData::Make(internal_type, 0, + {std::move(buffers[2]), std::move(data_buffer)}); + } + if (arrow::is_binary_like(internal_type->id())) { + //LR TODO need this for strings I think. + //std::cout << "LR New ArrayData List NYI 2" << std::endl; + //child_data = arrow::ArrayData::Make( + // internal_type, 0, + // {nullptr, std::move(data_buffer), std::move(child_data)}, 0); + } + + auto array_data = arrow::ArrayData::Make(type, num_rows, {std::move(buffers[0]), std::move(buffers[1])}, {child_data}); columns.push_back(array_data); + + } else { + auto array_data = arrow::ArrayData::Make(type, num_rows, std::move(buffers)); + columns.push_back(array_data); + } } *batch = arrow::RecordBatch::Make(schema, num_rows, columns); return Status::OK(); @@ -797,12 +907,15 @@ JNIEXPORT jlong JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_build /// class JavaResizableBuffer : public arrow::ResizableBuffer { public: - JavaResizableBuffer(JNIEnv* env, jobject jexpander, int32_t vector_idx, uint8_t* buffer, - int32_t len) + JavaResizableBuffer(JNIEnv* env, jobject jexpander, jmethodID jmethod, int32_t vector_idx, uint8_t* buffer, + int32_t len, bool isListVec = false) : ResizableBuffer(buffer, len), env_(env), jexpander_(jexpander), - vector_idx_(vector_idx) { + vector_idx_(vector_idx), + method_(jmethod), + isList(isListVec) + { size_ = 0; } @@ -810,27 +923,44 @@ class JavaResizableBuffer : public arrow::ResizableBuffer { Status Reserve(const int64_t new_capacity) override; - private: + public: JNIEnv* env_; jobject jexpander_; + jmethodID method_; int32_t vector_idx_; + bool isList; }; Status JavaResizableBuffer::Reserve(const int64_t new_capacity) { // callback into java to expand the buffer - jobject ret = env_->CallObjectMethod(jexpander_, vector_expander_method_, vector_idx_, + jobject ret = env_->CallObjectMethod(jexpander_, method_, vector_idx_, new_capacity); if (env_->ExceptionCheck()) { env_->ExceptionDescribe(); env_->ExceptionClear(); - return Status::OutOfMemory("buffer expand failed in java"); + std::cout << "Buffer expand failed. New capacity is " << new_capacity << + " vector id " << vector_idx_ << " expander method " << method_ << + " jexpander_ " << jexpander_ << std::endl; + return Status::OutOfMemory("buffer expand failed in java."); } - jlong ret_address = env_->GetLongField(ret, vector_expander_ret_address_); - jlong ret_capacity = env_->GetLongField(ret, vector_expander_ret_capacity_); - data_ = reinterpret_cast(ret_address); - capacity_ = ret_capacity; + if (isList) { + jlong ret_address = env_->GetLongField(ret, list_expander_ret_address_); + jlong ret_capacity = env_->GetLongField(ret, list_expander_ret_capacity_); + jlong valid_address = env_->GetLongField(ret, list_expander_valid_address_); + + data_ = reinterpret_cast(ret_address); + capacity_ = ret_capacity; + validityBuffer = reinterpret_cast(valid_address); + } else { + jlong ret_address = env_->GetLongField(ret, vector_expander_ret_address_); + jlong ret_capacity = env_->GetLongField(ret, vector_expander_ret_capacity_); + + data_ = reinterpret_cast(ret_address); + capacity_ = ret_capacity; + } + return Status::OK(); } @@ -859,7 +989,7 @@ Status JavaResizableBuffer::Resize(const int64_t new_size, bool shrink_to_fit) { JNIEXPORT void JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector( - JNIEnv* env, jobject object, jobject jexpander, jlong module_id, jint num_rows, + JNIEnv* env, jobject object, jobject jexpander, jobject jListExpander, jlong module_id, jint num_rows, jlongArray buf_addrs, jlongArray buf_sizes, jint sel_vec_type, jint sel_vec_rows, jlong sel_vec_addr, jlong sel_vec_size, jlongArray out_buf_addrs, jlongArray out_buf_sizes) { @@ -898,7 +1028,6 @@ Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector( if (!status.ok()) { break; } - std::shared_ptr selection_vector; auto selection_buffer = std::make_shared( reinterpret_cast(sel_vec_addr), sel_vec_size); @@ -925,6 +1054,7 @@ Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector( break; } + std::shared_ptr outBufJava = nullptr; auto ret_types = holder->rettypes(); ArrayDataVector output; int buf_idx = 0; @@ -956,22 +1086,72 @@ Java_org_apache_arrow_gandiva_evaluator_JniWrapper_evaluateProjector( "null"); break; } - buffers.push_back(std::make_shared( - env, jexpander, output_vector_idx, value_buf, data_sz)); + + buffers.push_back(std::make_shared( + env, jexpander, vector_expander_method_, output_vector_idx, value_buf, data_sz)); + } else if (field->type()->id() == arrow::Type::LIST) { + buffers.push_back(std::make_shared( + env, jexpander, vector_expander_method_, output_vector_idx, value_buf, data_sz)); } else { buffers.push_back(std::make_shared(value_buf, data_sz)); } + + + if (field->type()->id() == arrow::Type::LIST) { + std::vector> child_buffers; + + if (jListExpander == nullptr) { + status = Status::Invalid( + "expression has variable len output columns, but the jListExpander object is " + "null"); + break; + } + + data_sz = out_sizes[sz_idx++]; + CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len); + uint8_t* child_offset_buf = reinterpret_cast(out_bufs[buf_idx++]); + child_buffers.push_back(std::make_shared( + env, jListExpander, listvector_expander_method_, output_vector_idx, child_offset_buf, data_sz)); + + data_sz = out_sizes[sz_idx++]; + CHECK_OUT_BUFFER_IDX_AND_BREAK(buf_idx, out_bufs_len); + uint8_t* child_data_buf = reinterpret_cast(out_bufs[buf_idx++]); + + outBufJava = std::make_shared( + env, jListExpander, listvector_expander_method_, output_vector_idx, child_data_buf, data_sz, true); + outBufJava->offsetBuffer = reinterpret_cast(out_bufs[1]); + outBufJava->offsetCapacity = out_sizes[1]; + outBufJava->validityBuffer = reinterpret_cast(out_bufs[2]); + child_buffers.push_back(outBufJava); + + std::shared_ptr dt2 = std::make_shared(); + if (field->type()->id() == arrow::Type::LIST && field->type()->num_fields() > 0) { + dt2 = field->type()->fields()[0]->type(); + } + + auto array_data_child = arrow::ArrayData::Make(dt2, output_row_count, child_buffers); + std::vector> kids; + kids.push_back(array_data_child); + auto array_data = arrow::ArrayData::Make(field->type(), output_row_count, buffers, kids); + array_data->child_data = std::move(kids); + output.push_back(array_data); + ++output_vector_idx; + } else { auto array_data = arrow::ArrayData::Make(field->type(), output_row_count, buffers); output.push_back(array_data); ++output_vector_idx; + } + } if (!status.ok()) { break; } + status = holder->projector()->Evaluate(*in_batch, selection_vector.get(), output); } while (0); + env->ReleaseLongArrayElements(buf_addrs, in_buf_addrs, JNI_ABORT); env->ReleaseLongArrayElements(buf_sizes, in_buf_sizes, JNI_ABORT); env->ReleaseLongArrayElements(out_buf_addrs, out_bufs, JNI_ABORT); @@ -1061,7 +1241,7 @@ JNIEXPORT jlong JNICALL Java_org_apache_arrow_gandiva_evaluator_JniWrapper_build // good to invoke the filter builder now status = Filter::Make(schema_ptr, condition_ptr, config, sec_cache, &filter); if (!status.ok()) { - ss << "Failed to make LLVM module due to " << status.message() << "\n"; + ss << "Failed to make LLVM module [2] due to " << status.message() << "\n"; releaseFilterInput(schema_arr, schema_bytes, condition_arr, condition_bytes, env); goto err_out; } diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java index 0155af08234ad..6abc6719d63e6 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistry.java @@ -17,9 +17,11 @@ package org.apache.arrow.gandiva.evaluator; +import java.util.ArrayList; import java.util.List; import java.util.Set; +import org.apache.arrow.flatbuf.Type; import org.apache.arrow.gandiva.exceptions.GandivaException; import org.apache.arrow.gandiva.ipc.GandivaTypes; import org.apache.arrow.gandiva.ipc.GandivaTypes.ExtGandivaType; @@ -32,7 +34,6 @@ import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.pojo.ArrowType; -import com.google.common.collect.Lists; import com.google.common.collect.Sets; import com.google.protobuf.InvalidProtocolBufferException; @@ -116,12 +117,20 @@ private static Set getSupportedFunctionsFromGandiva() throws String functionName = protoFunctionSignature.getName(); ArrowType returnType = getArrowType(protoFunctionSignature.getReturnType()); - List paramTypes = Lists.newArrayList(); + ArrowType returnListType = getArrowTypeSimple(protoFunctionSignature.getReturnType().getListType()); + List> paramTypes = new ArrayList>(); for (ExtGandivaType type : protoFunctionSignature.getParamTypesList()) { - paramTypes.add(getArrowType(type)); + ArrowType paramType = getArrowType(type); + ArrowType paramListType = getArrowTypeSimple(type.getListType()); + List paramArrowList = new ArrayList(); + paramArrowList.add(paramType); + if (paramType.getTypeID().getFlatbufID() == Type.List) { + paramArrowList.add(paramListType); + } + paramTypes.add(paramArrowList); } FunctionSignature functionSignature = new FunctionSignature(functionName, - returnType, paramTypes); + returnType, returnListType, paramTypes); supportedTypes.add(functionSignature); } } catch (InvalidProtocolBufferException invalidProtException) { @@ -130,8 +139,8 @@ private static Set getSupportedFunctionsFromGandiva() throws return supportedTypes; } - private static ArrowType getArrowType(ExtGandivaType type) { - switch (type.getType().getNumber()) { + private static ArrowType getArrowTypeSimple(GandivaType type) { + switch (type.getNumber()) { case GandivaType.BOOL_VALUE: return ArrowType.Bool.INSTANCE; case GandivaType.UINT8_VALUE: @@ -164,25 +173,15 @@ private static ArrowType getArrowType(ExtGandivaType type) { return new ArrowType.Date(DateUnit.DAY); case GandivaType.DATE64_VALUE: return new ArrowType.Date(DateUnit.MILLISECOND); - case GandivaType.TIMESTAMP_VALUE: - return new ArrowType.Timestamp(mapArrowTimeUnit(type.getTimeUnit()), null); - case GandivaType.TIME32_VALUE: - return new ArrowType.Time(mapArrowTimeUnit(type.getTimeUnit()), - BIT_WIDTH_32); - case GandivaType.TIME64_VALUE: - return new ArrowType.Time(mapArrowTimeUnit(type.getTimeUnit()), - BIT_WIDTH_64); case GandivaType.NONE_VALUE: return new ArrowType.Null(); case GandivaType.DECIMAL_VALUE: return new ArrowType.Decimal(0, 0, 128); - case GandivaType.INTERVAL_VALUE: - return new ArrowType.Interval(mapArrowIntervalUnit(type.getIntervalType())); + case GandivaType.LIST_VALUE: + return new ArrowType.List(); case GandivaType.FIXED_SIZE_BINARY_VALUE: case GandivaType.MAP_VALUE: case GandivaType.DICTIONARY_VALUE: - case GandivaType.LIST_VALUE: - case GandivaType.STRUCT_VALUE: case GandivaType.UNION_VALUE: default: assert false; @@ -190,6 +189,23 @@ private static ArrowType getArrowType(ExtGandivaType type) { return null; } + private static ArrowType getArrowType(ExtGandivaType type) { + switch (type.getType().getNumber()) { + case GandivaType.TIMESTAMP_VALUE: + return new ArrowType.Timestamp(mapArrowTimeUnit(type.getTimeUnit()), null); + case GandivaType.TIME32_VALUE: + return new ArrowType.Time(mapArrowTimeUnit(type.getTimeUnit()), + BIT_WIDTH_32); + case GandivaType.TIME64_VALUE: + return new ArrowType.Time(mapArrowTimeUnit(type.getTimeUnit()), + BIT_WIDTH_64); + case GandivaType.INTERVAL_VALUE: + return new ArrowType.Interval(mapArrowIntervalUnit(type.getIntervalType())); + default: + return getArrowTypeSimple(type.getType()); + } + } + private static TimeUnit mapArrowTimeUnit(GandivaTypes.TimeUnit timeUnit) { switch (timeUnit.getNumber()) { case GandivaTypes.TimeUnit.MICROSEC_VALUE: diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/FunctionSignature.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/FunctionSignature.java index d01881843de47..c5c6aeb5372b8 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/FunctionSignature.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/FunctionSignature.java @@ -17,6 +17,7 @@ package org.apache.arrow.gandiva.evaluator; +import java.util.ArrayList; import java.util.List; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -30,13 +31,18 @@ public class FunctionSignature { private final String name; private final ArrowType returnType; - private final List paramTypes; + private final ArrowType returnListType; + private final List> paramTypes; public ArrowType getReturnType() { return returnType; } - public List getParamTypes() { + public ArrowType getReturnListType() { + return returnListType; + } + + public List> getParamTypes() { return paramTypes; } @@ -48,14 +54,36 @@ public String getName() { * Ctor. * @param name - name of the function. * @param returnType - data type of return + * @param returnListType optional list type * @param paramTypes - data type of input args. */ - public FunctionSignature(String name, ArrowType returnType, List paramTypes) { + public FunctionSignature(String name, ArrowType returnType, ArrowType returnListType, + List> paramTypes) { this.name = name; this.returnType = returnType; + this.returnListType = returnListType; this.paramTypes = paramTypes; } + /** + * Ctor. + * @param name - name of the function. + * @param returnType - data type of return + * @param paramTypes - data type of input args. + */ + public FunctionSignature(String name, ArrowType returnType, List paramTypes) { + this.name = name; + this.returnType = returnType; + this.returnListType = ArrowType.Null.INSTANCE; + this.paramTypes = new ArrayList>(); + for (ArrowType paramType : paramTypes) { + List paramArrowList = new ArrayList(); + paramArrowList.add(paramType); + this.paramTypes.add(paramArrowList); + } + + } + /** * Override equals. * @param signature - signature to compare @@ -71,12 +99,13 @@ public boolean equals(Object signature) { final FunctionSignature other = (FunctionSignature) signature; return this.name.equalsIgnoreCase(other.name) && Objects.equal(this.returnType, other.returnType) && + Objects.equal(this.returnListType, other.returnListType) && Objects.equal(this.paramTypes, other.paramTypes); } @Override public int hashCode() { - return Objects.hashCode(this.name.toLowerCase(), this.returnType, this.paramTypes); + return Objects.hashCode(this.name.toLowerCase(), this.returnType, this.returnListType, this.paramTypes); } @Override @@ -84,6 +113,7 @@ public String toString() { return MoreObjects.toStringHelper(this) .add("name ", name) .add("return type ", returnType) + .add("return list type", returnListType) .add("param types ", paramTypes) .toString(); diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java index 293d51a87a5fd..f883ed7081547 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java @@ -50,6 +50,7 @@ native long buildProjector(Object cache, byte[] schemaBuf, byte[] exprListBuf, * and store the output in ValueVectors. Throws an exception in case of errors * * @param expander VectorExpander object. Used for callbacks from cpp. + * @param listExpander ListVectorExpander object. Used for callbacks from cpp. * @param moduleId moduleId representing expressions. Created using a call to * buildNativeCode * @param numRows Number of rows in the record batch @@ -63,7 +64,7 @@ native long buildProjector(Object cache, byte[] schemaBuf, byte[] exprListBuf, * @param outSizes The allocated size of the output buffers. On successful evaluation, * the result is stored in the output buffers */ - native void evaluateProjector(Object expander, long moduleId, int numRows, + native void evaluateProjector(Object expander, Object listExpander, long moduleId, int numRows, long[] bufAddrs, long[] bufSizes, int selectionVectorType, int selectionVectorSize, long selectionVectorBufferAddr, long selectionVectorBufferSize, diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ListVectorExpander.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ListVectorExpander.java new file mode 100644 index 0000000000000..1d02f38a4d591 --- /dev/null +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/ListVectorExpander.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.gandiva.evaluator; + +import org.apache.arrow.vector.complex.ListVector; + +/** + * This class provides the functionality to expand output ListVectors using a callback mechanism from + * gandiva. + */ +public class ListVectorExpander { + private final ListVector[] bufferVectors; + public static final int valueBufferIndex = 1; + public static final int validityBufferIndex = 0; + + public ListVectorExpander(ListVector[] bufferVectors) { + this.bufferVectors = bufferVectors; + } + + /** + * Result of ListVector expansion. + */ + public static class ExpandResult { + public long address; + public long capacity; + public long validityaddress; + + /** + * Result of expanding the buffer. + * @param address Data buffer address + * @param capacity Capacity + * @param validAdd Validity buffer address + * + */ + public ExpandResult(long address, long capacity, long validAdd) { + this.address = address; + this.capacity = capacity; + this.validityaddress = validAdd; + } + } + + /** + * Expand vector at specified index. This is used as a back call from jni, and is only + * relevant for ListVectors. + * + * @param index index of buffer in the list passed to jni. + * @param toCapacity the size to which the buffer should be expanded to. + * + * @return address and size of the buffer after expansion. + */ + public ExpandResult expandOutputVectorAtIndex(int index, long toCapacity) { + if (index >= bufferVectors.length || bufferVectors[index] == null) { + throw new IllegalArgumentException("invalid index " + index); + } + + ListVector vector = bufferVectors[index]; + while (vector.getDataVector().getFieldBuffers().get(ListVectorExpander.valueBufferIndex).capacity() < toCapacity) { + //Just realloc the data vector. + vector.getDataVector().reAlloc(); + } + + return new ExpandResult( + vector.getDataVector().getFieldBuffers().get(ListVectorExpander.valueBufferIndex).memoryAddress(), + vector.getDataVector().getFieldBuffers().get(ListVectorExpander.valueBufferIndex).capacity(), + vector.getDataVector().getFieldBuffers().get(ListVectorExpander.validityBufferIndex).memoryAddress()); + } + +} diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java index c146fce26c150..686539a169f57 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/evaluator/Projector.java @@ -22,16 +22,15 @@ import org.apache.arrow.gandiva.exceptions.EvaluatorClosedException; import org.apache.arrow.gandiva.exceptions.GandivaException; -import org.apache.arrow.gandiva.exceptions.UnsupportedTypeException; import org.apache.arrow.gandiva.expression.ArrowTypeHelper; import org.apache.arrow.gandiva.expression.ExpressionTree; import org.apache.arrow.gandiva.ipc.GandivaTypes; import org.apache.arrow.gandiva.ipc.GandivaTypes.SelectionVectorType; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.vector.BaseVariableWidthVector; -import org.apache.arrow.vector.FixedWidthVector; import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VariableWidthVector; +import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.ipc.message.ArrowBuffer; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.types.pojo.Schema; @@ -352,18 +351,21 @@ private void evaluate(int numRows, List buffers, List buf boolean hasVariableWidthColumns = false; BaseVariableWidthVector[] resizableVectors = new BaseVariableWidthVector[outColumns.size()]; + ListVector[] resizableListVectors = new ListVector[outColumns.size()]; + long[] outAddrs = new long[3 * outColumns.size()]; long[] outSizes = new long[3 * outColumns.size()]; + idx = 0; int outColumnIdx = 0; + final int listVectorBufferCount = 5; for (ValueVector valueVector : outColumns) { - boolean isFixedWith = valueVector instanceof FixedWidthVector; - boolean isVarWidth = valueVector instanceof VariableWidthVector; - if (!isFixedWith && !isVarWidth) { - throw new UnsupportedTypeException( - "Unsupported value vector type " + valueVector.getField().getFieldType()); + if (valueVector instanceof ListVector) { + outAddrs = new long[listVectorBufferCount * outColumns.size()]; + outSizes = new long[listVectorBufferCount * outColumns.size()]; } + boolean isVarWidth = valueVector instanceof VariableWidthVector; outAddrs[idx] = valueVector.getValidityBuffer().memoryAddress(); outSizes[idx++] = valueVector.getValidityBuffer().capacity(); if (isVarWidth) { @@ -374,19 +376,45 @@ private void evaluate(int numRows, List buffers, List buf // save vector to allow for resizing. resizableVectors[outColumnIdx] = (BaseVariableWidthVector) valueVector; } - outAddrs[idx] = valueVector.getDataBuffer().memoryAddress(); - outSizes[idx++] = valueVector.getDataBuffer().capacity(); + if (valueVector instanceof ListVector) { + hasVariableWidthColumns = true; + resizableListVectors[outColumnIdx] = (ListVector) valueVector; + List fieldBufs = ((ListVector) valueVector).getDataVector().getFieldBuffers(); + outAddrs[idx] = valueVector.getOffsetBuffer().memoryAddress(); + outSizes[idx++] = valueVector.getOffsetBuffer().capacity(); + + //vector valid + outAddrs[idx] = ((ListVector) valueVector).getDataVector().getFieldBuffers() + .get(ListVectorExpander.validityBufferIndex).memoryAddress(); + outSizes[idx++] = ((ListVector) valueVector).getDataVector().getFieldBuffers() + .get(ListVectorExpander.validityBufferIndex).capacity(); + + //vector offset + outAddrs[idx] = ((ListVector) valueVector).getDataVector().getFieldBuffers() + .get(ListVectorExpander.valueBufferIndex).memoryAddress(); + outSizes[idx++] = ((ListVector) valueVector).getDataVector().getFieldBuffers() + .get(ListVectorExpander.valueBufferIndex).capacity(); + } else { + outAddrs[idx] = valueVector.getDataBuffer().memoryAddress(); + outSizes[idx++] = valueVector.getDataBuffer().capacity(); + } valueVector.setValueCount(selectionVectorRecordCount); outColumnIdx++; } - wrapper.evaluateProjector( hasVariableWidthColumns ? new VectorExpander(resizableVectors) : null, + hasVariableWidthColumns ? new ListVectorExpander(resizableListVectors) : null, this.moduleId, numRows, bufAddrs, bufSizes, selectionVectorType, selectionVectorRecordCount, selectionVectorAddr, selectionVectorSize, outAddrs, outSizes); + + for (ValueVector valueVector : outColumns) { + if (valueVector instanceof ListVector) { + ((ListVector) valueVector).setLastSet(selectionVectorRecordCount - 1); + } + } } /** diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ArrowTypeHelper.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ArrowTypeHelper.java index 90f8684b455a8..fd1be362b8404 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ArrowTypeHelper.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/ArrowTypeHelper.java @@ -143,6 +143,15 @@ private static void initArrowTypeDate(ArrowType.Date dateType, } } + private static void initArrowTypeList(ArrowType.List listType, + ArrowType subType, + GandivaTypes.ExtGandivaType.Builder builder) throws GandivaException { + if (subType != null) { + builder.setListType(arrowTypeToProtobuf(subType).getType()); + } + builder.setType(GandivaTypes.GandivaType.LIST); + } + private static void initArrowTypeTime(ArrowType.Time timeType, GandivaTypes.ExtGandivaType.Builder builder) { short timeUnit = timeType.getUnit().getFlatbufID(); @@ -227,11 +236,13 @@ private static void initArrowTypeInterval(ArrowType.Interval interval, * Converts an arrow type into a protobuf. * * @param arrowType Arrow type to be converted + * @param subType optional arrow type for list/complex types + * @param builder the builder to use * @return Protobuf representing the arrow type */ - public static GandivaTypes.ExtGandivaType arrowTypeToProtobuf(ArrowType arrowType) + public static GandivaTypes.ExtGandivaType arrowTypeToProtobuf(ArrowType arrowType, ArrowType subType, + GandivaTypes.ExtGandivaType.Builder builder) throws GandivaException { - GandivaTypes.ExtGandivaType.Builder builder = GandivaTypes.ExtGandivaType.newBuilder(); byte typeId = arrowType.getTypeID().getFlatbufID(); switch (typeId) { @@ -284,6 +295,7 @@ public static GandivaTypes.ExtGandivaType arrowTypeToProtobuf(ArrowType arrowTyp break; } case Type.List: { // 12 + ArrowTypeHelper.initArrowTypeList((ArrowType.List) arrowType, subType, builder); break; } case Type.Struct_: { // 13 @@ -315,6 +327,31 @@ public static GandivaTypes.ExtGandivaType arrowTypeToProtobuf(ArrowType arrowTyp return builder.build(); } + + /** + * Converts an arrow type into a protobuf. + * + * @param arrowType Arrow type to be converted + * @param f field optional for list/complex types + * @return Protobuf representing the arrow type + */ + public static GandivaTypes.ExtGandivaType arrowTypeToProtobuf(ArrowType arrowType, ArrowType f) + throws GandivaException { + GandivaTypes.ExtGandivaType.Builder builder = GandivaTypes.ExtGandivaType.newBuilder(); + return arrowTypeToProtobuf(arrowType, f, builder); + } + + /** + * Converts an arrow type into a protobuf. + * + * @param arrowType Arrow type to be converted + * @return Protobuf representing the arrow type + */ + public static GandivaTypes.ExtGandivaType arrowTypeToProtobuf(ArrowType arrowType) + throws GandivaException { + return arrowTypeToProtobuf(arrowType, null); + } + /** * Converts an arrow field object to a protobuf. * @param field Arrow field to be converted @@ -323,12 +360,21 @@ public static GandivaTypes.ExtGandivaType arrowTypeToProtobuf(ArrowType arrowTyp public static GandivaTypes.Field arrowFieldToProtobuf(Field field) throws GandivaException { GandivaTypes.Field.Builder builder = GandivaTypes.Field.newBuilder(); builder.setName(field.getName()); - builder.setType(ArrowTypeHelper.arrowTypeToProtobuf(field.getType())); builder.setNullable(field.isNullable()); + ArrowType subType = null; + if (field.getChildren().size() > 0 && field.getChildren().get(0) + .getType().getTypeID().getFlatbufID() != Type.List) { + subType = field.getChildren().get(0).getType(); + } + + builder.setType(ArrowTypeHelper.arrowTypeToProtobuf(field.getType(), subType)); for (Field child : field.getChildren()) { - builder.addChildren(ArrowTypeHelper.arrowFieldToProtobuf(child)); + if (child.getType() != ArrowType.Null.INSTANCE) { + builder.addChildren(ArrowTypeHelper.arrowFieldToProtobuf(child)); + } } + return builder.build(); } diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FunctionNode.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FunctionNode.java index ead1e146d5d8c..e092facfd69ba 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FunctionNode.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/FunctionNode.java @@ -19,9 +19,12 @@ import java.util.List; +import org.apache.arrow.flatbuf.Type; import org.apache.arrow.gandiva.exceptions.GandivaException; import org.apache.arrow.gandiva.ipc.GandivaTypes; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; + /** * Node representing an arbitrary function in an expression. @@ -30,18 +33,40 @@ class FunctionNode implements TreeNode { private final String function; private final List children; private final ArrowType retType; + private final ArrowType retListType; - FunctionNode(String function, List children, ArrowType retType) { + FunctionNode(String function, List children, Field inField) { + this.function = function; + this.children = children; + this.retType = inField.getType(); + if (inField.getChildren().size() > 0 && inField.getChildren().get(0) + .getType().getTypeID().getFlatbufID() != Type.List) { + this.retListType = inField.getChildren().get(0).getType(); + } else { + this.retListType = null; + } + + } + + FunctionNode(String function, List children, ArrowType inType) { + this.function = function; + this.children = children; + this.retType = inType; + this.retListType = null; + } + + FunctionNode(String function, List children, ArrowType inType, ArrowType listType) { this.function = function; this.children = children; - this.retType = retType; + this.retType = inType; + this.retListType = listType; } @Override public GandivaTypes.TreeNode toProtobuf() throws GandivaException { GandivaTypes.FunctionNode.Builder fnNode = GandivaTypes.FunctionNode.newBuilder(); fnNode.setFunctionName(function); - fnNode.setReturnType(ArrowTypeHelper.arrowTypeToProtobuf(retType)); + fnNode.setReturnType(ArrowTypeHelper.arrowTypeToProtobuf(retType, retListType)); for (TreeNode arg : children) { fnNode.addInArgs(arg.toProtobuf()); diff --git a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java index 8656e886aae24..f8337a25f8377 100644 --- a/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java +++ b/java/gandiva/src/main/java/org/apache/arrow/gandiva/expression/TreeBuilder.java @@ -97,6 +97,35 @@ public static TreeNode makeFunction(String function, return new FunctionNode(function, children, retType); } + /** + * Invoke this function to create a node representing a function. + * + * @param function Name of the function, e.g. add + * @param children The arguments to the function + * @param retType The type of the return value of the operator + * @param listType The type of the list return value of the operator + * @return Node representing a function + */ + public static TreeNode makeFunction(String function, + List children, + ArrowType retType, ArrowType listType) { + return new FunctionNode(function, children, retType, listType); + } + + /** + * Invoke this function to create a node representing a function. + * + * @param function Name of the function, e.g. add + * @param children The arguments to the function + * @param retType The field of the return value of the operator, could be a complex type. + * @return Node representing a function + */ + public static TreeNode makeFunction(String function, + List children, + Field retType) { + return new FunctionNode(function, children, retType); + } + /** * Invoke this function to create a node representing an if-clause. * @@ -161,7 +190,7 @@ public static ExpressionTree makeExpression(String function, children.add(makeField(field)); } - TreeNode root = makeFunction(function, children, resultField.getType()); + TreeNode root = makeFunction(function, children, resultField); return makeExpression(root, resultField); }