Skip to content

Commit

Permalink
[FLINK-36338][State] Properly handle KeyContext when using AsyncKeyed…
Browse files Browse the repository at this point in the history
…StateBackendAdaptor (apache#25367)
  • Loading branch information
Zakelly authored Sep 23, 2024
1 parent 32dc6c0 commit b14ab76
Show file tree
Hide file tree
Showing 22 changed files with 108 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ public class AsyncExecutionController<K> implements StateRequestHandler {
/** The reference of epoch manager. */
final EpochManager epochManager;

/** The listener of context switch. */
final SwitchContextListener<K> switchContextListener;

/**
* The parallel mode of epoch execution. Keep this field internal for now, until we could see
* the concrete need for {@link ParallelMode#PARALLEL_BETWEEN_EPOCH} from average users.
Expand All @@ -124,7 +127,8 @@ public AsyncExecutionController(
int maxParallelism,
int batchSize,
long bufferTimeout,
int maxInFlightRecords) {
int maxInFlightRecords,
SwitchContextListener<K> switchContextListener) {
this.keyAccountingUnit = new KeyAccountingUnit<>(maxInFlightRecords);
this.mailboxExecutor = mailboxExecutor;
this.exceptionHandler = exceptionHandler;
Expand All @@ -148,6 +152,7 @@ public AsyncExecutionController(
"AEC-buffer-timeout"));

this.epochManager = new EpochManager(this);
this.switchContextListener = switchContextListener;
LOG.info(
"Create AsyncExecutionController: batchSize {}, bufferTimeout {}, maxInFlightRecordNum {}, epochParallelMode {}",
this.batchSize,
Expand Down Expand Up @@ -189,6 +194,9 @@ public RecordContext<K> buildContext(Object record, K key) {
*/
public void setCurrentContext(RecordContext<K> switchingContext) {
currentContext = switchingContext;
if (switchContextListener != null) {
switchContextListener.switchContext(switchingContext);
}
}

/**
Expand Down Expand Up @@ -374,4 +382,9 @@ public StateExecutor getStateExecutor() {
public int getInFlightRecordNum() {
return inFlightRecordNum.get();
}

/** A listener listens the key context switch. */
public interface SwitchContextListener<K> {
void switchContext(RecordContext<K> context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ public void initializeState(StreamTaskStateInitializer streamTaskStateManager)
maxParallelism,
asyncBufferSize,
asyncBufferTimeout,
inFlightRecordsLimit);
inFlightRecordsLimit,
asyncKeyedStateBackend);
asyncKeyedStateBackend.setup(asyncExecutionController);
} else if (stateHandler.getKeyedStateBackend() != null) {
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ public final void initializeState(StreamTaskStateInitializer streamTaskStateMana
maxParallelism,
asyncBufferSize,
asyncBufferTimeout,
inFlightRecordsLimit);
inFlightRecordsLimit,
asyncKeyedStateBackend);
asyncKeyedStateBackend.setup(asyncExecutionController);
} else if (stateHandler.getKeyedStateBackend() != null) {
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,13 @@ public void setCurrentKey(K newKey) {
KeyGroupRangeAssignment.assignToKeyGroup(newKey, numberOfKeyGroups));
}

@Override
public void setCurrentKeyAndKeyGroup(K newKey, int newKeyGroupIndex) {
notifyKeySelected(newKey);
this.keyContext.setCurrentKey(newKey);
this.keyContext.setCurrentKeyGroupIndex(newKeyGroupIndex);
}

