Skip to content

Commit

Permalink
fix: range scan more than one onError callback (#205)
Browse files Browse the repository at this point in the history
* fix: range scan more than one onError callback

* fix spotless

* fix typo

* fix typo

* remove unused field
  • Loading branch information
mattisonchao authored Jan 26, 2025
1 parent 89d51b3 commit 6e5b0fa
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -684,13 +683,13 @@ public void rangeScan(
@NonNull Set<RangeScanOption> options) {
gaugePendingRangeScanRequests.increment();

RangeScanConsumerWithShard timedConsumer =
new RangeScanConsumerWithShard() {
final RangeScanConsumer timedConsumer =
new RangeScanConsumer() {
final long startTime = System.nanoTime();
final AtomicLong totalSize = new AtomicLong();

@Override
public void onNext(long shardId, GetResult result) {
public void onNext(GetResult result) {
totalSize.addAndGet(result.getValue().length);
consumer.onNext(result);
}
Expand All @@ -703,7 +702,7 @@ public void onError(Throwable throwable) {
}

@Override
public void onCompleted(long shardId) {
public void onCompleted() {
gaugePendingRangeScanRequests.decrement();
counterRangeScanBytes.add(totalSize.longValue());
histogramRangeScanLatency.recordSuccess(System.nanoTime() - startTime);
Expand Down Expand Up @@ -731,20 +730,12 @@ public void onCompleted(long shardId) {
}
}

interface RangeScanConsumerWithShard {
void onNext(long shardId, GetResult result);

void onError(Throwable throwable);

void onCompleted(long shardId);
}

private void internalShardRangeScan(
long shardId,
String startKeyInclusive,
String endKeyExclusive,
Optional<String> secondaryIndexName,
RangeScanConsumerWithShard consumer) {
RangeScanConsumer consumer) {
var leader = shardManager.leader(shardId);
var stub = stubManager.getStub(leader);
var requestBuilder =
Expand All @@ -763,8 +754,7 @@ private void internalShardRangeScan(
@Override
public void onNext(RangeScanResponse response) {
for (int i = 0; i < response.getRecordsCount(); i++) {
consumer.onNext(
shardId, ProtoUtil.getResultFromProto("", response.getRecords(i)));
consumer.onNext(ProtoUtil.getResultFromProto("", response.getRecords(i)));
}
}

Expand All @@ -775,7 +765,7 @@ public void onError(Throwable t) {

@Override
public void onCompleted() {
consumer.onCompleted(shardId);
consumer.onCompleted();
}
});
}
Expand All @@ -784,41 +774,60 @@ private void internalRangeScanMultiShards(
String startKeyInclusive,
String endKeyExclusive,
Optional<String> secondaryIndexName,
RangeScanConsumerWithShard consumer) {
Set<Long> shardIds = shardManager.allShardIds();
RangeScanConsumer consumer) {
final Set<Long> shardIds = shardManager.allShardIds();
final RangeScanConsumer multiShardConsumer =
new SharedRangeScanConsumer(shardIds.size(), consumer);
for (long shardId : shardIds) {
internalShardRangeScan(
shardId, startKeyInclusive, endKeyExclusive, secondaryIndexName, multiShardConsumer);
}
}

RangeScanConsumerWithShard multiShardConsumer =
new RangeScanConsumerWithShard() {
private final Set<Long> pendingShards = new HashSet<>(shardIds);
private boolean failed = false;
static class SharedRangeScanConsumer implements RangeScanConsumer {
private final RangeScanConsumer delegate;

@Override
public synchronized void onNext(long shardId, GetResult result) {
if (!failed) {
consumer.onNext(shardId, result);
}
}
private int pendingCompletedRequests;
private boolean completed = false;
private Throwable completedException = null;

@Override
public synchronized void onError(Throwable throwable) {
failed = true;
consumer.onError(throwable);
}
SharedRangeScanConsumer(int shards, RangeScanConsumer delegate) {
this.pendingCompletedRequests = shards;
this.delegate = delegate;
}

@Override
public synchronized void onCompleted(long shardId) {
if (!failed) {
pendingShards.remove(shardId);
if (pendingShards.isEmpty()) {
consumer.onCompleted(shardId);
}
}
}
};
@Override
public synchronized void onNext(GetResult result) {
if (completed) {
return;
}
delegate.onNext(result);
}

for (long shardId : shardIds) {
internalShardRangeScan(
shardId, startKeyInclusive, endKeyExclusive, secondaryIndexName, multiShardConsumer);
@Override
public synchronized void onError(Throwable throwable) {
if (completedException == null) {
completedException = throwable;
} else {
completedException.addSuppressed(throwable);
}
if (completed) {
return;
}
completed = true;
delegate.onError(throwable);
}

@Override
public synchronized void onCompleted() {
if (completed) {
return;
}
pendingCompletedRequests -= 1;
if (pendingCompletedRequests == 0) {
completed = true;
delegate.onCompleted();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import io.streamnative.oxia.client.api.DeleteOption;
import io.streamnative.oxia.client.api.GetResult;
import io.streamnative.oxia.client.api.PutResult;
import io.streamnative.oxia.client.api.RangeScanConsumer;
import io.streamnative.oxia.client.api.Version;
import io.streamnative.oxia.client.batch.BatchManager;
import io.streamnative.oxia.client.batch.Batcher;
Expand All @@ -49,10 +50,18 @@
import io.streamnative.oxia.proto.ListResponse;
import io.streamnative.oxia.proto.OxiaClientGrpc;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.*;
import java.util.concurrent.CompletionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand Down Expand Up @@ -566,4 +575,118 @@ void close() throws Exception {
inOrder.verify(stubManager).close();
client = null;
}

@Test
void testShardShardRangeScanConsumer() {
final int shards = 5;
final List<GetResult> results = new ArrayList<>();
final AtomicInteger onErrorCount = new AtomicInteger(0);
final AtomicInteger onCompletedCount = new AtomicInteger(0);
final Supplier<RangeScanConsumer> newShardRangeScanConsumer =
() ->
new AsyncOxiaClientImpl.SharedRangeScanConsumer(
5,
new RangeScanConsumer() {
@Override
public void onNext(GetResult result) {
results.add(result);
}

@Override
public void onError(Throwable throwable) {
onErrorCount.incrementAndGet();
}

@Override
public void onCompleted() {
onCompletedCount.incrementAndGet();
}
});
final var tasks = new ArrayList<ForkJoinTask<?>>();

// (1) complete ok
final var shardRangeScanConsumer1 = newShardRangeScanConsumer.get();
for (int i = 0; i < shards; i++) {
final int fi = i;
final ForkJoinTask<?> task =
ForkJoinPool.commonPool()
.submit(
() -> {
shardRangeScanConsumer1.onNext(
new GetResult(
"shard-" + fi + "-0",
new byte[10],
new Version(1, 2, 3, 4, empty(), empty())));
shardRangeScanConsumer1.onNext(
new GetResult(
"shard-" + fi + "-1",
new byte[10],
new Version(1, 2, 3, 4, empty(), empty())));
shardRangeScanConsumer1.onCompleted();
});
tasks.add(task);
}
tasks.forEach(ForkJoinTask::join);
var keys = results.stream().map(GetResult::getKey).toList();
for (int i = 0; i < shards; i++) {
Assertions.assertTrue(keys.contains("shard-" + i + "-0"));
Assertions.assertTrue(keys.contains("shard-" + i + "-1"));
}
Assertions.assertEquals(0, onErrorCount.get());
Assertions.assertEquals(1, onCompletedCount.get());

tasks.clear();
onErrorCount.set(0);
onCompletedCount.set(0);
results.clear();

// (2) complete partial exception
final var shardRangeScanConsumer2 = newShardRangeScanConsumer.get();
for (int i = 0; i < shards; i++) {
final int fi = i;
final ForkJoinTask<?> task =
ForkJoinPool.commonPool()
.submit(
() -> {
if (fi % 2 == 0) {
shardRangeScanConsumer2.onError(new IllegalStateException());
return;
}
shardRangeScanConsumer2.onNext(
new GetResult(
"shard-" + fi + "-0",
new byte[10],
new Version(1, 2, 3, 4, empty(), empty())));
shardRangeScanConsumer2.onNext(
new GetResult(
"shard-" + fi + "-1",
new byte[10],
new Version(1, 2, 3, 4, empty(), empty())));
shardRangeScanConsumer2.onCompleted();
});
tasks.add(task);
}
tasks.forEach(ForkJoinTask::join);

Assertions.assertEquals(1, onErrorCount.get());
Assertions.assertEquals(0, onCompletedCount.get());

tasks.clear();
onErrorCount.set(0);
onCompletedCount.set(0);
results.clear();

// (3) complete all exception
final var shardRangeScanConsumer3 = newShardRangeScanConsumer.get();
for (int i = 0; i < shards; i++) {
final ForkJoinTask<?> task =
ForkJoinPool.commonPool()
.submit(() -> shardRangeScanConsumer3.onError(new IllegalStateException()));
tasks.add(task);
}
tasks.forEach(ForkJoinTask::join);
Assertions.assertEquals(1, onErrorCount.get());
Assertions.assertEquals(0, onCompletedCount.get());
Assertions.assertEquals(0, results.size());
}
}

0 comments on commit 6e5b0fa

Please sign in to comment.