Skip to content

Commit

Permalink
Chunk aggregation output pages
Browse files Browse the repository at this point in the history
  • Loading branch information
dnhatn committed Jan 10, 2025
1 parent de09149 commit da4a842
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,11 @@ private static Operator operator(DriverContext driverContext, String grouping, S
);
default -> throw new IllegalArgumentException("unsupported grouping [" + grouping + "]");
};
int pageSize = 16 * 1024;
return new HashAggregationOperator(
List.of(supplier(op, dataType, filter, groups.size()).groupingAggregatorFactory(AggregatorMode.SINGLE)),
() -> BlockHash.build(groups, driverContext.blockFactory(), 16 * 1024, false),
() -> BlockHash.build(groups, driverContext.blockFactory(), pageSize, false),
pageSize,
driverContext
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ static TransportVersion def(int id) {
public static final TransportVersion REPLACE_FAILURE_STORE_OPTIONS_WITH_SELECTOR_SYNTAX = def(8_821_00_0);
public static final TransportVersion ELASTIC_INFERENCE_SERVICE_UNIFIED_CHAT_COMPLETIONS_INTEGRATION = def(8_822_00_0);
public static final TransportVersion KQL_QUERY_TECH_PREVIEW = def(8_823_00_0);
public static final TransportVersion ESQL_CHUNK_AGGREGATION_OUTPUT = def(8_824_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.compute.data.IntBlock;
import org.elasticsearch.compute.data.IntVector;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.analysis.AnalysisRegistry;
Expand Down Expand Up @@ -58,12 +59,14 @@ public Operator get(DriverContext driverContext) {
analysisRegistry,
maxPageSize
),
maxPageSize,
driverContext
);
}
return new HashAggregationOperator(
aggregators,
() -> BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false),
maxPageSize,
driverContext
);
}
Expand All @@ -78,9 +81,10 @@ public String describe() {
}
}

private boolean finished;
private Page output;
private final int maxPageSize;
private Emitter emitter;

private boolean blockHashClosed = false;
private final BlockHash blockHash;

