diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/AggregationPhase.java b/server/src/main/java/org/elasticsearch/search/aggregations/AggregationPhase.java index 8ca21db1ad9f2..4e1ec3faf6b36 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/AggregationPhase.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/AggregationPhase.java @@ -8,13 +8,10 @@ */ package org.elasticsearch.search.aggregations; -import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.search.aggregations.support.TimeSeriesIndexSearcher; import org.elasticsearch.search.internal.SearchContext; -import org.elasticsearch.search.query.QueryPhase; import java.io.IOException; -import java.util.ArrayList; import java.util.List; import java.util.function.Supplier; @@ -59,7 +56,7 @@ private static AggregatorCollector newAggregatorCollector(SearchContext context) } private static void executeInSortOrder(SearchContext context, BucketCollector collector) { - TimeSeriesIndexSearcher searcher = new TimeSeriesIndexSearcher(context.searcher(), getCancellationChecks(context)); + TimeSeriesIndexSearcher searcher = new TimeSeriesIndexSearcher(context.searcher(), context.getCancellationChecks()); searcher.setMinimumScore(context.minimumScore()); searcher.setProfiler(context); try { @@ -70,23 +67,4 @@ private static void executeInSortOrder(SearchContext context, BucketCollector co } } - private static List getCancellationChecks(SearchContext context) { - List cancellationChecks = new ArrayList<>(); - if (context.lowLevelCancellation()) { - // This searching doesn't live beyond this phase, so we don't need to remove query cancellation - cancellationChecks.add(() -> { - final SearchShardTask task = context.getTask(); - if (task != null) { - task.ensureNotCancelled(); - } - }); - } - - final Runnable timeoutRunnable = QueryPhase.getTimeoutCheck(context); - if (timeoutRunnable != null) { - cancellationChecks.add(timeoutRunnable); - } - - return cancellationChecks; - } } diff --git a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java index f57b247219ae6..45ac9fa06f399 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java @@ -41,6 +41,7 @@ import org.elasticsearch.search.fetch.subphase.ScriptFieldsContext; import org.elasticsearch.search.fetch.subphase.highlight.SearchHighlightContext; import org.elasticsearch.search.profile.Profilers; +import org.elasticsearch.search.query.QueryPhase; import org.elasticsearch.search.query.QuerySearchResult; import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext; import org.elasticsearch.search.rank.feature.RankFeatureResult; @@ -84,6 +85,21 @@ public abstract class SearchContext implements Releasable { protected SearchContext() {} + public final List getCancellationChecks() { + final Runnable timeoutRunnable = QueryPhase.getTimeoutCheck(this); + if (lowLevelCancellation()) { + // This searching doesn't live beyond this phase, so we don't need to remove query cancellation + Runnable c = () -> { + final SearchShardTask task = getTask(); + if (task != null) { + task.ensureNotCancelled(); + } + }; + return timeoutRunnable == null ? List.of(c) : List.of(c, timeoutRunnable); + } + return timeoutRunnable == null ? List.of() : List.of(timeoutRunnable); + } + public abstract void setTask(SearchShardTask task); public abstract SearchShardTask getTask(); diff --git a/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java b/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java index c23df9cdfa441..f8b348b383f01 100644 --- a/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java +++ b/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java @@ -15,19 +15,16 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopFieldDocs; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.Maps; import org.elasticsearch.lucene.grouping.TopFieldGroups; import org.elasticsearch.search.internal.ContextIndexSearcher; import org.elasticsearch.search.internal.SearchContext; -import org.elasticsearch.search.query.QueryPhase; import org.elasticsearch.search.query.SearchTimeoutException; import org.elasticsearch.search.sort.ShardDocSortField; import org.elasticsearch.search.sort.SortAndFormats; import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -195,21 +192,7 @@ private static boolean topDocsSortedByScore(TopDocs topDocs) { } static Runnable getCancellationChecks(SearchContext context) { - List cancellationChecks = new ArrayList<>(); - if (context.lowLevelCancellation()) { - cancellationChecks.add(() -> { - final SearchShardTask task = context.getTask(); - if (task != null) { - task.ensureNotCancelled(); - } - }); - } - - final Runnable timeoutRunnable = QueryPhase.getTimeoutCheck(context); - if (timeoutRunnable != null) { - cancellationChecks.add(timeoutRunnable); - } - + List cancellationChecks = context.getCancellationChecks(); return () -> { for (var check : cancellationChecks) { check.run();