Skip to content

Commit

Permalink
Almost working
Browse files Browse the repository at this point in the history
  • Loading branch information
lriggs committed Sep 1, 2023
1 parent b3e2c54 commit b52e5c2
Show file tree
Hide file tree
Showing 16 changed files with 543 additions and 49 deletions.
3 changes: 3 additions & 0 deletions cpp/src/gandiva/annotator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ void Annotator::PrepareBuffersForField(const FieldDescriptor& desc,
reinterpret_cast<uint8_t*>(array_data.buffers[buffer_idx].get());
eval_batch->SetBuffer(desc.data_buffer_ptr_idx(), data_buf_ptr, array_data.offset);
} else {
std::cout << "LR Annotator::PrepareBuffersForField is_output index " << desc.data_buffer_ptr_idx() << std::endl;

// list data buffer is in child data buffer
uint8_t* data_buf_ptr = reinterpret_cast<uint8_t*>(
array_data.child_data.at(0)->buffers[buffer_idx].get());
Expand Down Expand Up @@ -181,6 +183,7 @@ EvalBatchPtr Annotator::PrepareEvalBatch(const arrow::RecordBatch& record_batch,
}

// Fill in the entries for the output fields.
std::cout << "LR PrepareEvalBatch preparing output fields" << std::endl;
int idx = 0;
for (auto& arraydata : out_vector) {
const FieldDescriptorPtr& desc = out_descs_.at(idx);
Expand Down
60 changes: 59 additions & 1 deletion cpp/src/gandiva/array_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
#include "gandiva/array_ops.h"

#include <iostream>
#include <string>

#include "arrow/util/value_parsing.h"

#include "gandiva/gdv_function_stubs.h"
#include "gandiva/engine.h"
#include "gandiva/exported_funcs.h"

Expand Down Expand Up @@ -50,7 +53,25 @@ bool array_int32_contains_int32(int64_t context_ptr, const int32_t* entry_buf,
std::cout << "LR array_int32_contains_int32 offset length=" << entry_offsets_len << std::endl;
for (int i = 0; i < entry_offsets_len; i++) {
std::cout << "LR going to check " << entry_buf + i << std::endl;
int32_t entry_len = *(entry_buf + i);
//LR TODO
//int32_t entry_len = *(entry_buf + i);
//coming as int64 for some reason. *2
int32_t entry_len = *(entry_buf + (i * 2));
std::cout << "LR checking value " << entry_len << " against target " << contains_data << std::endl;
if (entry_len == contains_data) {
return true;
}
}
return false;
}

bool array_int64_contains_int64(int64_t context_ptr, const int64_t* entry_buf,
int32_t entry_offsets_len,
int64_t contains_data) {
std::cout << "LR array_int64_contains_int64 offset length=" << entry_offsets_len << std::endl;
for (int i = 0; i < entry_offsets_len; i++) {
std::cout << "LR going to check " << entry_buf + i << std::endl;
int64_t entry_len = *(entry_buf + (i*2)); //LR TODO sizeof int64?
std::cout << "LR checking value " << entry_len << " against target " << contains_data << std::endl;
if (entry_len == contains_data) {
return true;
Expand All @@ -59,6 +80,23 @@ bool array_int32_contains_int32(int64_t context_ptr, const int32_t* entry_buf,
return false;
}


int32_t* array_int32_make_array(int64_t context_ptr, int32_t contains_data, int32_t* out_len) {
std::cout << "LR array_int32_make_array offset data=" << contains_data << std::endl;

int integers[] = { 1, 2, 3, contains_data, 5 };
*out_len = 5;// * 4;
//length is number of items, but buffers must account for byte size.
uint8_t* ret = gdv_fn_context_arena_malloc(context_ptr, *out_len * 4);
memcpy(ret, integers, *out_len * 4);
std::cout << "LR made a buffer length" << *out_len * 4 << " item 3 is = " << int32_t(ret[3*4]) << std::endl;


//return reinterpret_cast<int32_t*>(ret);
return reinterpret_cast<int32_t*>(ret);
}


int64_t array_utf8_length(int64_t context_ptr, const char* entry_buf,
int32_t* entry_child_offsets, int32_t entry_offsets_len) {
int64_t res = entry_offsets_len;
Expand Down Expand Up @@ -98,5 +136,25 @@ void ExportedArrayFunctions::AddMappings(Engine* engine) const {
engine->AddGlobalMappingForFunc("array_int32_contains_int32",
types->i1_type() /*return_type*/, args,
reinterpret_cast<void*>(array_int32_contains_int32));

args = {types->i64_type(), // int64_t execution_context
types->i64_ptr_type(), // int8_t* data ptr
types->i32_type(), // int32_t child offsets length
types->i64_type()}; // int32_t contains data length

engine->AddGlobalMappingForFunc("array_int64_contains_int64",
types->i1_type() /*return_type*/, args,
reinterpret_cast<void*>(array_int64_contains_int64));


args = {types->i64_type(), // int64_t execution_context
types->i32_type(),
types->i32_ptr_type()}; // int32_t contains data length

engine->AddGlobalMappingForFunc("array_int32_make_array",
types->i32_ptr_type(), args,
reinterpret_cast<void*>(array_int32_make_array));


}
} // namespace gandiva
13 changes: 13 additions & 0 deletions cpp/src/gandiva/array_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@

#include "gandiva/visibility.h"

namespace llvm {
class VectorType;
}

/// Array functions that can be accessed from LLVM.
extern "C" {
GANDIVA_EXPORT
Expand All @@ -34,4 +38,13 @@ GANDIVA_EXPORT
bool array_int32_contains_int32(int64_t context_ptr, const int32_t* entry_buf,
int32_t entry_offsets_len,
int32_t contains_data);
GANDIVA_EXPORT
bool array_int64_contains_int64(int64_t context_ptr, const int64_t* entry_buf,
int32_t entry_offsets_len,
int64_t contains_data);

GANDIVA_EXPORT
int32_t* array_int32_make_array(int64_t context_ptr,
int32_t contains_data,
int32_t* out_len);
}
4 changes: 3 additions & 1 deletion cpp/src/gandiva/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ Status Engine::FinalizeModule() {
if (!cached_) {
ARROW_RETURN_NOT_OK(RemoveUnusedFunctions());

/*
//LR Turning this off seems to provide better error messages with compilation/generation failures.
if (optimize_) {
// misc passes to allow for inlining, vectorization, ..
std::unique_ptr<llvm::legacy::PassManager> pass_manager(
Expand All @@ -324,7 +326,7 @@ Status Engine::FinalizeModule() {
pass_builder.populateModulePassManager(*pass_manager);
pass_manager->run(*module_);
}

*/
ARROW_RETURN_IF(llvm::verifyModule(*module_, &llvm::errs()),
Status::CodeGenError("Module verification failed after optimizer"));
}
Expand Down
9 changes: 9 additions & 0 deletions cpp/src/gandiva/function_registry_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ std::vector<NativeFunction> GetArrayFunctionRegistry() {
NativeFunction("array_containsGandiva", {}, DataTypeVector{list(int32()), int32()},
boolean(), kResultNullIfNull, "array_int32_contains_int32",
NativeFunction::kNeedsContext),
NativeFunction("array_contains", {}, DataTypeVector{list(int32()), int32()},
boolean(), kResultNullIfNull, "array_int32_contains_int32",
NativeFunction::kNeedsContext),
NativeFunction("array_makeGandiva", {}, DataTypeVector{int32()},
list(int32()), kResultNullIfNull, "array_int32_make_array",
NativeFunction::kNeedsContext),
/*NativeFunction("array_containsGandiva", {}, DataTypeVector{list(int64()), int64()},
boolean(), kResultNullIfNull, "array_int64_contains_int64",
NativeFunction::kNeedsContext),*/
};
return array_fn_registry_;
}
Expand Down
69 changes: 57 additions & 12 deletions cpp/src/gandiva/llvm_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ Status LLVMGenerator::Add(const ExpressionPtr expr, const FieldDescriptorPtr out
std::unique_ptr<CompiledExpr> compiled_expr(new CompiledExpr(value_validity, output));
std::string fn_name = "expr_" + std::to_string(idx) + "_" +
std::to_string(static_cast<int>(selection_vector_mode_));
std::cout << "LR LLVMGenerator::Add " << fn_name << std::endl;
if (!cached_) {
ARROW_RETURN_NOT_OK(engine_->LoadFunctionIRs());
ARROW_RETURN_NOT_OK(CodeGenExprValue(value_validity->value_expr(),
Expand All @@ -100,6 +101,7 @@ Status LLVMGenerator::Add(const ExpressionPtr expr, const FieldDescriptorPtr out
}
compiled_expr->SetFunctionName(selection_vector_mode_, fn_name);
compiled_exprs_.push_back(std::move(compiled_expr));
std::cout << "LR LLVMGenerator::Add Done" << std::endl;
return Status::OK();
}

Expand All @@ -108,21 +110,27 @@ Status LLVMGenerator::Add(const ExpressionPtr expr, const FieldDescriptorPtr out
Status LLVMGenerator::Build(const ExpressionVector& exprs, SelectionVector::Mode mode) {
selection_vector_mode_ = mode;

std::cout << "LR LLVMGenerator::Build " << std::endl;
for (auto& expr : exprs) {
auto output = annotator_.AddOutputFieldDescriptor(expr->result());
ARROW_RETURN_NOT_OK(Add(expr, output));
}

std::cout << "LR LLVMGenerator::Build 2" << std::endl;
//Too much logging. needle in haystack?
//std::cout << "LR LLVMGenerator::Build 2 IR is " << engine_->DumpIR() << std::endl;
// Compile and inject into the process' memory the generated function.
ARROW_RETURN_NOT_OK(engine_->FinalizeModule());

std::cout << "LR LLVMGenerator::Build FinalizeModule" << std::endl;

// setup the jit functions for each expression.
for (auto& compiled_expr : compiled_exprs_) {
auto fn_name = compiled_expr->GetFunctionName(mode);
auto jit_fn = reinterpret_cast<EvalFunc>(engine_->CompiledFunction(fn_name));
compiled_expr->SetJITFunction(selection_vector_mode_, jit_fn);
}

std::cout << "LR LLVMGenerator::Build Done" << std::endl;
return Status::OK();
}

Expand All @@ -144,13 +152,12 @@ Status LLVMGenerator::Execute(const arrow::RecordBatch& record_batch,
const SelectionVector* selection_vector,
const ArrayDataVector& output_vector) const {
DCHECK_GT(record_batch.num_rows(), 0);
int jello = 0;
std::cout << "LR LLVMGenerator::Execute " << jello++ << std::endl;
std::cout << "LR LLVMGenerator::Execute 1"<< std::endl;

auto eval_batch = annotator_.PrepareEvalBatch(record_batch, output_vector);
DCHECK_GT(eval_batch->GetNumBuffers(), 0);

std::cout << "LR LLVMGenerator::Execute " << jello++ << std::endl;
std::cout << "LR LLVMGenerator::Execute 2" << std::endl;
auto mode = SelectionVector::MODE_NONE;
if (selection_vector != nullptr) {
mode = selection_vector->GetMode();
Expand All @@ -160,7 +167,7 @@ Status LLVMGenerator::Execute(const arrow::RecordBatch& record_batch,
selection_vector_mode_, " received vector with mode ", mode);
}

std::cout << "LR LLVMGenerator::Execute " << jello++ << std::endl;
std::cout << "LR LLVMGenerator::Execute 3" << std::endl;
for (auto& compiled_expr : compiled_exprs_) {
// generate data/offset vectors.
const uint8_t* selection_buffer = nullptr;
Expand All @@ -170,7 +177,7 @@ Status LLVMGenerator::Execute(const arrow::RecordBatch& record_batch,
num_output_rows = selection_vector->GetNumSlots();
}

std::cout << "LR LLVMGenerator::Execute A" << jello++ << std::endl;
std::cout << "LR LLVMGenerator::Execute A1" << std::endl;
EvalFunc jit_function = compiled_expr->GetJITFunction(mode);
jit_function(eval_batch->GetBufferArray(), eval_batch->GetBufferOffsetArray(),
eval_batch->GetLocalBitMapArray(), annotator_.GetHolderPointersArray(),
Expand All @@ -182,7 +189,7 @@ Status LLVMGenerator::Execute(const arrow::RecordBatch& record_batch,
eval_batch->GetExecutionContext()->has_error(),
Status::ExecutionError(eval_batch->GetExecutionContext()->get_error()));

std::cout << "LR LLVMGenerator::Execute A" << jello++ << std::endl;
std::cout << "LR LLVMGenerator::Execute A2" << std::endl;
// generate validity vectors.
ComputeBitMapsForExpr(*compiled_expr, selection_vector, eval_batch.get());
}
Expand Down Expand Up @@ -305,7 +312,8 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count,
FieldDescriptorPtr output, int suffix_idx,
std::string& fn_name,
SelectionVector::Mode selection_vector_mode) {
std::cout << "LR CodeGenExprValue" << std::endl;
std::cout << "LR CodeGenExprValue for output field " << output->Name()
<< " type " << output->Type()->ToString() << " output type id " << output->Type()->id() << std::endl;
try {
llvm::IRBuilder<>* builder = ir_builder();
// Create fn prototype :
Expand Down Expand Up @@ -404,6 +412,7 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count,
}

// The visitor can add code to both the entry/loop blocks.
std::cout << "LR calling visitor to get output data for [" << fn_name << "]" << std::endl;
Visitor visitor(this, fn, loop_entry, arg_addrs, arg_local_bitmaps, arg_holder_ptrs,
slice_offsets, arg_context_ptr, position_var);
value_expr->Accept(visitor);
Expand Down Expand Up @@ -441,6 +450,7 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count,
} else if (output_type_id == arrow::Type::LIST) {
auto output_list_internal_type = output->Type()->field(0)->type()->id();
std::cout << "LR creating list type to store the result with internal type " << output_list_internal_type << std::endl;

if (arrow::is_binary_like(output_list_internal_type)) {
auto output_list_value = std::dynamic_pointer_cast<ListLValue>(output_value);
llvm::Value* child_output_offset_ref = GetChildOffsetsReference(
Expand All @@ -451,6 +461,19 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count,
child_output_offset_ref, loop_var, output_list_value->data(),
output_list_value->child_offsets(), output_list_value->offsets_length()});
} else if (output_list_internal_type == arrow::Type::INT32) {


std::string str1;
llvm::raw_string_ostream output1(str1);
output_value->data()->print(output1);

std::string str2;
llvm::raw_string_ostream output2(str2);
output_value->length()->print(output2);


std::cout << "LR gdv_fn_populate_list_int32_t_vector params are " << arg_context_ptr << "," << output_buffer_ptr_ref << ","
<< output_offset_ref << "," << loop_var << "[[" << str1 << "]] [[" << str2 << "]]" << std::endl;
AddFunctionCall("gdv_fn_populate_list_int32_t_vector", types()->i32_type(),
{arg_context_ptr, output_buffer_ptr_ref, output_offset_ref,
loop_var, output_value->data(), output_value->length()});
Expand Down Expand Up @@ -604,7 +627,7 @@ llvm::Value* LLVMGenerator::AddFunctionCall(const std::string& full_name,
llvm::Function* fn = module()->getFunction(full_name);
DCHECK_NE(fn, nullptr) << "missing function " << full_name;

if (enable_ir_traces_ && !full_name.compare("printf") &&
if (!full_name.compare("printf") &&
!full_name.compare("printff")) {
// Trace for debugging
ADD_TRACE("invoke native fn " + full_name);
Expand All @@ -624,7 +647,7 @@ llvm::Value* LLVMGenerator::AddFunctionCall(const std::string& full_name,
llvm::raw_string_ostream output2(str2);
ret_type->print(output);
value->getType()->print(output2);
std::cout << "LR ret type " << str << " value ret type " << str2 << std::endl;
std::cout << "LR addfunctioncall for " << full_name << " == value->getType " << str2 << " ret_type " << str << std::endl;

DCHECK(value->getType() == ret_type);
}
Expand All @@ -644,9 +667,7 @@ std::shared_ptr<DecimalLValue> LLVMGenerator::BuildDecimalLValue(llvm::Value* va
}

#define ADD_VISITOR_TRACE(...) \
if (generator_->enable_ir_traces_) { \
generator_->AddTrace(__VA_ARGS__); \
}

// Visitor for generating the code for a decomposed expression.
LLVMGenerator::Visitor::Visitor(LLVMGenerator* generator, llvm::Function* function,
Expand Down Expand Up @@ -1018,6 +1039,8 @@ void LLVMGenerator::Visitor::Visit(const NonNullableFuncDex& dex) {
LLVMTypes* types = generator_->types();
auto arrow_type_id = arrow_return_type->id();
auto result_type = types->DataVecType(arrow_return_type);
//Result type array/list is special.
//auto result_type = types->IRType(arrow_type_id);
std::cout << "LR NonNullableFunc 2 result_type " << printType(result_type) << " arrow_return_type " << arrow_return_type->ToString() << " old type " << types->IRType(arrow_type_id) << std::endl;

// Build combined validity of the args.
Expand Down Expand Up @@ -1477,6 +1500,7 @@ LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func,
auto llvm_return_type = types->DataVecType(arrow_return_type);
DecimalIR decimalIR(generator_->engine_.get());

std::cout << "LR LLVMGenerator::Visitor::BuildFunctionCall for " << func->pc_name() << " llvm return type is " << printType(llvm_return_type) << std::endl;
if (arrow_return_type_id == arrow::Type::DECIMAL) {
// For decimal fns, the output precision/scale are passed along as parameters.
//
Expand Down Expand Up @@ -1504,12 +1528,31 @@ LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func,
// add extra arg for return length for variable len return types (allocated on stack).
llvm::AllocaInst* result_len_ptr = nullptr;
if (arrow::is_binary_like(arrow_return_type_id)) {
std::cout << "LR LLVMGenerator::Visitor::BuildFunctionCall is binary like" << std::endl;
result_len_ptr = new llvm::AllocaInst(generator_->types()->i32_type(), 0,
"result_len", entry_block_);
params->push_back(result_len_ptr);
has_arena_allocs_ = true;
}

if (arrow_return_type_id == arrow::Type::LIST) {
std::cout << "LR LLVMGenerator::Visitor::BuildFunctionCall is list" << std::endl;
result_len_ptr = new llvm::AllocaInst(generator_->types()->i32_type(), 0,
"result_len", entry_block_);
params->push_back(result_len_ptr);
has_arena_allocs_ = true;


}

std::cout << "LR LLVMGenerator::Visitor::BuildFunctionCall params are: " << std::endl;
for (auto p : *params) {
std::string str1;
llvm::raw_string_ostream output1(str1);
p->print(output1);
std::cout << str1 << std::endl;
}

// Make the function call
llvm::IRBuilder<>* builder = ir_builder();
auto value =
Expand All @@ -1520,6 +1563,8 @@ LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func,
(result_len_ptr == nullptr)
? nullptr
: builder->CreateLoad(result_len_ptr->getAllocatedType(), result_len_ptr);
std::cout << "LR LLVMGenerator::Visitor::BuildFunctionCall is DONE" << std::endl;

return std::make_shared<LValue>(value, value_len);
}
}
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/gandiva/llvm_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ LLVMTypes::LLVMTypes(llvm::LLVMContext& context) : context_(context) {
{arrow::Type::type::DECIMAL, i128_type()},
{arrow::Type::type::INTERVAL_MONTHS, i32_type()},
{arrow::Type::type::STRUCT, struct_type()},
{arrow::Type::type::INTERVAL_DAY_TIME, i64_type()}};
{arrow::Type::type::INTERVAL_DAY_TIME, i64_type()},
{arrow::Type::type::LIST, list_type()}};
}

} // namespace gandiva
Loading

0 comments on commit b52e5c2

Please sign in to comment.