Skip to content

Commit

Permalink
DX-56436: fix substring_index function (#43)
Browse files Browse the repository at this point in the history
Co-authored-by: Projjal Chanda <[email protected]>
  • Loading branch information
xxlaykxx and projjal authored Jul 31, 2023
1 parent 76d2cc8 commit e4c1c2c
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 8 deletions.
10 changes: 7 additions & 3 deletions cpp/src/gandiva/gdv_function_stubs_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -464,15 +464,15 @@ TEST(TestGdvFnStubs, TestSubstringIndex) {
EXPECT_FALSE(ctx.has_error());

out_str = gdv_fn_substring_index(ctx_ptr, "Abc.DE.fGh", 10, ".", 1, -2, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "fGh");
EXPECT_EQ(std::string(out_str, out_len), "DE.fGh");
EXPECT_FALSE(ctx.has_error());

out_str = gdv_fn_substring_index(ctx_ptr, "S;DCGS;JO!L", 11, ";", 1, 1, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "S");
EXPECT_FALSE(ctx.has_error());

out_str = gdv_fn_substring_index(ctx_ptr, "S;DCGS;JO!L", 11, ";", 1, -1, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "DCGS;JO!L");
EXPECT_EQ(std::string(out_str, out_len), "JO!L");
EXPECT_FALSE(ctx.has_error());

out_str = gdv_fn_substring_index(ctx_ptr, "www.mysql.com", 13, "Q", 1, 1, &out_len);
Expand All @@ -496,7 +496,7 @@ TEST(TestGdvFnStubs, TestSubstringIndex) {
EXPECT_FALSE(ctx.has_error());

out_str = gdv_fn_substring_index(ctx_ptr, "www||mysql||com", 15, "||", 2, -2, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "com");
EXPECT_EQ(std::string(out_str, out_len), "mysql||com");
EXPECT_FALSE(ctx.has_error());

out_str = gdv_fn_substring_index(ctx_ptr, "MÜNCHEN", 8, "Ü", 2, 1, &out_len);
Expand All @@ -507,6 +507,10 @@ TEST(TestGdvFnStubs, TestSubstringIndex) {
EXPECT_EQ(std::string(out_str, out_len), "NCHEN");
EXPECT_FALSE(ctx.has_error());

out_str = gdv_fn_substring_index(ctx_ptr, "MÜëCHEN", 9, "Ü", 2, -1, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "ëCHEN");
EXPECT_FALSE(ctx.has_error());

out_str = gdv_fn_substring_index(ctx_ptr, "citroën", 8, "ë", 2, -1, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "n");
EXPECT_FALSE(ctx.has_error());
Expand Down
7 changes: 5 additions & 2 deletions cpp/src/gandiva/gdv_string_function_stubs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,10 +413,13 @@ const char* gdv_fn_substring_index(int64_t context, const char* txt, int32_t txt
return out;
} else if (static_cast<int32_t>(abs(cnt)) <= static_cast<int32_t>(occ.size()) &&
cnt < 0) {
int32_t sz = static_cast<int32_t>(occ.size());
int32_t temp = static_cast<int32_t>(abs(cnt));
memcpy(out, txt + occ[temp - 1] + pat_len, txt_len - occ[temp - 1] - pat_len);
*out_len = txt_len - occ[temp - 1] - pat_len;

memcpy(out, txt + occ[sz - temp] + pat_len, txt_len - occ[sz - temp] - pat_len);
*out_len = txt_len - occ[sz - temp] - pat_len;
return out;

} else {
*out_len = txt_len;
memcpy(out, txt, txt_len);
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/gandiva/precompiled/arithmetic_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -681,14 +681,14 @@ TEST(TestArithmeticOps, TestCeilingFloatDouble) {
}

TEST(TestArithmeticOps, TestFloorFloatDouble) {
// ceiling from floats
// floor from floats
EXPECT_EQ(floor_float32(6.6f), 6.0f);
EXPECT_EQ(floor_float32(-6.6f), -7.0f);
EXPECT_EQ(floor_float32(-6.3f), -7.0f);
EXPECT_EQ(floor_float32(0.0f), 0.0f);
EXPECT_EQ(floor_float32(-0), 0.0);

// ceiling from doubles
// floor from doubles
EXPECT_EQ(floor_float64(6.6), 6.0);
EXPECT_EQ(floor_float64(-6.6), -7.0);
EXPECT_EQ(floor_float64(-6.3), -7.0);
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/gandiva/tests/projector_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3194,7 +3194,8 @@ TEST_F(TestProjector, TestSubstringIndex) {

auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array1, array2, array3});

auto out_1 = MakeArrowArrayUtf8({"www||mysql", "com", "DCGS;JO!L"}, {true, true, true});
auto out_1 =
MakeArrowArrayUtf8({"www||mysql", "mysql||com", "JO!L"}, {true, true, true});

arrow::ArrayVector outputs;

Expand Down

0 comments on commit e4c1c2c

Please sign in to comment.