Skip to content

Commit

Permalink
Deduplicate code for getting cancellation checks out of a search cont…
Browse files Browse the repository at this point in the history
…ext (#120828) (#120839)

Just deduplicating the logic and moving it to a shared location + no need for
a static method like that.
  • Loading branch information
original-brownbear authored Jan 24, 2025
1 parent 47be8e5 commit 7243adc
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,10 @@
*/
package org.elasticsearch.search.aggregations;

import org.elasticsearch.action.search.SearchShardTask;
import org.elasticsearch.search.aggregations.support.TimeSeriesIndexSearcher;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.query.QueryPhase;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;

Expand Down Expand Up @@ -59,7 +56,7 @@ private static AggregatorCollector newAggregatorCollector(SearchContext context)
}

private static void executeInSortOrder(SearchContext context, BucketCollector collector) {
TimeSeriesIndexSearcher searcher = new TimeSeriesIndexSearcher(context.searcher(), getCancellationChecks(context));
TimeSeriesIndexSearcher searcher = new TimeSeriesIndexSearcher(context.searcher(), context.getCancellationChecks());
searcher.setMinimumScore(context.minimumScore());
searcher.setProfiler(context);
try {
Expand All @@ -70,23 +67,4 @@ private static void executeInSortOrder(SearchContext context, BucketCollector co
}
}

private static List<Runnable> getCancellationChecks(SearchContext context) {
List<Runnable> cancellationChecks = new ArrayList<>();
if (context.lowLevelCancellation()) {
// This searching doesn't live beyond this phase, so we don't need to remove query cancellation
cancellationChecks.add(() -> {
final SearchShardTask task = context.getTask();
if (task != null) {
task.ensureNotCancelled();
}
});
}

final Runnable timeoutRunnable = QueryPhase.getTimeoutCheck(context);
if (timeoutRunnable != null) {
cancellationChecks.add(timeoutRunnable);
}

return cancellationChecks;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.elasticsearch.search.fetch.subphase.ScriptFieldsContext;
import org.elasticsearch.search.fetch.subphase.highlight.SearchHighlightContext;
import org.elasticsearch.search.profile.Profilers;
import org.elasticsearch.search.query.QueryPhase;
import org.elasticsearch.search.query.QuerySearchResult;
import org.elasticsearch.search.rank.context.QueryPhaseRankShardContext;
import org.elasticsearch.search.rank.feature.RankFeatureResult;
Expand Down Expand Up @@ -84,6 +85,21 @@ public abstract class SearchContext implements Releasable {

protected SearchContext() {}

public final List<Runnable> getCancellationChecks() {
final Runnable timeoutRunnable = QueryPhase.getTimeoutCheck(this);
if (lowLevelCancellation()) {
// This searching doesn't live beyond this phase, so we don't need to remove query cancellation
Runnable c = () -> {
final SearchShardTask task = getTask();
if (task != null) {
task.ensureNotCancelled();
}
};
return timeoutRunnable == null ? List.of(c) : List.of(c, timeoutRunnable);
}
return timeoutRunnable == null ? List.of() : List.of(timeoutRunnable);
}

public abstract void setTask(SearchShardTask task);

public abstract SearchShardTask getTask();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,16 @@
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.search.SearchShardTask;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.lucene.grouping.TopFieldGroups;
import org.elasticsearch.search.internal.ContextIndexSearcher;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.query.QueryPhase;
import org.elasticsearch.search.query.SearchTimeoutException;
import org.elasticsearch.search.sort.ShardDocSortField;
import org.elasticsearch.search.sort.SortAndFormats;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -195,21 +192,7 @@ private static boolean topDocsSortedByScore(TopDocs topDocs) {
}

static Runnable getCancellationChecks(SearchContext context) {
List<Runnable> cancellationChecks = new ArrayList<>();
if (context.lowLevelCancellation()) {
cancellationChecks.add(() -> {
final SearchShardTask task = context.getTask();
if (task != null) {
task.ensureNotCancelled();
}
});
}

final Runnable timeoutRunnable = QueryPhase.getTimeoutCheck(context);
if (timeoutRunnable != null) {
cancellationChecks.add(timeoutRunnable);
}

List<Runnable> cancellationChecks = context.getCancellationChecks();
return () -> {
for (var check : cancellationChecks) {
check.run();
Expand Down

0 comments on commit 7243adc

Please sign in to comment.