Skip to content

Commit

Permalink
Fix sliding window node failed to restore from check/save point
Browse files Browse the repository at this point in the history
This close alibaba#256
  • Loading branch information
daugraph committed Oct 8, 2023
1 parent 15d6382 commit e410623
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,14 @@ public SlidingWindowPreprocessAggregateFunction(

@Override
public Row createAccumulator() {
Row acc = Row.withNames();
int arity = keyFields.size() + 1 + aggDescriptors.getAggFieldDescriptors().size();
Object[] values = new Object[arity];
for (AggregationFieldsDescriptor.AggregationFieldDescriptor descriptor :
aggDescriptors.getAggFieldDescriptors()) {
acc.setField(
descriptor.fieldName, descriptor.aggFuncWithoutRetract.createAccumulator());
int pos = keyFields.size() + 1 + aggDescriptors.getAggFieldIdx(descriptor);
values[pos] = descriptor.aggFuncWithoutRetract.createAccumulator();
}
return acc;
return Row.of(values);
}

@Override
Expand All @@ -239,16 +240,17 @@ public Row add(Row row, Row acc) {
for (AggregationFieldsDescriptor.AggregationFieldDescriptor descriptor :
aggDescriptors.getAggFieldDescriptors()) {
Object fieldValue = row.getFieldAs(descriptor.fieldName);
Object fieldAcc = acc.getField(descriptor.fieldName);
int pos = keyFields.size() + 1 + aggDescriptors.getAggFieldIdx(descriptor);
Object fieldAcc = acc.getFieldAs(pos);
descriptor.aggFuncWithoutRetract.add(fieldAcc, fieldValue, timestamp);
}

if (acc.getField(rowTimeFieldName) == null) {
acc.setField(
rowTimeFieldName,
Instant.ofEpochMilli(getWindowTime(timestamp, size, offset)));
if (acc.getField(keyFields.size()) == null) {
acc.setField(keyFields.size(), Instant.ofEpochMilli(getWindowTime(timestamp, size, offset)));
int idx = 0;
for (String key : keyFields) {
acc.setField(key, row.getField(key));
acc.setField(idx, row.getField(key));
idx += 1;
}
}

Expand All @@ -264,14 +266,17 @@ public Row getResult(Row acc) {
public Row merge(Row acc1, Row acc2) {
for (AggregationFieldsDescriptor.AggregationFieldDescriptor descriptor :
aggDescriptors.getAggFieldDescriptors()) {
Object fieldAcc1 = acc1.getField(descriptor.fieldName);
Object fieldAcc2 = acc2.getField(descriptor.fieldName);
int pos = keyFields.size() + 1 + aggDescriptors.getAggFieldIdx(descriptor);
Object fieldAcc1 = acc1.getField(pos);
Object fieldAcc2 = acc2.getField(pos);
descriptor.aggFuncWithoutRetract.merge(fieldAcc1, fieldAcc2);
}
if (acc1.getField(rowTimeFieldName) == null) {
acc1.setField(rowTimeFieldName, acc2.getField(rowTimeFieldName));
if (acc1.getField(keyFields.size()) == null) {
acc1.setField(keyFields.size(), acc2.getField(keyFields.size()));
int idx = 0;
for (String key : keyFields) {
acc1.setField(key, acc2.getField(key));
acc1.setField(idx, acc2.getField(idx));
idx += 1;
}
}
return acc1;
Expand Down Expand Up @@ -354,27 +359,33 @@ public static Table applySlidingWindowAggregationProcess(
rowTimeFieldName,
keyFieldNames);
}
rowDataStream =
rowDataStream
.keyBy(
(KeySelector<Row, Row>)
value ->
Row.of(
Arrays.stream(keyFieldNames)
.map(value::getField)
.toArray()))
.process(
new SlidingWindowKeyedProcessFunction(
aggregationFieldsDescriptor,
rowTypeSerializer,
resultRowTypeInfo.createSerializer(null),
keyFieldNames,
rowTimeFieldName,
windowDescriptor.stepSize.toMillis(),
expiredRowHandler,
skipSameWindowOutput))
.setParallelism(rowDataStream.getParallelism())
.returns(resultRowTypeInfo);
rowDataStream = rowDataStream
.keyBy((KeySelector<Row, Row>) row -> {
List<Object> values = new ArrayList<>();
for (int i = 0; i < keyFieldNames.length; i += 1) {
Object value;
try {
value = row.getField(i);
} catch (IllegalArgumentException e) {
value = row.getField(keyFieldNames[i]);
}
values.add(value);
}
return Row.of(values.toArray(new Object[0]));
})
.process(
new SlidingWindowKeyedProcessFunction(
aggregationFieldsDescriptor,
rowTypeSerializer,
resultRowTypeInfo.createSerializer(null),
keyFieldNames,
rowTimeFieldName,
windowDescriptor.stepSize.toMillis(),
expiredRowHandler,
skipSameWindowOutput)
).setParallelism(rowDataStream.getParallelism())
.returns(resultRowTypeInfo);


Table table =
tEnv.fromDataStream(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,14 @@ public void onTimer(
break;
}
for (Row row : state.timestampToRows.get(rowTime)) {
descriptor.aggFunc.retractAccumulator(
accumulatorState, row.getField(descriptor.fieldName));
Object value;
try {
int idx = keyFieldNames.length + 1 + aggregationFieldsDescriptor.getAggFieldIdx(descriptor);
value = row.getField(idx);
} catch (IllegalArgumentException e) {
value = row.getField(descriptor.fieldName);
}
descriptor.aggFunc.retractAccumulator(accumulatorState, value);
}
}
if (leftIdx < timestampList.size() && timestampList.get(leftIdx) <= timestamp) {
Expand Down

0 comments on commit e410623

Please sign in to comment.