From d560915e265ebe05b6e7c2eb37400734a89e2323 Mon Sep 17 00:00:00 2001 From: Yenda Li Date: Tue, 4 Feb 2025 07:55:29 -0800 Subject: [PATCH] feat: Support cancellation in HashProbe::getOutput [3/n] (#12237) Summary: Hashprobe::getOutput can take a long time to run for degenerate query shapes. Reviewed By: Yuhta Differential Revision: D68997164 --- velox/exec/HashProbe.cpp | 7 +++++++ velox/exec/tests/HashJoinTest.cpp | 27 +++++++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/velox/exec/HashProbe.cpp b/velox/exec/HashProbe.cpp index c8454da3a06f..26cd9f7e13c0 100644 --- a/velox/exec/HashProbe.cpp +++ b/velox/exec/HashProbe.cpp @@ -1078,6 +1078,13 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) { initBuffer(outputTableRows_, outputTableRowsCapacity_, pool()); for (;;) { + // If the task owning this operator has been cancelled, there is no point + // to continue executing this procedure, which may be long in degenerate + // cases. Exit the working loop and let the Driver handle exiting gracefully + // in its own loop. + if (operatorCtx_->task()->isCancelled()) { + return nullptr; + } int numOut = 0; if (emptyBuildSide) { diff --git a/velox/exec/tests/HashJoinTest.cpp b/velox/exec/tests/HashJoinTest.cpp index 650ef294bae2..c7b88c4a8c8a 100644 --- a/velox/exec/tests/HashJoinTest.cpp +++ b/velox/exec/tests/HashJoinTest.cpp @@ -401,6 +401,11 @@ class HashJoinBuilder { return *this; } + HashJoinBuilder& injectTaskCancellation(bool injectTaskCancellation) { + injectTaskCancellation_ = injectTaskCancellation; + return *this; + } + HashJoinBuilder& maxSpillLevel(int32_t maxSpillLevel) { maxSpillLevel_ = maxSpillLevel; return *this; @@ -665,6 +670,11 @@ class HashJoinBuilder { memory::spillMemoryPool()->stats().peakBytes; TestScopedSpillInjection scopedSpillInjection(spillPct); auto task = builder.assertResults(referenceQuery_); + + if (injectTaskCancellation_) { + task->requestCancel(); + } + // Wait up to 5 seconds for all the task background activities to complete. // Then we can collect the stats from all the operators. // @@ -758,6 +768,7 @@ class HashJoinBuilder { std::optional runParallelProbe_; std::optional runParallelBuild_; + bool injectTaskCancellation_{false}; bool injectSpill_{true}; // If not set, then the test will run the test with different settings: // 0, 2. @@ -1020,6 +1031,22 @@ TEST_P(MultiThreadedHashJoinTest, outOfJoinKeyColumnOrder) { .run(); } +TEST_P(MultiThreadedHashJoinTest, joinWithCancellation) { + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(numDrivers_) + .keyTypes({BIGINT()}) + .probeVectors(1600, 5) + .buildVectors(1500, 5) + .injectTaskCancellation(true) + .referenceQuery( + "SELECT t_k0, t_data, u_k0, u_data FROM t, u WHERE t.t_k0 = u.u_k0") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + auto stats = task->taskStats(); + EXPECT_GT(stats.terminationTimeMs, 0); + }) + .run(); +} + TEST_P(MultiThreadedHashJoinTest, emptyBuild) { const std::vector finishOnEmptys = {false, true}; for (const auto finishOnEmpty : finishOnEmptys) {