private final List<GroupingAggregator> aggregators;
Expand All @@ -96,16 +100,23 @@ public String describe() {
*/
private long aggregationNanos;
/**
* Count of pages this operator has processed.
* Count of input pages this operator has processed.
*/
private int pagesProcessed;

/**
* Count of output pages this operator has emitted
*/
private int pagesEmitted;

@SuppressWarnings("this-escape")
public HashAggregationOperator(
List<GroupingAggregator.Factory> aggregators,
Supplier<BlockHash> blockHash,
int maxPageSize,
DriverContext driverContext
) {
this.maxPageSize = maxPageSize;
this.aggregators = new ArrayList<>(aggregators.size());
this.driverContext = driverContext;
boolean success = false;
Expand All @@ -124,7 +135,7 @@ public HashAggregationOperator(

@Override
public boolean needsInput() {
return finished == false;
return emitter == null;
}

@Override
Expand Down Expand Up @@ -192,61 +203,102 @@ public void close() {

@Override
public Page getOutput() {
Page p = output;
output = null;
return p;
if (emitter == null) {
return null;
}
return emitter.nextPage();
}

@Override
public void finish() {
if (finished) {
return;
private class Emitter implements Releasable {
private final int[] aggBlockCounts;
private int position = -1;
private IntVector allSelected = null;
private Block[] allKeys;

Emitter(int[] aggBlockCounts) {
this.aggBlockCounts = aggBlockCounts;
}
finished = true;
Block[] blocks = null;
IntVector selected = null;
boolean success = false;
try {
selected = blockHash.nonEmpty();
Block[] keys = blockHash.getKeys();
int[] aggBlockCounts = aggregators.stream().mapToInt(GroupingAggregator::evaluateBlockCount).toArray();
blocks = new Block[keys.length + Arrays.stream(aggBlockCounts).sum()];
System.arraycopy(keys, 0, blocks, 0, keys.length);
int offset = keys.length;
for (int i = 0; i < aggregators.size(); i++) {
var aggregator = aggregators.get(i);
aggregator.evaluate(blocks, offset, selected, driverContext);
offset += aggBlockCounts[i];

Page nextPage() {
if (position == -1) {
position = 0;
// TODO: chunk selected and keys
allKeys = blockHash.getKeys();
allSelected = blockHash.nonEmpty();
blockHashClosed = true;
blockHash.close();
}
output = new Page(blocks);
success = true;
} finally {
// selected should always be closed
if (selected != null) {
selected.close();
final int endPosition = Math.min(position + maxPageSize, allSelected.getPositionCount());
if (endPosition == position) {
return null;
}
if (success == false && blocks != null) {
Releasables.closeExpectNoException(blocks);
final boolean singlePage = position == 0 && endPosition == allSelected.getPositionCount();
final Block[] blocks = new Block[allKeys.length + Arrays.stream(aggBlockCounts).sum()];
IntVector selected = null;
boolean success = false;
try {
if (singlePage) {
this.allSelected.incRef();
selected = this.allSelected;
for (int i = 0; i < allKeys.length; i++) {
allKeys[i].incRef();
blocks[i] = allKeys[i];
}
} else {
final int[] positions = new int[endPosition - position];
for (int i = 0; i < positions.length; i++) {
positions[i] = position + i;
}
selected = allSelected.filter(positions);
for (int keyIndex = 0; keyIndex < allKeys.length; keyIndex++) {
blocks[keyIndex] = allKeys[keyIndex].filter(positions);
}
}
int blockOffset = allKeys.length;
for (int i = 0; i < aggregators.size(); i++) {
aggregators.get(i).evaluate(blocks, blockOffset, selected, driverContext);
blockOffset += aggBlockCounts[i];
}
var output = new Page(blocks);
pagesEmitted++;
success = true;
return output;
} finally {
position = endPosition;
Releasables.close(selected, success ? null : Releasables.wrap(blocks));
}
}

@Override
public void close() {
Releasables.close(allSelected, allKeys != null ? Releasables.wrap(allKeys) : null);
}

boolean doneEmitting() {
return allSelected != null && position >= allSelected.getPositionCount();
}
}

@Override
public void finish() {
if (emitter == null) {
emitter = new Emitter(aggregators.stream().mapToInt(GroupingAggregator::evaluateBlockCount).toArray());
}
}

@Override
public boolean isFinished() {
return finished && output == null;
return emitter != null && emitter.doneEmitting();
}

@Override
public void close() {
if (output != null) {
output.releaseBlocks();
}
Releasables.close(blockHash, () -> Releasables.close(aggregators));
Releasables.close(emitter, blockHashClosed ? null : blockHash, () -> Releasables.close(aggregators));
}

@Override
public Operator.Status status() {
return new Status(hashNanos, aggregationNanos, pagesProcessed);
return new Status(hashNanos, aggregationNanos, pagesProcessed, pagesEmitted);
}

protected static void checkState(boolean condition, String msg) {
Expand Down Expand Up @@ -285,33 +337,43 @@ public static class Status implements Operator.Status {
*/
private final long aggregationNanos;
/**
* Count of pages this operator has processed.
* Count of input pages this operator has processed.
*/
private final int pagesProcessed;

/**
* Count of output pages this operator has emitted
*/
private final int pageEmitted;

/**
* Build.
* @param hashNanos Nanoseconds this operator has spent hashing grouping keys.
* @param aggregationNanos Nanoseconds this operator has spent running the aggregations.
* @param pagesProcessed Count of pages this operator has processed.
*/
public Status(long hashNanos, long aggregationNanos, int pagesProcessed) {
public Status(long hashNanos, long aggregationNanos, int pagesProcessed, int pagesEmitted) {
this.hashNanos = hashNanos;
this.aggregationNanos = aggregationNanos;
this.pagesProcessed = pagesProcessed;
this.pageEmitted = pagesEmitted;
}

protected Status(StreamInput in) throws IOException {
hashNanos = in.readVLong();
aggregationNanos = in.readVLong();
pagesProcessed = in.readVInt();
pageEmitted = in.getTransportVersion().onOrAfter(TransportVersions.ESQL_CHUNK_AGGREGATION_OUTPUT) ? in.readVInt() : 0;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVLong(hashNanos);
out.writeVLong(aggregationNanos);
out.writeVInt(pagesProcessed);
if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CHUNK_AGGREGATION_OUTPUT)) {
out.writeVInt(pageEmitted);
}
}

@Override
Expand All @@ -334,12 +396,19 @@ public long aggregationNanos() {
}

/**
* Count of pages this operator has processed.
* Count of input pages this operator has processed.
*/
public int pagesProcessed() {
return pagesProcessed;
}

/**
* Count of output pages this operator has emitted
*/
public int pagesEmitted() {
return pageEmitted;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand All @@ -352,6 +421,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field("aggregation_time", TimeValue.timeValueNanos(aggregationNanos));
}
builder.field("pages_processed", pagesProcessed);
builder.field("pages_emitted", pageEmitted);
return builder.endObject();

}
Expand All @@ -361,12 +431,15 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Status status = (Status) o;
return hashNanos == status.hashNanos && aggregationNanos == status.aggregationNanos && pagesProcessed == status.pagesProcessed;
return hashNanos == status.hashNanos
&& aggregationNanos == status.aggregationNanos
&& pagesProcessed == status.pagesProcessed
&& pageEmitted == status.pageEmitted;
}

@Override
public int hashCode() {
return Objects.hash(hashNanos, aggregationNanos, pagesProcessed);
return Objects.hash(hashNanos, aggregationNanos, pagesProcessed, pageEmitted);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,16 @@ public Page getOutput() {
return null;
}
if (valuesAggregator != null) {
try {
return valuesAggregator.getOutput();
} finally {
final ValuesAggregator aggregator = this.valuesAggregator;
this.valuesAggregator = null;
Releasables.close(aggregator);
final Page output = valuesAggregator.getOutput();
if (output == null) {
Releasables.close(valuesAggregator, () -> this.valuesAggregator = null);
} else {
return output;
}
}
if (ordinalAggregators.isEmpty() == false) {
try {
// TODO: chunk output pages
return mergeOrdinalsSegmentResults();
} catch (IOException e) {
throw new UncheckedIOException(e);
Expand Down Expand Up @@ -510,6 +510,7 @@ private static class ValuesAggregator implements Releasable {
maxPageSize,
false
),
maxPageSize,
driverContext
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public Operator get(DriverContext driverContext) {
return new HashAggregationOperator(
aggregators,
() -> new TimeSeriesBlockHash(tsHashChannel, timeBucketChannel, driverContext),
maxPageSize,
driverContext
);
}
Expand Down Expand Up @@ -97,6 +98,7 @@ public Operator get(DriverContext driverContext) {
return new HashAggregationOperator(
aggregators,
() -> BlockHash.build(hashGroups, driverContext.blockFactory(), maxPageSize, false),
maxPageSize,
driverContext
);
}
Expand Down Expand Up @@ -125,6 +127,7 @@ public Operator get(DriverContext driverContext) {
return new HashAggregationOperator(
aggregators,
() -> BlockHash.build(groupings, driverContext.blockFactory(), maxPageSize, false),
maxPageSize,
driverContext
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ public String toString() {
randomPageSize(),
false
),
randomPageSize(),
driverContext
)
);
Expand Down
Loading

0 comments on commit da4a842

Please sign in to comment.