private void notifyKeySelected(K newKey) {
// we prefer a for-loop over other iteration schemes for performance reasons here.
for (int i = 0; i < keySelectionListeners.size(); ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import org.apache.flink.api.common.state.InternalCheckpointListener;
import org.apache.flink.api.common.state.v2.State;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.asyncprocessing.AsyncExecutionController;
import org.apache.flink.runtime.asyncprocessing.RecordContext;
import org.apache.flink.runtime.asyncprocessing.StateExecutor;
import org.apache.flink.runtime.asyncprocessing.StateRequestHandler;
import org.apache.flink.runtime.state.v2.StateDescriptor;
Expand All @@ -36,11 +38,12 @@
* in batch.
*/
@Internal
public interface AsyncKeyedStateBackend
public interface AsyncKeyedStateBackend<K>
extends Snapshotable<SnapshotResult<KeyedStateHandle>>,
InternalCheckpointListener,
Disposable,
Closeable {
Closeable,
AsyncExecutionController.SwitchContextListener<K> {

/**
* Initializes with some contexts.
Expand Down Expand Up @@ -80,6 +83,10 @@ <N, S extends State, SV> S createState(
@Nonnull
StateExecutor createStateExecutor();

/** By default, a state backend does nothing when a key is switched in async processing. */
@Override
default void switchContext(RecordContext<K> context) {}

@Override
void dispose();
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ public interface KeyedStateBackend<K>
/** @return Current key. */
K getCurrentKey();

/** Act as a fast path for {@link #setCurrentKey} when the key group is known. */
void setCurrentKeyAndKeyGroup(K newKey, int newKeyGroupIndex);

/** @return Serializer of the key. */
TypeSerializer<K> getKeySerializer();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ <K> CheckpointableKeyedStateBackend<K> createKeyedStateBackend(
* backend.
*/
@Experimental
default <K> AsyncKeyedStateBackend createAsyncKeyedStateBackend(
default <K> AsyncKeyedStateBackend<K> createAsyncKeyedStateBackend(
KeyedStateBackendParameters<K> parameters) throws Exception {
throw new UnsupportedOperationException(
"Don't support createAsyncKeyedStateBackend by default");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
/** Default implementation of KeyedStateStoreV2. */
public class DefaultKeyedStateStoreV2 implements KeyedStateStoreV2 {

private final AsyncKeyedStateBackend asyncKeyedStateBackend;
private final AsyncKeyedStateBackend<?> asyncKeyedStateBackend;

public DefaultKeyedStateStoreV2(@Nonnull AsyncKeyedStateBackend asyncKeyedStateBackend) {
this.asyncKeyedStateBackend = Preconditions.checkNotNull(asyncKeyedStateBackend);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.flink.api.common.state.InternalCheckpointListener;
import org.apache.flink.api.common.state.v2.State;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.asyncprocessing.RecordContext;
import org.apache.flink.runtime.asyncprocessing.StateExecutor;
import org.apache.flink.runtime.asyncprocessing.StateRequestHandler;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
Expand Down Expand Up @@ -49,7 +50,7 @@
*
* @param <K> The key by which state is keyed.
*/
public class AsyncKeyedStateBackendAdaptor<K> implements AsyncKeyedStateBackend {
public class AsyncKeyedStateBackendAdaptor<K> implements AsyncKeyedStateBackend<K> {
private final CheckpointableKeyedStateBackend<K> keyedStateBackend;

public AsyncKeyedStateBackendAdaptor(CheckpointableKeyedStateBackend<K> keyedStateBackend) {
Expand Down Expand Up @@ -95,6 +96,11 @@ public StateExecutor createStateExecutor() {
return null;
}

@Override
public void switchContext(RecordContext<K> context) {
keyedStateBackend.setCurrentKeyAndKeyGroup(context.getKey(), context.getKeyGroup());
}

@Override
public void dispose() {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public class StreamOperatorStateHandler {

protected static final Logger LOG = LoggerFactory.getLogger(StreamOperatorStateHandler.class);

@Nullable private final AsyncKeyedStateBackend asyncKeyedStateBackend;
@Nullable private final AsyncKeyedStateBackend<?> asyncKeyedStateBackend;

@Nullable private final KeyedStateStoreV2 keyedStateStoreV2;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ public K getCurrentKey() {
return currentKey;
}

@Override
public void setCurrentKeyAndKeyGroup(K newKey, int newKeyGroupIndex) {
setCurrentKey(newKey);
}

@Override
public TypeSerializer<K> getKeySerializer() {
return keySerializer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,14 @@ public void testPartialLoading() {
TestIteratorStateExecutor stateExecutor = new TestIteratorStateExecutor(100, 3);
AsyncExecutionController aec =
new AsyncExecutionController(
new SyncMailboxExecutor(), (a, b) -> {}, stateExecutor, 1, 100, 1000, 1);
new SyncMailboxExecutor(),
(a, b) -> {},
stateExecutor,
1,
100,
1000,
1,
null);
stateExecutor.bindAec(aec);
RecordContext<String> recordContext = aec.buildContext("1", "key1");
aec.setCurrentContext(recordContext);
Expand Down Expand Up @@ -77,7 +84,14 @@ public void testPartialLoadingWithReturnValue() {
TestIteratorStateExecutor stateExecutor = new TestIteratorStateExecutor(100, 3);
AsyncExecutionController aec =
new AsyncExecutionController(
new SyncMailboxExecutor(), (a, b) -> {}, stateExecutor, 1, 100, 1000, 1);
new SyncMailboxExecutor(),
(a, b) -> {},
stateExecutor,
1,
100,
1000,
1,
null);
stateExecutor.bindAec(aec);
RecordContext<String> recordContext = aec.buildContext("1", "key1");
aec.setCurrentContext(recordContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@

/** Test for {@link AsyncExecutionController}. */
class AsyncExecutionControllerTest {
AsyncExecutionController aec;
AsyncExecutionController<String> aec;
AtomicInteger output;
TestValueState valueState;

Expand Down Expand Up @@ -90,7 +90,7 @@ void setup(
StateBackend testAsyncStateBackend =
StateBackendTestUtils.buildAsyncStateBackend(stateSupplier, stateExecutor);
assertThat(testAsyncStateBackend.supportsAsyncKeyedStateBackend()).isTrue();
AsyncKeyedStateBackend asyncKeyedStateBackend;
AsyncKeyedStateBackend<String> asyncKeyedStateBackend;
try {
asyncKeyedStateBackend = testAsyncStateBackend.createAsyncKeyedStateBackend(null);
} catch (Exception e) {
Expand All @@ -106,7 +106,8 @@ void setup(
128,
batchSize,
timeout,
maxInFlight);
maxInFlight,
null);
asyncKeyedStateBackend.setup(aec);

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ public OperatorStateBackend createOperatorStateBackend(
}
}

private static class TestAsyncKeyedStateBackend implements AsyncKeyedStateBackend {
private static class TestAsyncKeyedStateBackend<K> implements AsyncKeyedStateBackend<K> {

private final Supplier<org.apache.flink.api.common.state.v2.State> innerStateSupplier;
private final StateExecutor stateExecutor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ void setup() {
1,
1,
1000,
1);
1,
null);
exception = new AtomicReference<>(null);
}

Expand Down Expand Up @@ -124,9 +125,9 @@ public boolean supportsAsyncKeyedStateBackend() {
}

@Override
public <K> AsyncKeyedStateBackend createAsyncKeyedStateBackend(
public <K> AsyncKeyedStateBackend<K> createAsyncKeyedStateBackend(
KeyedStateBackendParameters<K> parameters) {
return new AsyncKeyedStateBackend() {
return new AsyncKeyedStateBackend<K>() {
@Nonnull
@Override
public RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,16 @@ public void testMergeNamespace() throws Exception {
ReduceFunction<Integer> reducer = Integer::sum;
ReducingStateDescriptor<Integer> descriptor =
new ReducingStateDescriptor<>("testState", reducer, BasicTypeInfo.INT_TYPE_INFO);
AsyncExecutionController aec =
new AsyncExecutionController(
AsyncExecutionController<String> aec =
new AsyncExecutionController<>(
new SyncMailboxExecutor(),
(a, b) -> {},
new ReducingStateExecutor(),
1,
100,
10000,
1);
1,
null);
AbstractReducingState<String, String, Integer> reducingState =
new AbstractReducingState<>(aec, descriptor);
aec.setCurrentContext(aec.buildContext("test", "test"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

/** Test for {@link InternalTimerServiceAsyncImpl}. */
class InternalTimerServiceAsyncImplTest {
private AsyncExecutionController asyncExecutionController;
private AsyncExecutionController<String> asyncExecutionController;
private TestKeyContext keyContext;
private TestProcessingTimeService processingTimeService;
private InternalTimerServiceAsyncImpl<Integer, String> service;
Expand All @@ -59,14 +59,15 @@ public void handleException(String message, Throwable exception) {
@BeforeEach
void setup() throws Exception {
asyncExecutionController =
new AsyncExecutionController(
new AsyncExecutionController<>(
new SyncMailboxExecutor(),
exceptionHandler,
new MockStateExecutor(),
128,
2,
1000L,
10);
10,
null);
// ensure arbitrary key is in the key group
int totalKeyGroups = 128;
KeyGroupRange testKeyGroupList = new KeyGroupRange(0, totalKeyGroups - 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,11 @@ public K getCurrentKey() {
return keyedStateBackend.getCurrentKey();
}

@Override
public void setCurrentKeyAndKeyGroup(K newKey, int newKeyGroupIndex) {
keyedStateBackend.setCurrentKeyAndKeyGroup(newKey, newKeyGroupIndex);
}

@Override
public TypeSerializer<K> getKeySerializer() {
return keyedStateBackend.getKeySerializer();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ public void setCurrentKey(K newKey) {
keyedStateBackend.setCurrentKey(newKey);
}

@Override
public void setCurrentKeyAndKeyGroup(K newKey, int newKeyGroupIndex) {
keyedStateBackend.setCurrentKeyAndKeyGroup(newKey, newKeyGroupIndex);
}

@Override
public void notifyCheckpointComplete(long checkpointId) throws Exception {
keyedStateBackend.notifyCheckpointComplete(checkpointId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
* A KeyedStateBackend that stores its state in {@code ForSt}. This state backend can store very
* large state that exceeds memory even disk to remote storage.
*/
public class ForStKeyedStateBackend<K> implements AsyncKeyedStateBackend {
public class ForStKeyedStateBackend<K> implements AsyncKeyedStateBackend<K> {

private static final Logger LOG = LoggerFactory.getLogger(ForStKeyedStateBackend.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,12 @@ public void setCurrentKey(K newKey) {
sharedRocksKeyBuilder.setKeyAndKeyGroup(getCurrentKey(), getCurrentKeyGroupIndex());
}

@Override
public void setCurrentKeyAndKeyGroup(K newKey, int newKeyGroupIndex) {
super.setCurrentKeyAndKeyGroup(newKey, newKeyGroupIndex);
sharedRocksKeyBuilder.setKeyAndKeyGroup(getCurrentKey(), getCurrentKeyGroupIndex());
}

/** Should only be called by one thread, and only after all accesses to the DB happened. */
@Override
public void dispose() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,12 @@ public void setCurrentKey(K newKey) {
sharedRocksKeyBuilder.setKeyAndKeyGroup(getCurrentKey(), getCurrentKeyGroupIndex());
}

@Override
public void setCurrentKeyAndKeyGroup(K newKey, int newKeyGroupIndex) {
super.setCurrentKeyAndKeyGroup(newKey, newKeyGroupIndex);
sharedRocksKeyBuilder.setKeyAndKeyGroup(getCurrentKey(), getCurrentKeyGroupIndex());
}

/** Should only be called by one thread, and only after all accesses to the DB happened. */
@Override
public void dispose() {
Expand Down

0 comments on commit b14ab76

Please sign in to comment.