From f3c3eca72c81bf00026409ef7582430be7629977 Mon Sep 17 00:00:00 2001 From: arnavb Date: Fri, 24 Jan 2025 07:06:32 +0000 Subject: [PATCH] update --- velox/core/PlanNode.cpp | 20 +++++++- velox/exec/NestedLoopJoinProbe.cpp | 66 ++++++++++++++++++++++++- velox/exec/NestedLoopJoinProbe.h | 11 +++++ velox/exec/tests/NestedLoopJoinTest.cpp | 32 ++++++++++++ velox/exec/tests/utils/PlanBuilder.cpp | 18 +++++-- 5 files changed, 141 insertions(+), 6 deletions(-) diff --git a/velox/core/PlanNode.cpp b/velox/core/PlanNode.cpp index a86b9343a182..667d3fb8143d 100644 --- a/velox/core/PlanNode.cpp +++ b/velox/core/PlanNode.cpp @@ -1424,7 +1424,24 @@ NestedLoopJoinNode::NestedLoopJoinNode( auto leftType = sources_[0]->outputType(); auto rightType = sources_[1]->outputType(); - for (const auto& name : outputType_->names()) { + + auto numOutputColumms = outputType_->size(); + if (core::isLeftSemiProjectJoin(joinType) || core::isRightSemiProjectJoin(joinType)) { + --numOutputColumms; + VELOX_CHECK_EQ(outputType_->childAt(numOutputColumms), BOOLEAN()); + const auto& name = outputType_->nameOf(numOutputColumms); + VELOX_CHECK( + !leftType->containsChild(name), + "Match column '{}' cannot be present in left source.", + name); + VELOX_CHECK( + !rightType->containsChild(name), + "Match column '{}' cannot be present in right source.", + name); + } + + for (auto i = 0; i < numOutputColumms; ++i) { + auto name = outputType_->nameOf(i); const bool leftContains = leftType->containsChild(name); const bool rightContains = rightType->containsChild(name); VELOX_USER_CHECK( @@ -1458,6 +1475,7 @@ bool NestedLoopJoinNode::isSupported(core::JoinType joinType) { case core::JoinType::kLeft: case core::JoinType::kRight: case core::JoinType::kFull: + case core::JoinType::kLeftSemiProject: return true; default: diff --git a/velox/exec/NestedLoopJoinProbe.cpp b/velox/exec/NestedLoopJoinProbe.cpp index f84d0484d1fb..995cfefd23dc 100644 --- a/velox/exec/NestedLoopJoinProbe.cpp +++ b/velox/exec/NestedLoopJoinProbe.cpp @@ -17,6 +17,7 @@ #include "velox/exec/OperatorUtils.h" #include "velox/exec/Task.h" #include "velox/expression/FieldReference.h" +#include namespace facebook::velox::exec { namespace { @@ -325,6 +326,37 @@ bool NestedLoopJoinProbe::addToOutput() { evaluateJoinFilter(currentBuild); } + /** + * Implements a a Left Semi Project Join within NestedLoopJoinProbe. + * The getOutputLeftSemiJoinImpl() will ensure that exactly one row is + * produced for each probe row, along with a boolean "match" column + * which will indicate whether a matching build row exists on the + * build side. + * + * 1. At this point, the filter expressions are applied and we short + * circuit the execution for a LeftSemiProject since we don't require + * mismatch rows or build side projections. For each probe row, we simply + * iterate through decoded filter results to determine if at least + * one build side row satisfies the filter condition. + * 2. If match is found, the match column is marked as `true`, and + * defaulted to false otherwise. Finally populates the output row with + * the probe row data. + * 3. The function ensures that only one row is produced in the output, + * indicating build side match. After processing the current probe row, + * it skip the rest of the build rows. + * + * Returns a `RowVectorPtr` representing the output row. For left semi project + * this basically contains probe row data with the match column. + * + */ + if (isLeftSemiProjectJoin(joinType_)) { + output_ = getOutputLeftSemiJoinImpl(); + numOutputRows_ = 1; + ++buildIndex_; + buildRow_ = 0; + return false; + } + // Iterate over the filter results. For each match, add an output record. for (size_t i = buildRow_; i < decodedFilterResult_.size(); ++i) { if (isJoinConditionMatch(i)) { @@ -414,6 +446,7 @@ void NestedLoopJoinProbe::evaluateJoinFilter(const RowVectorPtr& buildVector) { operatorCtx_->execCtx(), joinCondition_.get(), filterInput.get()); joinCondition_->eval(0, 1, true, filterInputRows_, evalCtx, filterResult); filterOutput_ = filterResult[0]; + decodedFilterResult_.decode(*filterOutput_, filterInputRows_); } @@ -684,4 +717,35 @@ RowVectorPtr NestedLoopJoinProbe::getBuildMismatchedOutput( pool(), outputType_, nullptr, numUnmatched, std::move(projectedChildren)); } -} // namespace facebook::velox::exec +RowVectorPtr NestedLoopJoinProbe::getOutputLeftSemiJoinImpl() { + VELOX_CHECK_NOT_NULL(input_); + + bool matched = false; + numOutputRows_ = 0; + for (auto i = buildRow_; i < decodedFilterResult_.size(); ++i) { + if (isJoinConditionMatch(i)) { + matched = true; + break; + } + } + auto matchVector = BaseVector::create(BOOLEAN(), /*size=*/1, pool()); + auto flatMatch = matchVector->as>(); + flatMatch->set(0, /*matched=*/matched); + + std::vector outputChildren(outputType_->size()); + for (auto& projection : identityProjections_) { + auto indices = allocateIndices(/*size=*/1, pool()); + indices->asMutable()[0] = probeRow_; + outputChildren[projection.outputChannel] = BaseVector::wrapInDictionary( + nullptr, indices, 1, input_->childAt(projection.inputChannel)); + } + + int matchChannel = outputType_->size() - 1; + outputChildren[matchChannel] = matchVector; + + auto singleRow = + std::make_shared(pool(), outputType_, nullptr, 1, outputChildren); + + return singleRow; +} +} diff --git a/velox/exec/NestedLoopJoinProbe.h b/velox/exec/NestedLoopJoinProbe.h index b59457b118aa..419440921834 100644 --- a/velox/exec/NestedLoopJoinProbe.h +++ b/velox/exec/NestedLoopJoinProbe.h @@ -365,6 +365,17 @@ class NestedLoopJoinProbe : public Operator { std::vector filterBuildProjections_; BufferPtr buildOutMapping_; + + // Returns the 'match' column in the output for semi project joins. + VectorPtr& matchColumn() const { + VELOX_DCHECK( + isRightSemiProjectJoin(joinType_) || isLeftSemiProjectJoin(joinType_)); + return output_->children().back(); + } + + bool isLeftSemiJoinProject(core::JoinType joinType); + facebook::velox::RowVectorPtr getOutputLeftSemiJoinImpl(); + }; } // namespace facebook::velox::exec diff --git a/velox/exec/tests/NestedLoopJoinTest.cpp b/velox/exec/tests/NestedLoopJoinTest.cpp index a2dd8d591135..c4c005a61835 100644 --- a/velox/exec/tests/NestedLoopJoinTest.cpp +++ b/velox/exec/tests/NestedLoopJoinTest.cpp @@ -622,5 +622,37 @@ TEST_F(NestedLoopJoinTest, mergeBuildVectors) { ASSERT_TRUE(waitForTaskCompletion(cursor->task().get())); } +TEST_F(NestedLoopJoinTest, leftSemiJoinProjectDataValidation) { + auto probeVectors = makeRowVector( + {"t0"}, + {sequence(5)}); + + auto buildVectors = makeRowVector( + {"u0"}, + {sequence(3, 2)}); + + auto expected = makeRowVector( + {"t0", "match"}, + {makeFlatVector({0, 1, 2, 3, 4}), + makeFlatVector({false, false, true, true, true})}); + + auto planNodeIdGenerator = std::make_shared(); + auto op = PlanBuilder(planNodeIdGenerator) + .values({probeVectors}) + .nestedLoopJoin( + PlanBuilder(planNodeIdGenerator) + .values({buildVectors}) + .planNode(), + "t0 = u0", + {"t0", "match"}, + core::JoinType::kLeftSemiProject) + .planNode(); + + AssertQueryBuilder builder{op}; + auto result = builder.copyResults(pool()); + + assertEqualVectors(expected, result); +} + } // namespace } // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/utils/PlanBuilder.cpp b/velox/exec/tests/utils/PlanBuilder.cpp index cd419343d65c..c6b77f46f5d4 100644 --- a/velox/exec/tests/utils/PlanBuilder.cpp +++ b/velox/exec/tests/utils/PlanBuilder.cpp @@ -440,9 +440,10 @@ PlanBuilder& PlanBuilder::optionalFilter(const std::string& optionalFilter) { PlanBuilder& PlanBuilder::filter(const std::string& filter) { VELOX_CHECK_NOT_NULL(planNode_, "Filter cannot be the source node"); - auto expr = parseExpr(filter, planNode_->outputType(), options_, pool_); - planNode_ = - std::make_shared(nextPlanNodeId(), expr, planNode_); + planNode_ = std::make_shared( + nextPlanNodeId(), + parseExpr(filter, planNode_->outputType(), options_, pool_), + planNode_); return *this; } @@ -1554,7 +1555,13 @@ PlanBuilder& PlanBuilder::nestedLoopJoin( const std::vector& outputLayout, core::JoinType joinType) { VELOX_CHECK_NOT_NULL(planNode_, "NestedLoopJoin cannot be the source node"); + auto resultType = concat(planNode_->outputType(), right->outputType()); + + if (isLeftSemiProjectJoin(joinType) || isRightSemiProjectJoin(joinType)) { + resultType = concat(resultType, ROW({"match"}, {BOOLEAN()})); + } + auto outputType = extract(resultType, outputLayout); core::TypedExprPtr joinConditionExpr{}; @@ -1562,13 +1569,16 @@ PlanBuilder& PlanBuilder::nestedLoopJoin( joinConditionExpr = parseExpr(joinCondition, resultType, options_, pool_); } + RowTypePtr finalOutputType; + finalOutputType = outputType; + planNode_ = std::make_shared( nextPlanNodeId(), joinType, std::move(joinConditionExpr), std::move(planNode_), right, - outputType); + finalOutputType); return *this; }