diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ClusterComputeHandler.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ClusterComputeHandler.java new file mode 100644 index 0000000000000..1f2b8faf83ee3 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ClusterComputeHandler.java @@ -0,0 +1,230 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plugin; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionListenerResponseHandler; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.support.ChannelActionListener; +import org.elasticsearch.compute.EsqlRefCountingListener; +import org.elasticsearch.compute.operator.exchange.ExchangeService; +import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.RemoteClusterService; +import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportChannel; +import org.elasticsearch.transport.TransportRequestHandler; +import org.elasticsearch.transport.TransportRequestOptions; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.esql.action.EsqlExecutionInfo; +import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec; +import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; +import org.elasticsearch.xpack.esql.session.Configuration; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.Executor; + +/** + * Manages computes across multiple clusters by sending {@link ClusterComputeRequest} to remote clusters and executing the computes. + * This handler delegates the execution of computes on data nodes within each remote cluster to {@link DataNodeComputeHandler}. + */ +final class ClusterComputeHandler implements TransportRequestHandler { + private final ComputeService computeService; + private final ExchangeService exchangeService; + private final TransportService transportService; + private final Executor esqlExecutor; + private final DataNodeComputeHandler dataNodeComputeHandler; + + ClusterComputeHandler( + ComputeService computeService, + ExchangeService exchangeService, + TransportService transportService, + Executor esqlExecutor, + DataNodeComputeHandler dataNodeComputeHandler + ) { + this.computeService = computeService; + this.exchangeService = exchangeService; + this.esqlExecutor = esqlExecutor; + this.transportService = transportService; + this.dataNodeComputeHandler = dataNodeComputeHandler; + transportService.registerRequestHandler(ComputeService.CLUSTER_ACTION_NAME, esqlExecutor, ClusterComputeRequest::new, this); + } + + void startComputeOnRemoteClusters( + String sessionId, + CancellableTask rootTask, + Configuration configuration, + PhysicalPlan plan, + ExchangeSourceHandler exchangeSource, + List clusters, + ComputeListener computeListener + ) { + var queryPragmas = configuration.pragmas(); + var linkExchangeListeners = ActionListener.releaseAfter(computeListener.acquireAvoid(), exchangeSource.addEmptySink()); + try (EsqlRefCountingListener refs = new EsqlRefCountingListener(linkExchangeListeners)) { + for (RemoteCluster cluster : clusters) { + final var childSessionId = computeService.newChildSession(sessionId); + ExchangeService.openExchange( + transportService, + cluster.connection, + childSessionId, + queryPragmas.exchangeBufferSize(), + esqlExecutor, + refs.acquire().delegateFailureAndWrap((l, unused) -> { + var remoteSink = exchangeService.newRemoteSink(rootTask, childSessionId, transportService, cluster.connection); + exchangeSource.addRemoteSink(remoteSink, true, queryPragmas.concurrentExchangeClients(), ActionListener.noop()); + var remotePlan = new RemoteClusterPlan(plan, cluster.concreteIndices, cluster.originalIndices); + var clusterRequest = new ClusterComputeRequest(cluster.clusterAlias, childSessionId, configuration, remotePlan); + var clusterListener = ActionListener.runBefore( + computeListener.acquireCompute(cluster.clusterAlias()), + () -> l.onResponse(null) + ); + transportService.sendChildRequest( + cluster.connection, + ComputeService.CLUSTER_ACTION_NAME, + clusterRequest, + rootTask, + TransportRequestOptions.EMPTY, + new ActionListenerResponseHandler<>(clusterListener, ComputeResponse::new, esqlExecutor) + ); + }) + ); + } + } + } + + List getRemoteClusters( + Map clusterToConcreteIndices, + Map clusterToOriginalIndices + ) { + List remoteClusters = new ArrayList<>(clusterToConcreteIndices.size()); + RemoteClusterService remoteClusterService = transportService.getRemoteClusterService(); + for (Map.Entry e : clusterToConcreteIndices.entrySet()) { + String clusterAlias = e.getKey(); + OriginalIndices concreteIndices = clusterToConcreteIndices.get(clusterAlias); + OriginalIndices originalIndices = clusterToOriginalIndices.get(clusterAlias); + if (originalIndices == null) { + assert false : "can't find original indices for cluster " + clusterAlias; + throw new IllegalStateException("can't find original indices for cluster " + clusterAlias); + } + if (concreteIndices.indices().length > 0) { + Transport.Connection connection = remoteClusterService.getConnection(clusterAlias); + remoteClusters.add(new RemoteCluster(clusterAlias, connection, concreteIndices.indices(), originalIndices)); + } + } + return remoteClusters; + } + + record RemoteCluster(String clusterAlias, Transport.Connection connection, String[] concreteIndices, OriginalIndices originalIndices) { + + } + + @Override + public void messageReceived(ClusterComputeRequest request, TransportChannel channel, Task task) { + ChannelActionListener listener = new ChannelActionListener<>(channel); + RemoteClusterPlan remoteClusterPlan = request.remoteClusterPlan(); + var plan = remoteClusterPlan.plan(); + if (plan instanceof ExchangeSinkExec == false) { + listener.onFailure(new IllegalStateException("expected exchange sink for a remote compute; got " + plan)); + return; + } + String clusterAlias = request.clusterAlias(); + /* + * This handler runs only on remote cluster coordinators, so it creates a new local EsqlExecutionInfo object to record + * execution metadata for ES|QL processing local to this cluster. The execution info will be copied into the + * ComputeResponse that is sent back to the primary coordinating cluster. + */ + EsqlExecutionInfo execInfo = new EsqlExecutionInfo(true); + execInfo.swapCluster(clusterAlias, (k, v) -> new EsqlExecutionInfo.Cluster(clusterAlias, Arrays.toString(request.indices()))); + CancellableTask cancellable = (CancellableTask) task; + try (var computeListener = ComputeListener.create(clusterAlias, transportService, cancellable, execInfo, listener)) { + runComputeOnRemoteCluster( + clusterAlias, + request.sessionId(), + (CancellableTask) task, + request.configuration(), + (ExchangeSinkExec) plan, + Set.of(remoteClusterPlan.targetIndices()), + remoteClusterPlan.originalIndices(), + execInfo, + computeListener + ); + } + } + + /** + * Performs a compute on a remote cluster. The output pages are placed in an exchange sink specified by + * {@code globalSessionId}. The coordinator on the main cluster will poll pages from there. + *

+ * Currently, the coordinator on the remote cluster polls pages from data nodes within the remote cluster + * and performs cluster-level reduction before sending pages to the querying cluster. This reduction aims + * to minimize data transfers across clusters but may require additional CPU resources for operations like + * aggregations. + */ + void runComputeOnRemoteCluster( + String clusterAlias, + String globalSessionId, + CancellableTask parentTask, + Configuration configuration, + ExchangeSinkExec plan, + Set concreteIndices, + OriginalIndices originalIndices, + EsqlExecutionInfo executionInfo, + ComputeListener computeListener + ) { + final var exchangeSink = exchangeService.getSinkHandler(globalSessionId); + parentTask.addListener( + () -> exchangeService.finishSinkHandler(globalSessionId, new TaskCancelledException(parentTask.getReasonCancelled())) + ); + final String localSessionId = clusterAlias + ":" + globalSessionId; + final PhysicalPlan coordinatorPlan = ComputeService.reductionPlan(plan, true); + var exchangeSource = new ExchangeSourceHandler( + configuration.pragmas().exchangeBufferSize(), + transportService.getThreadPool().executor(ThreadPool.Names.SEARCH), + computeListener.acquireAvoid() + ); + try (Releasable ignored = exchangeSource.addEmptySink()) { + exchangeSink.addCompletionListener(computeListener.acquireAvoid()); + computeService.runCompute( + parentTask, + new ComputeContext( + localSessionId, + clusterAlias, + List.of(), + configuration, + configuration.newFoldContext(), + exchangeSource, + exchangeSink + ), + coordinatorPlan, + computeListener.acquireCompute(clusterAlias) + ); + dataNodeComputeHandler.startComputeOnDataNodes( + localSessionId, + clusterAlias, + parentTask, + configuration, + plan, + concreteIndices, + originalIndices, + exchangeSource, + executionInfo, + computeListener + ); + } + } + +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeContext.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeContext.java new file mode 100644 index 0000000000000..4e178bb740757 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeContext.java @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plugin; + +import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler; +import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.session.Configuration; + +import java.util.List; + +record ComputeContext( + String sessionId, + String clusterAlias, + List searchContexts, + Configuration configuration, + FoldContext foldCtx, + ExchangeSourceHandler exchangeSource, + ExchangeSinkHandler exchangeSink +) { + List searchExecutionContexts() { + return searchContexts.stream().map(SearchContext::getSearchExecutionContext).toList(); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index a38236fe60954..2cb4b49ec3591 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -8,58 +8,31 @@ package org.elasticsearch.xpack.esql.plugin; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.ActionListenerResponseHandler; -import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.action.search.SearchShardsGroup; -import org.elasticsearch.action.search.SearchShardsRequest; -import org.elasticsearch.action.search.SearchShardsResponse; -import org.elasticsearch.action.support.ChannelActionListener; -import org.elasticsearch.action.support.RefCountingRunnable; -import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.compute.EsqlRefCountingListener; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.DriverTaskRunner; import org.elasticsearch.compute.operator.exchange.ExchangeService; -import org.elasticsearch.compute.operator.exchange.ExchangeSink; -import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler; import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler; -import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.core.Tuple; -import org.elasticsearch.index.Index; -import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.index.shard.IndexShard; -import org.elasticsearch.index.shard.ShardId; -import org.elasticsearch.index.shard.ShardNotFoundException; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; import org.elasticsearch.search.SearchService; -import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.search.internal.SearchContext; -import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.search.lookup.SourceProvider; import org.elasticsearch.tasks.CancellableTask; -import org.elasticsearch.tasks.Task; -import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.RemoteClusterAware; -import org.elasticsearch.transport.RemoteClusterService; -import org.elasticsearch.transport.Transport; -import org.elasticsearch.transport.TransportChannel; -import org.elasticsearch.transport.TransportRequestHandler; -import org.elasticsearch.transport.TransportRequestOptions; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.esql.action.EsqlExecutionInfo; import org.elasticsearch.xpack.esql.action.EsqlQueryAction; -import org.elasticsearch.xpack.esql.action.EsqlSearchShardsAction; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.enrich.EnrichLookupService; @@ -75,14 +48,10 @@ import org.elasticsearch.xpack.esql.session.Result; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.Executor; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; @@ -92,20 +61,24 @@ * Computes the result of a {@link PhysicalPlan}. */ public class ComputeService { + public static final String DATA_ACTION_NAME = EsqlQueryAction.NAME + "/data"; + public static final String CLUSTER_ACTION_NAME = EsqlQueryAction.NAME + "/cluster"; + private static final Logger LOGGER = LogManager.getLogger(ComputeService.class); private final SearchService searchService; private final BigArrays bigArrays; private final BlockFactory blockFactory; private final TransportService transportService; - private final Executor esqlExecutor; private final DriverTaskRunner driverRunner; - private final ExchangeService exchangeService; private final EnrichLookupService enrichLookupService; private final LookupFromIndexService lookupFromIndexService; private final ClusterService clusterService; private final AtomicLong childSessionIdGenerator = new AtomicLong(); + private final DataNodeComputeHandler dataNodeComputeHandler; + private final ClusterComputeHandler clusterComputeHandler; + @SuppressWarnings("this-escape") public ComputeService( SearchService searchService, TransportService transportService, @@ -121,19 +94,19 @@ public ComputeService( this.transportService = transportService; this.bigArrays = bigArrays.withCircuitBreaking(); this.blockFactory = blockFactory; - this.esqlExecutor = threadPool.executor(ThreadPool.Names.SEARCH); - transportService.registerRequestHandler(DATA_ACTION_NAME, this.esqlExecutor, DataNodeRequest::new, new DataNodeRequestHandler()); - transportService.registerRequestHandler( - CLUSTER_ACTION_NAME, - this.esqlExecutor, - ClusterComputeRequest::new, - new ClusterRequestHandler() - ); - this.driverRunner = new DriverTaskRunner(transportService, this.esqlExecutor); - this.exchangeService = exchangeService; + var esqlExecutor = threadPool.executor(ThreadPool.Names.SEARCH); + this.driverRunner = new DriverTaskRunner(transportService, esqlExecutor); this.enrichLookupService = enrichLookupService; this.lookupFromIndexService = lookupFromIndexService; this.clusterService = clusterService; + this.dataNodeComputeHandler = new DataNodeComputeHandler(this, searchService, transportService, exchangeService, esqlExecutor); + this.clusterComputeHandler = new ClusterComputeHandler( + this, + exchangeService, + transportService, + esqlExecutor, + dataNodeComputeHandler + ); } public void execute( @@ -238,7 +211,7 @@ public void execute( ); // starts computes on data nodes on the main cluster if (localConcreteIndices != null && localConcreteIndices.indices().length > 0) { - startComputeOnDataNodes( + dataNodeComputeHandler.startComputeOnDataNodes( sessionId, RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, rootTask, @@ -252,13 +225,14 @@ public void execute( ); } // starts computes on remote clusters - startComputeOnRemoteClusters( + final var remoteClusters = clusterComputeHandler.getRemoteClusters(clusterToConcreteIndices, clusterToOriginalIndices); + clusterComputeHandler.startComputeOnRemoteClusters( sessionId, rootTask, configuration, dataNodePlan, exchangeSource, - getRemoteClusters(clusterToConcreteIndices, clusterToOriginalIndices), + remoteClusters, computeListener ); } @@ -298,145 +272,11 @@ private static void updateExecutionInfoAfterCoordinatorOnlyQuery(EsqlExecutionIn } } - private List getRemoteClusters( - Map clusterToConcreteIndices, - Map clusterToOriginalIndices - ) { - List remoteClusters = new ArrayList<>(clusterToConcreteIndices.size()); - RemoteClusterService remoteClusterService = transportService.getRemoteClusterService(); - for (Map.Entry e : clusterToConcreteIndices.entrySet()) { - String clusterAlias = e.getKey(); - OriginalIndices concreteIndices = clusterToConcreteIndices.get(clusterAlias); - OriginalIndices originalIndices = clusterToOriginalIndices.get(clusterAlias); - if (originalIndices == null) { - assert false : "can't find original indices for cluster " + clusterAlias; - throw new IllegalStateException("can't find original indices for cluster " + clusterAlias); - } - if (concreteIndices.indices().length > 0) { - Transport.Connection connection = remoteClusterService.getConnection(clusterAlias); - remoteClusters.add(new RemoteCluster(clusterAlias, connection, concreteIndices.indices(), originalIndices)); - } - } - return remoteClusters; - } - - private void startComputeOnDataNodes( - String sessionId, - String clusterAlias, - CancellableTask parentTask, - Configuration configuration, - PhysicalPlan dataNodePlan, - Set concreteIndices, - OriginalIndices originalIndices, - ExchangeSourceHandler exchangeSource, - EsqlExecutionInfo executionInfo, - ComputeListener computeListener - ) { - QueryBuilder requestFilter = PlannerUtils.requestTimestampFilter(dataNodePlan); - var lookupListener = ActionListener.releaseAfter(computeListener.acquireAvoid(), exchangeSource.addEmptySink()); - // SearchShards API can_match is done in lookupDataNodes - lookupDataNodes(parentTask, clusterAlias, requestFilter, concreteIndices, originalIndices, ActionListener.wrap(dataNodeResult -> { - try (EsqlRefCountingListener refs = new EsqlRefCountingListener(lookupListener)) { - // update ExecutionInfo with shard counts (total and skipped) - executionInfo.swapCluster( - clusterAlias, - (k, v) -> new EsqlExecutionInfo.Cluster.Builder(v).setTotalShards(dataNodeResult.totalShards()) - // do not set successful or failed shard count here - do it when search is done - .setSkippedShards(dataNodeResult.skippedShards()) - .build() - ); - - // For each target node, first open a remote exchange on the remote node, then link the exchange source to - // the new remote exchange sink, and initialize the computation on the target node via data-node-request. - for (DataNode node : dataNodeResult.dataNodes()) { - var queryPragmas = configuration.pragmas(); - var childSessionId = newChildSession(sessionId); - ExchangeService.openExchange( - transportService, - node.connection, - childSessionId, - queryPragmas.exchangeBufferSize(), - esqlExecutor, - refs.acquire().delegateFailureAndWrap((l, unused) -> { - var remoteSink = exchangeService.newRemoteSink(parentTask, childSessionId, transportService, node.connection); - exchangeSource.addRemoteSink(remoteSink, true, queryPragmas.concurrentExchangeClients(), ActionListener.noop()); - ActionListener computeResponseListener = computeListener.acquireCompute(clusterAlias); - var dataNodeListener = ActionListener.runBefore(computeResponseListener, () -> l.onResponse(null)); - final boolean sameNode = transportService.getLocalNode().getId().equals(node.connection.getNode().getId()); - var dataNodeRequest = new DataNodeRequest( - childSessionId, - configuration, - clusterAlias, - node.shardIds, - node.aliasFilters, - dataNodePlan, - originalIndices.indices(), - originalIndices.indicesOptions(), - sameNode == false && queryPragmas.nodeLevelReduction() - ); - transportService.sendChildRequest( - node.connection, - DATA_ACTION_NAME, - dataNodeRequest, - parentTask, - TransportRequestOptions.EMPTY, - new ActionListenerResponseHandler<>(dataNodeListener, ComputeResponse::new, esqlExecutor) - ); - }) - ); - } - } - }, lookupListener::onFailure)); - } - - private void startComputeOnRemoteClusters( - String sessionId, - CancellableTask rootTask, - Configuration configuration, - PhysicalPlan plan, - ExchangeSourceHandler exchangeSource, - List clusters, - ComputeListener computeListener - ) { - var queryPragmas = configuration.pragmas(); - var linkExchangeListeners = ActionListener.releaseAfter(computeListener.acquireAvoid(), exchangeSource.addEmptySink()); - try (EsqlRefCountingListener refs = new EsqlRefCountingListener(linkExchangeListeners)) { - for (RemoteCluster cluster : clusters) { - final var childSessionId = newChildSession(sessionId); - ExchangeService.openExchange( - transportService, - cluster.connection, - childSessionId, - queryPragmas.exchangeBufferSize(), - esqlExecutor, - refs.acquire().delegateFailureAndWrap((l, unused) -> { - var remoteSink = exchangeService.newRemoteSink(rootTask, childSessionId, transportService, cluster.connection); - exchangeSource.addRemoteSink(remoteSink, true, queryPragmas.concurrentExchangeClients(), ActionListener.noop()); - var remotePlan = new RemoteClusterPlan(plan, cluster.concreteIndices, cluster.originalIndices); - var clusterRequest = new ClusterComputeRequest(cluster.clusterAlias, childSessionId, configuration, remotePlan); - var clusterListener = ActionListener.runBefore( - computeListener.acquireCompute(cluster.clusterAlias()), - () -> l.onResponse(null) - ); - transportService.sendChildRequest( - cluster.connection, - CLUSTER_ACTION_NAME, - clusterRequest, - rootTask, - TransportRequestOptions.EMPTY, - new ActionListenerResponseHandler<>(clusterListener, ComputeResponse::new, esqlExecutor) - ); - }) - ); - } - } - } - void runCompute(CancellableTask task, ComputeContext context, PhysicalPlan plan, ActionListener listener) { - listener = ActionListener.runBefore(listener, () -> Releasables.close(context.searchContexts)); - List contexts = new ArrayList<>(context.searchContexts.size()); - for (int i = 0; i < context.searchContexts.size(); i++) { - SearchContext searchContext = context.searchContexts.get(i); + listener = ActionListener.runBefore(listener, () -> Releasables.close(context.searchContexts())); + List contexts = new ArrayList<>(context.searchContexts().size()); + for (int i = 0; i < context.searchContexts().size(); i++) { + SearchContext searchContext = context.searchContexts().get(i); var searchExecutionContext = new SearchExecutionContext(searchContext.getSearchExecutionContext()) { @Override @@ -453,13 +293,13 @@ public SourceProvider createSourceProvider() { final List drivers; try { LocalExecutionPlanner planner = new LocalExecutionPlanner( - context.sessionId, - context.clusterAlias, + context.sessionId(), + context.clusterAlias(), task, bigArrays, blockFactory, clusterService.getSettings(), - context.configuration, + context.configuration(), context.exchangeSource(), context.exchangeSink(), enrichLookupService, @@ -469,7 +309,7 @@ public SourceProvider createSourceProvider() { LOGGER.debug("Received physical plan:\n{}", plan); - plan = PlannerUtils.localPlan(context.searchExecutionContexts(), context.configuration, context.foldCtx(), plan); + plan = PlannerUtils.localPlan(context.searchExecutionContexts(), context.configuration(), context.foldCtx(), plan); // the planner will also set the driver parallelism in LocalExecutionPlanner.LocalExecutionPlan (used down below) // it's doing this in the planning of EsQueryExec (the source of the data) // see also EsPhysicalOperationProviders.sourcePhysicalOperation @@ -477,7 +317,7 @@ public SourceProvider createSourceProvider() { if (LOGGER.isDebugEnabled()) { LOGGER.debug("Local execution plan:\n{}", localExecutionPlan.describe()); } - drivers = localExecutionPlan.createDrivers(context.sessionId); + drivers = localExecutionPlan.createDrivers(context.sessionId()); if (drivers.isEmpty()) { throw new IllegalStateException("no drivers created"); } @@ -487,7 +327,7 @@ public SourceProvider createSourceProvider() { return; } ActionListener listenerCollectingStatus = listener.map(ignored -> { - if (context.configuration.profile()) { + if (context.configuration().profile()) { return new ComputeResponse(drivers.stream().map(Driver::profile).toList()); } else { final ComputeResponse response = new ComputeResponse(List.of()); @@ -503,306 +343,7 @@ public SourceProvider createSourceProvider() { ); } - private void acquireSearchContexts( - String clusterAlias, - List shardIds, - Configuration configuration, - Map aliasFilters, - ActionListener> listener - ) { - final List targetShards = new ArrayList<>(); - try { - for (ShardId shardId : shardIds) { - var indexShard = searchService.getIndicesService().indexServiceSafe(shardId.getIndex()).getShard(shardId.id()); - targetShards.add(indexShard); - } - } catch (Exception e) { - listener.onFailure(e); - return; - } - final var doAcquire = ActionRunnable.supply(listener, () -> { - final List searchContexts = new ArrayList<>(targetShards.size()); - boolean success = false; - try { - for (IndexShard shard : targetShards) { - var aliasFilter = aliasFilters.getOrDefault(shard.shardId().getIndex(), AliasFilter.EMPTY); - var shardRequest = new ShardSearchRequest( - shard.shardId(), - configuration.absoluteStartedTimeInMillis(), - aliasFilter, - clusterAlias - ); - // TODO: `searchService.createSearchContext` allows opening search contexts without limits, - // we need to limit the number of active search contexts here or in SearchService - SearchContext context = searchService.createSearchContext(shardRequest, SearchService.NO_TIMEOUT); - searchContexts.add(context); - } - for (SearchContext searchContext : searchContexts) { - searchContext.preProcess(); - } - success = true; - return searchContexts; - } finally { - if (success == false) { - IOUtils.close(searchContexts); - } - } - }); - final AtomicBoolean waitedForRefreshes = new AtomicBoolean(); - try (RefCountingRunnable refs = new RefCountingRunnable(() -> { - if (waitedForRefreshes.get()) { - esqlExecutor.execute(doAcquire); - } else { - doAcquire.run(); - } - })) { - for (IndexShard targetShard : targetShards) { - final Releasable ref = refs.acquire(); - targetShard.ensureShardSearchActive(await -> { - try (ref) { - if (await) { - waitedForRefreshes.set(true); - } - } - }); - } - } - } - - record DataNode(Transport.Connection connection, List shardIds, Map aliasFilters) { - - } - - /** - * Result from lookupDataNodes where can_match is performed to determine what shards can be skipped - * and which target nodes are needed for running the ES|QL query - * - * @param dataNodes list of DataNode to perform the ES|QL query on - * @param totalShards Total number of shards (from can_match phase), including skipped shards - * @param skippedShards Number of skipped shards (from can_match phase) - */ - record DataNodeResult(List dataNodes, int totalShards, int skippedShards) {} - - record RemoteCluster(String clusterAlias, Transport.Connection connection, String[] concreteIndices, OriginalIndices originalIndices) { - - } - - /** - * Performs can_match and find the target nodes for the given target indices and filter. - *

- * Ideally, the search_shards API should be called before the field-caps API; however, this can lead - * to a situation where the column structure (i.e., matched data types) differs depending on the query. - */ - private void lookupDataNodes( - Task parentTask, - String clusterAlias, - QueryBuilder filter, - Set concreteIndices, - OriginalIndices originalIndices, - ActionListener listener - ) { - ActionListener searchShardsListener = listener.map(resp -> { - Map nodes = new HashMap<>(); - for (DiscoveryNode node : resp.getNodes()) { - nodes.put(node.getId(), node); - } - Map> nodeToShards = new HashMap<>(); - Map> nodeToAliasFilters = new HashMap<>(); - int totalShards = 0; - int skippedShards = 0; - for (SearchShardsGroup group : resp.getGroups()) { - var shardId = group.shardId(); - if (group.allocatedNodes().isEmpty()) { - throw new ShardNotFoundException(group.shardId(), "no shard copies found {}", group.shardId()); - } - if (concreteIndices.contains(shardId.getIndexName()) == false) { - continue; - } - totalShards++; - if (group.skipped()) { - skippedShards++; - continue; - } - String targetNode = group.allocatedNodes().get(0); - nodeToShards.computeIfAbsent(targetNode, k -> new ArrayList<>()).add(shardId); - AliasFilter aliasFilter = resp.getAliasFilters().get(shardId.getIndex().getUUID()); - if (aliasFilter != null) { - nodeToAliasFilters.computeIfAbsent(targetNode, k -> new HashMap<>()).put(shardId.getIndex(), aliasFilter); - } - } - List dataNodes = new ArrayList<>(nodeToShards.size()); - for (Map.Entry> e : nodeToShards.entrySet()) { - DiscoveryNode node = nodes.get(e.getKey()); - Map aliasFilters = nodeToAliasFilters.getOrDefault(e.getKey(), Map.of()); - dataNodes.add(new DataNode(transportService.getConnection(node), e.getValue(), aliasFilters)); - } - return new DataNodeResult(dataNodes, totalShards, skippedShards); - }); - SearchShardsRequest searchShardsRequest = new SearchShardsRequest( - originalIndices.indices(), - originalIndices.indicesOptions(), - filter, - null, - null, - false, - clusterAlias - ); - transportService.sendChildRequest( - transportService.getLocalNode(), - EsqlSearchShardsAction.TYPE.name(), - searchShardsRequest, - parentTask, - TransportRequestOptions.EMPTY, - new ActionListenerResponseHandler<>(searchShardsListener, SearchShardsResponse::new, esqlExecutor) - ); - } - - // TODO: Use an internal action here - public static final String DATA_ACTION_NAME = EsqlQueryAction.NAME + "/data"; - - private class DataNodeRequestExecutor { - private final DataNodeRequest request; - private final CancellableTask parentTask; - private final ExchangeSinkHandler exchangeSink; - private final ComputeListener computeListener; - private final int maxConcurrentShards; - private final ExchangeSink blockingSink; // block until we have completed on all shards or the coordinator has enough data - - DataNodeRequestExecutor( - DataNodeRequest request, - CancellableTask parentTask, - ExchangeSinkHandler exchangeSink, - int maxConcurrentShards, - ComputeListener computeListener - ) { - this.request = request; - this.parentTask = parentTask; - this.exchangeSink = exchangeSink; - this.computeListener = computeListener; - this.maxConcurrentShards = maxConcurrentShards; - this.blockingSink = exchangeSink.createExchangeSink(); - } - - void start() { - parentTask.addListener( - () -> exchangeService.finishSinkHandler(request.sessionId(), new TaskCancelledException(parentTask.getReasonCancelled())) - ); - runBatch(0); - } - - private void runBatch(int startBatchIndex) { - final Configuration configuration = request.configuration(); - final String clusterAlias = request.clusterAlias(); - final var sessionId = request.sessionId(); - final int endBatchIndex = Math.min(startBatchIndex + maxConcurrentShards, request.shardIds().size()); - List shardIds = request.shardIds().subList(startBatchIndex, endBatchIndex); - ActionListener batchListener = new ActionListener<>() { - final ActionListener ref = computeListener.acquireCompute(); - - @Override - public void onResponse(ComputeResponse result) { - try { - onBatchCompleted(endBatchIndex); - } finally { - ref.onResponse(result); - } - } - - @Override - public void onFailure(Exception e) { - try { - exchangeService.finishSinkHandler(request.sessionId(), e); - } finally { - ref.onFailure(e); - } - } - }; - acquireSearchContexts(clusterAlias, shardIds, configuration, request.aliasFilters(), ActionListener.wrap(searchContexts -> { - assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH, ESQL_WORKER_THREAD_POOL_NAME); - var computeContext = new ComputeContext( - sessionId, - clusterAlias, - searchContexts, - configuration, - configuration.newFoldContext(), - null, - exchangeSink - ); - runCompute(parentTask, computeContext, request.plan(), batchListener); - }, batchListener::onFailure)); - } - - private void onBatchCompleted(int lastBatchIndex) { - if (lastBatchIndex < request.shardIds().size() && exchangeSink.isFinished() == false) { - runBatch(lastBatchIndex); - } else { - // don't return until all pages are fetched - var completionListener = computeListener.acquireAvoid(); - exchangeSink.addCompletionListener( - ActionListener.runAfter(completionListener, () -> exchangeService.finishSinkHandler(request.sessionId(), null)) - ); - blockingSink.finish(); - } - } - } - - private void runComputeOnDataNode( - CancellableTask task, - String externalId, - PhysicalPlan reducePlan, - DataNodeRequest request, - ComputeListener computeListener - ) { - var parentListener = computeListener.acquireAvoid(); - try { - // run compute with target shards - var internalSink = exchangeService.createSinkHandler(request.sessionId(), request.pragmas().exchangeBufferSize()); - DataNodeRequestExecutor dataNodeRequestExecutor = new DataNodeRequestExecutor( - request, - task, - internalSink, - request.configuration().pragmas().maxConcurrentShardsPerNode(), - computeListener - ); - dataNodeRequestExecutor.start(); - // run the node-level reduction - var externalSink = exchangeService.getSinkHandler(externalId); - task.addListener(() -> exchangeService.finishSinkHandler(externalId, new TaskCancelledException(task.getReasonCancelled()))); - var exchangeSource = new ExchangeSourceHandler(1, esqlExecutor, computeListener.acquireAvoid()); - exchangeSource.addRemoteSink(internalSink::fetchPageAsync, true, 1, ActionListener.noop()); - ActionListener reductionListener = computeListener.acquireCompute(); - runCompute( - task, - new ComputeContext( - request.sessionId(), - request.clusterAlias(), - List.of(), - request.configuration(), - new FoldContext(request.pragmas().foldLimit().getBytes()), - exchangeSource, - externalSink - ), - reducePlan, - ActionListener.wrap(resp -> { - // don't return until all pages are fetched - externalSink.addCompletionListener(ActionListener.running(() -> { - exchangeService.finishSinkHandler(externalId, null); - reductionListener.onResponse(resp); - })); - }, e -> { - exchangeService.finishSinkHandler(externalId, e); - reductionListener.onFailure(e); - }) - ); - parentListener.onResponse(null); - } catch (Exception e) { - exchangeService.finishSinkHandler(externalId, e); - exchangeService.finishSinkHandler(request.sessionId(), e); - parentListener.onFailure(e); - } - } - - private static PhysicalPlan reductionPlan(ExchangeSinkExec plan, boolean enable) { + static PhysicalPlan reductionPlan(ExchangeSinkExec plan, boolean enable) { PhysicalPlan reducePlan = new ExchangeSourceExec(plan.source(), plan.output(), plan.isIntermediateAgg()); if (enable) { PhysicalPlan p = PlannerUtils.reductionPlan(plan); @@ -813,149 +354,7 @@ private static PhysicalPlan reductionPlan(ExchangeSinkExec plan, boolean enable) return new ExchangeSinkExec(plan.source(), plan.output(), plan.isIntermediateAgg(), reducePlan); } - private class DataNodeRequestHandler implements TransportRequestHandler { - @Override - public void messageReceived(DataNodeRequest request, TransportChannel channel, Task task) { - final ActionListener listener = new ChannelActionListener<>(channel); - final PhysicalPlan reductionPlan; - if (request.plan() instanceof ExchangeSinkExec plan) { - reductionPlan = reductionPlan(plan, request.runNodeLevelReduction()); - } else { - listener.onFailure(new IllegalStateException("expected exchange sink for a remote compute; got " + request.plan())); - return; - } - final String sessionId = request.sessionId(); - request = new DataNodeRequest( - sessionId + "[n]", // internal session - request.configuration(), - request.clusterAlias(), - request.shardIds(), - request.aliasFilters(), - request.plan(), - request.indices(), - request.indicesOptions(), - request.runNodeLevelReduction() - ); - try (var computeListener = ComputeListener.create(transportService, (CancellableTask) task, listener)) { - runComputeOnDataNode((CancellableTask) task, sessionId, reductionPlan, request, computeListener); - } - } - } - - public static final String CLUSTER_ACTION_NAME = EsqlQueryAction.NAME + "/cluster"; - - private class ClusterRequestHandler implements TransportRequestHandler { - @Override - public void messageReceived(ClusterComputeRequest request, TransportChannel channel, Task task) { - ChannelActionListener listener = new ChannelActionListener<>(channel); - RemoteClusterPlan remoteClusterPlan = request.remoteClusterPlan(); - var plan = remoteClusterPlan.plan(); - if (plan instanceof ExchangeSinkExec == false) { - listener.onFailure(new IllegalStateException("expected exchange sink for a remote compute; got " + plan)); - return; - } - String clusterAlias = request.clusterAlias(); - /* - * This handler runs only on remote cluster coordinators, so it creates a new local EsqlExecutionInfo object to record - * execution metadata for ES|QL processing local to this cluster. The execution info will be copied into the - * ComputeResponse that is sent back to the primary coordinating cluster. - */ - EsqlExecutionInfo execInfo = new EsqlExecutionInfo(true); - execInfo.swapCluster(clusterAlias, (k, v) -> new EsqlExecutionInfo.Cluster(clusterAlias, Arrays.toString(request.indices()))); - CancellableTask cancellable = (CancellableTask) task; - try (var computeListener = ComputeListener.create(clusterAlias, transportService, cancellable, execInfo, listener)) { - runComputeOnRemoteCluster( - clusterAlias, - request.sessionId(), - (CancellableTask) task, - request.configuration(), - (ExchangeSinkExec) plan, - Set.of(remoteClusterPlan.targetIndices()), - remoteClusterPlan.originalIndices(), - execInfo, - computeListener - ); - } - } - } - - /** - * Performs a compute on a remote cluster. The output pages are placed in an exchange sink specified by - * {@code globalSessionId}. The coordinator on the main cluster will poll pages from there. - *

- * Currently, the coordinator on the remote cluster polls pages from data nodes within the remote cluster - * and performs cluster-level reduction before sending pages to the querying cluster. This reduction aims - * to minimize data transfers across clusters but may require additional CPU resources for operations like - * aggregations. - */ - void runComputeOnRemoteCluster( - String clusterAlias, - String globalSessionId, - CancellableTask parentTask, - Configuration configuration, - ExchangeSinkExec plan, - Set concreteIndices, - OriginalIndices originalIndices, - EsqlExecutionInfo executionInfo, - ComputeListener computeListener - ) { - final var exchangeSink = exchangeService.getSinkHandler(globalSessionId); - parentTask.addListener( - () -> exchangeService.finishSinkHandler(globalSessionId, new TaskCancelledException(parentTask.getReasonCancelled())) - ); - final String localSessionId = clusterAlias + ":" + globalSessionId; - final PhysicalPlan coordinatorPlan = reductionPlan(plan, true); - var exchangeSource = new ExchangeSourceHandler( - configuration.pragmas().exchangeBufferSize(), - transportService.getThreadPool().executor(ThreadPool.Names.SEARCH), - computeListener.acquireAvoid() - ); - try (Releasable ignored = exchangeSource.addEmptySink()) { - exchangeSink.addCompletionListener(computeListener.acquireAvoid()); - runCompute( - parentTask, - new ComputeContext( - localSessionId, - clusterAlias, - List.of(), - configuration, - configuration.newFoldContext(), - exchangeSource, - exchangeSink - ), - coordinatorPlan, - computeListener.acquireCompute(clusterAlias) - ); - startComputeOnDataNodes( - localSessionId, - clusterAlias, - parentTask, - configuration, - plan, - concreteIndices, - originalIndices, - exchangeSource, - executionInfo, - computeListener - ); - } - } - - record ComputeContext( - String sessionId, - String clusterAlias, - List searchContexts, - Configuration configuration, - FoldContext foldCtx, - ExchangeSourceHandler exchangeSource, - ExchangeSinkHandler exchangeSink - ) { - public List searchExecutionContexts() { - return searchContexts.stream().map(ctx -> ctx.getSearchExecutionContext()).toList(); - } - } - - private String newChildSession(String session) { + String newChildSession(String session) { return session + "/" + childSessionIdGenerator.incrementAndGet(); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java new file mode 100644 index 0000000000000..1a1e5726a487b --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeComputeHandler.java @@ -0,0 +1,476 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plugin; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionListenerResponseHandler; +import org.elasticsearch.action.ActionRunnable; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.search.SearchShardsGroup; +import org.elasticsearch.action.search.SearchShardsRequest; +import org.elasticsearch.action.search.SearchShardsResponse; +import org.elasticsearch.action.support.ChannelActionListener; +import org.elasticsearch.action.support.RefCountingRunnable; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.compute.EsqlRefCountingListener; +import org.elasticsearch.compute.operator.exchange.ExchangeService; +import org.elasticsearch.compute.operator.exchange.ExchangeSink; +import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler; +import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.shard.IndexShard; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.index.shard.ShardNotFoundException; +import org.elasticsearch.search.SearchService; +import org.elasticsearch.search.internal.AliasFilter; +import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportChannel; +import org.elasticsearch.transport.TransportRequestHandler; +import org.elasticsearch.transport.TransportRequestOptions; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.esql.action.EsqlExecutionInfo; +import org.elasticsearch.xpack.esql.action.EsqlSearchShardsAction; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec; +import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; +import org.elasticsearch.xpack.esql.planner.PlannerUtils; +import org.elasticsearch.xpack.esql.session.Configuration; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.elasticsearch.xpack.esql.plugin.EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME; + +/** + * Handles computes within a single cluster by dispatching {@link DataNodeRequest} to data nodes + * and executing these computes on the data nodes. + */ +final class DataNodeComputeHandler implements TransportRequestHandler { + private final ComputeService computeService; + private final SearchService searchService; + private final TransportService transportService; + private final ExchangeService exchangeService; + private final Executor esqlExecutor; + + DataNodeComputeHandler( + ComputeService computeService, + SearchService searchService, + TransportService transportService, + ExchangeService exchangeService, + Executor esqlExecutor + ) { + this.computeService = computeService; + this.searchService = searchService; + this.transportService = transportService; + this.exchangeService = exchangeService; + this.esqlExecutor = esqlExecutor; + transportService.registerRequestHandler(ComputeService.DATA_ACTION_NAME, esqlExecutor, DataNodeRequest::new, this); + } + + void startComputeOnDataNodes( + String sessionId, + String clusterAlias, + CancellableTask parentTask, + Configuration configuration, + PhysicalPlan dataNodePlan, + Set concreteIndices, + OriginalIndices originalIndices, + ExchangeSourceHandler exchangeSource, + EsqlExecutionInfo executionInfo, + ComputeListener computeListener + ) { + QueryBuilder requestFilter = PlannerUtils.requestTimestampFilter(dataNodePlan); + var lookupListener = ActionListener.releaseAfter(computeListener.acquireAvoid(), exchangeSource.addEmptySink()); + // SearchShards API can_match is done in lookupDataNodes + lookupDataNodes(parentTask, clusterAlias, requestFilter, concreteIndices, originalIndices, ActionListener.wrap(dataNodeResult -> { + try (EsqlRefCountingListener refs = new EsqlRefCountingListener(lookupListener)) { + // update ExecutionInfo with shard counts (total and skipped) + executionInfo.swapCluster( + clusterAlias, + (k, v) -> new EsqlExecutionInfo.Cluster.Builder(v).setTotalShards(dataNodeResult.totalShards()) + // do not set successful or failed shard count here - do it when search is done + .setSkippedShards(dataNodeResult.skippedShards()) + .build() + ); + + // For each target node, first open a remote exchange on the remote node, then link the exchange source to + // the new remote exchange sink, and initialize the computation on the target node via data-node-request. + for (DataNode node : dataNodeResult.dataNodes()) { + var queryPragmas = configuration.pragmas(); + var childSessionId = computeService.newChildSession(sessionId); + ExchangeService.openExchange( + transportService, + node.connection, + childSessionId, + queryPragmas.exchangeBufferSize(), + esqlExecutor, + refs.acquire().delegateFailureAndWrap((l, unused) -> { + var remoteSink = exchangeService.newRemoteSink(parentTask, childSessionId, transportService, node.connection); + exchangeSource.addRemoteSink(remoteSink, true, queryPragmas.concurrentExchangeClients(), ActionListener.noop()); + ActionListener computeResponseListener = computeListener.acquireCompute(clusterAlias); + var dataNodeListener = ActionListener.runBefore(computeResponseListener, () -> l.onResponse(null)); + final boolean sameNode = transportService.getLocalNode().getId().equals(node.connection.getNode().getId()); + var dataNodeRequest = new DataNodeRequest( + childSessionId, + configuration, + clusterAlias, + node.shardIds, + node.aliasFilters, + dataNodePlan, + originalIndices.indices(), + originalIndices.indicesOptions(), + sameNode == false && queryPragmas.nodeLevelReduction() + ); + transportService.sendChildRequest( + node.connection, + ComputeService.DATA_ACTION_NAME, + dataNodeRequest, + parentTask, + TransportRequestOptions.EMPTY, + new ActionListenerResponseHandler<>(dataNodeListener, ComputeResponse::new, esqlExecutor) + ); + }) + ); + } + } + }, lookupListener::onFailure)); + } + + private void acquireSearchContexts( + String clusterAlias, + List shardIds, + Configuration configuration, + Map aliasFilters, + ActionListener> listener + ) { + final List targetShards = new ArrayList<>(); + try { + for (ShardId shardId : shardIds) { + var indexShard = searchService.getIndicesService().indexServiceSafe(shardId.getIndex()).getShard(shardId.id()); + targetShards.add(indexShard); + } + } catch (Exception e) { + listener.onFailure(e); + return; + } + final var doAcquire = ActionRunnable.supply(listener, () -> { + final List searchContexts = new ArrayList<>(targetShards.size()); + boolean success = false; + try { + for (IndexShard shard : targetShards) { + var aliasFilter = aliasFilters.getOrDefault(shard.shardId().getIndex(), AliasFilter.EMPTY); + var shardRequest = new ShardSearchRequest( + shard.shardId(), + configuration.absoluteStartedTimeInMillis(), + aliasFilter, + clusterAlias + ); + // TODO: `searchService.createSearchContext` allows opening search contexts without limits, + // we need to limit the number of active search contexts here or in SearchService + SearchContext context = searchService.createSearchContext(shardRequest, SearchService.NO_TIMEOUT); + searchContexts.add(context); + } + for (SearchContext searchContext : searchContexts) { + searchContext.preProcess(); + } + success = true; + return searchContexts; + } finally { + if (success == false) { + IOUtils.close(searchContexts); + } + } + }); + final AtomicBoolean waitedForRefreshes = new AtomicBoolean(); + try (RefCountingRunnable refs = new RefCountingRunnable(() -> { + if (waitedForRefreshes.get()) { + esqlExecutor.execute(doAcquire); + } else { + doAcquire.run(); + } + })) { + for (IndexShard targetShard : targetShards) { + final Releasable ref = refs.acquire(); + targetShard.ensureShardSearchActive(await -> { + try (ref) { + if (await) { + waitedForRefreshes.set(true); + } + } + }); + } + } + } + + record DataNode(Transport.Connection connection, List shardIds, Map aliasFilters) { + + } + + /** + * Result from lookupDataNodes where can_match is performed to determine what shards can be skipped + * and which target nodes are needed for running the ES|QL query + * + * @param dataNodes list of DataNode to perform the ES|QL query on + * @param totalShards Total number of shards (from can_match phase), including skipped shards + * @param skippedShards Number of skipped shards (from can_match phase) + */ + record DataNodeResult(List dataNodes, int totalShards, int skippedShards) {} + + /** + * Performs can_match and find the target nodes for the given target indices and filter. + *

+ * Ideally, the search_shards API should be called before the field-caps API; however, this can lead + * to a situation where the column structure (i.e., matched data types) differs depending on the query. + */ + private void lookupDataNodes( + Task parentTask, + String clusterAlias, + QueryBuilder filter, + Set concreteIndices, + OriginalIndices originalIndices, + ActionListener listener + ) { + ActionListener searchShardsListener = listener.map(resp -> { + Map nodes = new HashMap<>(); + for (DiscoveryNode node : resp.getNodes()) { + nodes.put(node.getId(), node); + } + Map> nodeToShards = new HashMap<>(); + Map> nodeToAliasFilters = new HashMap<>(); + int totalShards = 0; + int skippedShards = 0; + for (SearchShardsGroup group : resp.getGroups()) { + var shardId = group.shardId(); + if (group.allocatedNodes().isEmpty()) { + throw new ShardNotFoundException(group.shardId(), "no shard copies found {}", group.shardId()); + } + if (concreteIndices.contains(shardId.getIndexName()) == false) { + continue; + } + totalShards++; + if (group.skipped()) { + skippedShards++; + continue; + } + String targetNode = group.allocatedNodes().get(0); + nodeToShards.computeIfAbsent(targetNode, k -> new ArrayList<>()).add(shardId); + AliasFilter aliasFilter = resp.getAliasFilters().get(shardId.getIndex().getUUID()); + if (aliasFilter != null) { + nodeToAliasFilters.computeIfAbsent(targetNode, k -> new HashMap<>()).put(shardId.getIndex(), aliasFilter); + } + } + List dataNodes = new ArrayList<>(nodeToShards.size()); + for (Map.Entry> e : nodeToShards.entrySet()) { + DiscoveryNode node = nodes.get(e.getKey()); + Map aliasFilters = nodeToAliasFilters.getOrDefault(e.getKey(), Map.of()); + dataNodes.add(new DataNode(transportService.getConnection(node), e.getValue(), aliasFilters)); + } + return new DataNodeResult(dataNodes, totalShards, skippedShards); + }); + SearchShardsRequest searchShardsRequest = new SearchShardsRequest( + originalIndices.indices(), + originalIndices.indicesOptions(), + filter, + null, + null, + false, + clusterAlias + ); + transportService.sendChildRequest( + transportService.getLocalNode(), + EsqlSearchShardsAction.TYPE.name(), + searchShardsRequest, + parentTask, + TransportRequestOptions.EMPTY, + new ActionListenerResponseHandler<>(searchShardsListener, SearchShardsResponse::new, esqlExecutor) + ); + } + + private class DataNodeRequestExecutor { + private final DataNodeRequest request; + private final CancellableTask parentTask; + private final ExchangeSinkHandler exchangeSink; + private final ComputeListener computeListener; + private final int maxConcurrentShards; + private final ExchangeSink blockingSink; // block until we have completed on all shards or the coordinator has enough data + + DataNodeRequestExecutor( + DataNodeRequest request, + CancellableTask parentTask, + ExchangeSinkHandler exchangeSink, + int maxConcurrentShards, + ComputeListener computeListener + ) { + this.request = request; + this.parentTask = parentTask; + this.exchangeSink = exchangeSink; + this.computeListener = computeListener; + this.maxConcurrentShards = maxConcurrentShards; + this.blockingSink = exchangeSink.createExchangeSink(); + } + + void start() { + parentTask.addListener( + () -> exchangeService.finishSinkHandler(request.sessionId(), new TaskCancelledException(parentTask.getReasonCancelled())) + ); + runBatch(0); + } + + private void runBatch(int startBatchIndex) { + final Configuration configuration = request.configuration(); + final String clusterAlias = request.clusterAlias(); + final var sessionId = request.sessionId(); + final int endBatchIndex = Math.min(startBatchIndex + maxConcurrentShards, request.shardIds().size()); + List shardIds = request.shardIds().subList(startBatchIndex, endBatchIndex); + ActionListener batchListener = new ActionListener<>() { + final ActionListener ref = computeListener.acquireCompute(); + + @Override + public void onResponse(ComputeResponse result) { + try { + onBatchCompleted(endBatchIndex); + } finally { + ref.onResponse(result); + } + } + + @Override + public void onFailure(Exception e) { + try { + exchangeService.finishSinkHandler(request.sessionId(), e); + } finally { + ref.onFailure(e); + } + } + }; + acquireSearchContexts(clusterAlias, shardIds, configuration, request.aliasFilters(), ActionListener.wrap(searchContexts -> { + assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH, ESQL_WORKER_THREAD_POOL_NAME); + var computeContext = new ComputeContext( + sessionId, + clusterAlias, + searchContexts, + configuration, + configuration.newFoldContext(), + null, + exchangeSink + ); + computeService.runCompute(parentTask, computeContext, request.plan(), batchListener); + }, batchListener::onFailure)); + } + + private void onBatchCompleted(int lastBatchIndex) { + if (lastBatchIndex < request.shardIds().size() && exchangeSink.isFinished() == false) { + runBatch(lastBatchIndex); + } else { + // don't return until all pages are fetched + var completionListener = computeListener.acquireAvoid(); + exchangeSink.addCompletionListener( + ActionListener.runAfter(completionListener, () -> exchangeService.finishSinkHandler(request.sessionId(), null)) + ); + blockingSink.finish(); + } + } + } + + private void runComputeOnDataNode( + CancellableTask task, + String externalId, + PhysicalPlan reducePlan, + DataNodeRequest request, + ComputeListener computeListener + ) { + var parentListener = computeListener.acquireAvoid(); + try { + // run compute with target shards + var internalSink = exchangeService.createSinkHandler(request.sessionId(), request.pragmas().exchangeBufferSize()); + DataNodeRequestExecutor dataNodeRequestExecutor = new DataNodeRequestExecutor( + request, + task, + internalSink, + request.configuration().pragmas().maxConcurrentShardsPerNode(), + computeListener + ); + dataNodeRequestExecutor.start(); + // run the node-level reduction + var externalSink = exchangeService.getSinkHandler(externalId); + task.addListener(() -> exchangeService.finishSinkHandler(externalId, new TaskCancelledException(task.getReasonCancelled()))); + var exchangeSource = new ExchangeSourceHandler(1, esqlExecutor, computeListener.acquireAvoid()); + exchangeSource.addRemoteSink(internalSink::fetchPageAsync, true, 1, ActionListener.noop()); + ActionListener reductionListener = computeListener.acquireCompute(); + computeService.runCompute( + task, + new ComputeContext( + request.sessionId(), + request.clusterAlias(), + List.of(), + request.configuration(), + new FoldContext(request.pragmas().foldLimit().getBytes()), + exchangeSource, + externalSink + ), + reducePlan, + ActionListener.wrap(resp -> { + // don't return until all pages are fetched + externalSink.addCompletionListener(ActionListener.running(() -> { + exchangeService.finishSinkHandler(externalId, null); + reductionListener.onResponse(resp); + })); + }, e -> { + exchangeService.finishSinkHandler(externalId, e); + reductionListener.onFailure(e); + }) + ); + parentListener.onResponse(null); + } catch (Exception e) { + exchangeService.finishSinkHandler(externalId, e); + exchangeService.finishSinkHandler(request.sessionId(), e); + parentListener.onFailure(e); + } + } + + @Override + public void messageReceived(DataNodeRequest request, TransportChannel channel, Task task) { + final ActionListener listener = new ChannelActionListener<>(channel); + final PhysicalPlan reductionPlan; + if (request.plan() instanceof ExchangeSinkExec plan) { + reductionPlan = ComputeService.reductionPlan(plan, request.runNodeLevelReduction()); + } else { + listener.onFailure(new IllegalStateException("expected exchange sink for a remote compute; got " + request.plan())); + return; + } + final String sessionId = request.sessionId(); + request = new DataNodeRequest( + sessionId + "[n]", // internal session + request.configuration(), + request.clusterAlias(), + request.shardIds(), + request.aliasFilters(), + request.plan(), + request.indices(), + request.indicesOptions(), + request.runNodeLevelReduction() + ); + try (var computeListener = ComputeListener.create(transportService, (CancellableTask) task, listener)) { + runComputeOnDataNode((CancellableTask) task, sessionId, reductionPlan, request, computeListener); + } + } +}