diff --git a/src/main/java/io/weaviate/client/base/util/Futures.java b/src/main/java/io/weaviate/client/base/util/Futures.java new file mode 100644 index 00000000..42806192 --- /dev/null +++ b/src/main/java/io/weaviate/client/base/util/Futures.java @@ -0,0 +1,50 @@ +package io.weaviate.client.base.util; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; + +public class Futures { + + private Futures() { + } + + public static CompletableFuture supplyDelayed(Supplier> supplier, long millis, + Executor executor) throws InterruptedException { + if (executor instanceof ScheduledExecutorService) { + return CompletableFuture.supplyAsync( + supplier, + command -> ((ScheduledExecutorService) executor).schedule(command, millis, TimeUnit.MILLISECONDS) + ).thenCompose(f -> f); + } + Thread.sleep(millis); + return supplier.get(); + } + + public static CompletableFuture thenComposeAsync(CompletableFuture future, Function> callback, + Executor executor) { + if (executor != null) { + return future.thenComposeAsync(callback, executor); + } + return future.thenComposeAsync(callback); + } + + public static CompletableFuture handleAsync(CompletableFuture future, BiFunction> callback, + Executor executor) { + if (executor != null) { + return future.handleAsync(callback, executor).thenCompose(f -> f); + } + return future.handleAsync(callback).thenCompose(f -> f); + } + + public static CompletableFuture supplyAsync(Supplier supplier, Executor executor) { + if (executor != null) { + return CompletableFuture.supplyAsync(supplier, executor); + } + return CompletableFuture.supplyAsync(supplier); + } +} diff --git a/src/main/java/io/weaviate/client/v1/async/backup/Backup.java b/src/main/java/io/weaviate/client/v1/async/backup/Backup.java index dbad121f..3a76b358 100644 --- a/src/main/java/io/weaviate/client/v1/async/backup/Backup.java +++ b/src/main/java/io/weaviate/client/v1/async/backup/Backup.java @@ -10,6 +10,8 @@ import lombok.RequiredArgsConstructor; import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; +import java.util.concurrent.Executor; + @RequiredArgsConstructor public class Backup { @@ -18,7 +20,11 @@ public class Backup { public BackupCreator creator() { - return new BackupCreator(client, config, createStatusGetter()); + return creator(null); + } + + public BackupCreator creator(Executor executor) { + return new BackupCreator(client, config, createStatusGetter(), executor); } public BackupCreateStatusGetter createStatusGetter() { @@ -26,7 +32,11 @@ public BackupCreateStatusGetter createStatusGetter() { } public BackupRestorer restorer() { - return new BackupRestorer(client, config, restoreStatusGetter()); + return restorer(null); + } + + public BackupRestorer restorer(Executor executor) { + return new BackupRestorer(client, config, restoreStatusGetter(), executor); } public BackupRestoreStatusGetter restoreStatusGetter() { diff --git a/src/main/java/io/weaviate/client/v1/async/backup/api/BackupCreator.java b/src/main/java/io/weaviate/client/v1/async/backup/api/BackupCreator.java index 2c7f5716..50fbe032 100644 --- a/src/main/java/io/weaviate/client/v1/async/backup/api/BackupCreator.java +++ b/src/main/java/io/weaviate/client/v1/async/backup/api/BackupCreator.java @@ -8,6 +8,7 @@ import io.weaviate.client.base.WeaviateError; import io.weaviate.client.base.WeaviateErrorMessage; import io.weaviate.client.base.WeaviateErrorResponse; +import io.weaviate.client.base.util.Futures; import io.weaviate.client.base.util.UrlEncoder; import io.weaviate.client.v1.backup.model.BackupCreateResponse; import io.weaviate.client.v1.backup.model.BackupCreateStatusResponse; @@ -22,6 +23,7 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; +import java.util.concurrent.Executor; import java.util.concurrent.Future; public class BackupCreator extends AsyncBaseClient @@ -36,11 +38,13 @@ public class BackupCreator extends AsyncBaseClient private String backupId; private BackupCreateConfig config; private boolean waitForCompletion; + private final Executor executor; - public BackupCreator(CloseableHttpAsyncClient client, Config config, BackupCreateStatusGetter statusGetter) { + public BackupCreator(CloseableHttpAsyncClient client, Config config, BackupCreateStatusGetter statusGetter, Executor executor) { super(client, config); this.statusGetter = statusGetter; + this.executor = executor; } @@ -158,7 +162,7 @@ public void cancelled() { private CompletableFuture> getStatusRecursively(String backend, String backupId, Result createResult) { - return getStatus(backend, backupId).thenCompose(createStatusResult -> { + return Futures.thenComposeAsync(getStatus(backend, backupId), createStatusResult -> { boolean isRunning = Optional.of(createStatusResult) .filter(r -> !r.hasErrors()) .map(Result::getResult) @@ -176,14 +180,13 @@ private CompletableFuture> getStatusRecursively(Str if (isRunning) { try { - Thread.sleep(WAIT_INTERVAL); - return getStatusRecursively(backend, backupId, createResult); + return Futures.supplyDelayed(() -> getStatusRecursively(backend, backupId, createResult), WAIT_INTERVAL, executor); } catch (InterruptedException e) { throw new CompletionException(e); } } return CompletableFuture.completedFuture(merge(createStatusResult, createResult)); - }); + }, executor); } private Result merge(Result createStatusResult, diff --git a/src/main/java/io/weaviate/client/v1/async/backup/api/BackupRestorer.java b/src/main/java/io/weaviate/client/v1/async/backup/api/BackupRestorer.java index 9a416274..b10da038 100644 --- a/src/main/java/io/weaviate/client/v1/async/backup/api/BackupRestorer.java +++ b/src/main/java/io/weaviate/client/v1/async/backup/api/BackupRestorer.java @@ -8,6 +8,7 @@ import io.weaviate.client.base.WeaviateError; import io.weaviate.client.base.WeaviateErrorMessage; import io.weaviate.client.base.WeaviateErrorResponse; +import io.weaviate.client.base.util.Futures; import io.weaviate.client.base.util.UrlEncoder; import io.weaviate.client.v1.backup.model.BackupRestoreResponse; import io.weaviate.client.v1.backup.model.BackupRestoreStatusResponse; @@ -22,6 +23,7 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; +import java.util.concurrent.Executor; import java.util.concurrent.Future; public class BackupRestorer extends AsyncBaseClient @@ -36,11 +38,13 @@ public class BackupRestorer extends AsyncBaseClient private String backupId; private BackupRestoreConfig config; private boolean waitForCompletion; + private final Executor executor; - public BackupRestorer(CloseableHttpAsyncClient client, Config config, BackupRestoreStatusGetter statusGetter) { + public BackupRestorer(CloseableHttpAsyncClient client, Config config, BackupRestoreStatusGetter statusGetter, Executor executor) { super(client, config); this.statusGetter = statusGetter; + this.executor = executor; } @@ -158,7 +162,7 @@ public void cancelled() { private CompletableFuture> getStatusRecursively(String backend, String backupId, Result restoreResult) { - return getStatus(backend, backupId).thenCompose(restoreStatusResult -> { + return Futures.thenComposeAsync(getStatus(backend, backupId), restoreStatusResult -> { boolean isRunning = Optional.of(restoreStatusResult) .filter(r -> !r.hasErrors()) .map(Result::getResult) @@ -176,14 +180,13 @@ private CompletableFuture> getStatusRecursively(St if (isRunning) { try { - Thread.sleep(WAIT_INTERVAL); - return getStatusRecursively(backend, backupId, restoreResult); + return Futures.supplyDelayed(() -> getStatusRecursively(backend, backupId, restoreResult), WAIT_INTERVAL, executor); } catch (InterruptedException e) { throw new CompletionException(e); } } return CompletableFuture.completedFuture(merge(restoreStatusResult, restoreResult)); - }); + }, executor); } private Result merge(Result restoreStatusResult, diff --git a/src/main/java/io/weaviate/client/v1/async/batch/Batch.java b/src/main/java/io/weaviate/client/v1/async/batch/Batch.java index 9b9dc959..3ee8314a 100644 --- a/src/main/java/io/weaviate/client/v1/async/batch/Batch.java +++ b/src/main/java/io/weaviate/client/v1/async/batch/Batch.java @@ -14,6 +14,8 @@ import io.weaviate.client.v1.batch.util.ReferencesPath; import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; +import java.util.concurrent.Executor; + public class Batch { private final CloseableHttpAsyncClient client; private final Config config; @@ -25,7 +27,7 @@ public class Batch { private final AccessTokenProvider tokenProvider; public Batch(CloseableHttpAsyncClient client, Config config, DbVersionSupport dbVersionSupport, - GrpcVersionSupport grpcVersionSupport, AccessTokenProvider tokenProvider, Data data) { + GrpcVersionSupport grpcVersionSupport, AccessTokenProvider tokenProvider, Data data) { this.client = client; this.config = config; this.objectsPath = new ObjectsPath(); @@ -37,37 +39,86 @@ public Batch(CloseableHttpAsyncClient client, Config config, DbVersionSupport db } public ObjectsBatcher objectsBatcher() { - return objectsBatcher(ObjectsBatcher.BatchRetriesConfig.defaultConfig().build()); + return objectsBatcher(ObjectsBatcher.BatchRetriesConfig.defaultConfig().build(), null); + } + + public ObjectsBatcher objectsBatcher(Executor executor) { + return objectsBatcher(ObjectsBatcher.BatchRetriesConfig.defaultConfig().build(), executor); } public ObjectsBatcher objectsBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig) { - return ObjectsBatcher.create(client, config, data, objectsPath, tokenProvider, grpcVersionSupport, batchRetriesConfig); + return objectsBatcher(batchRetriesConfig, null); + } + + public ObjectsBatcher objectsBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig, Executor executor) { + return ObjectsBatcher.create(client, config, data, objectsPath, tokenProvider, grpcVersionSupport, + batchRetriesConfig, executor); } public ObjectsBatcher objectsAutoBatcher() { return objectsAutoBatcher( ObjectsBatcher.BatchRetriesConfig.defaultConfig().build(), - ObjectsBatcher.AutoBatchConfig.defaultConfig().build() + ObjectsBatcher.AutoBatchConfig.defaultConfig().build(), + null + ); + } + + public ObjectsBatcher objectsAutoBatcher(Executor executor) { + return objectsAutoBatcher( + ObjectsBatcher.BatchRetriesConfig.defaultConfig().build(), + ObjectsBatcher.AutoBatchConfig.defaultConfig().build(), + executor ); } public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig) { return objectsAutoBatcher( batchRetriesConfig, - ObjectsBatcher.AutoBatchConfig.defaultConfig().build() + ObjectsBatcher.AutoBatchConfig.defaultConfig().build(), + null + ); + } + + public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig, + Executor executor) { + return objectsAutoBatcher( + batchRetriesConfig, + ObjectsBatcher.AutoBatchConfig.defaultConfig().build(), + executor ); } public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.AutoBatchConfig autoBatchConfig) { return objectsAutoBatcher( ObjectsBatcher.BatchRetriesConfig.defaultConfig().build(), - autoBatchConfig + autoBatchConfig, + null + ); + } + + public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.AutoBatchConfig autoBatchConfig, + Executor executor) { + return objectsAutoBatcher( + ObjectsBatcher.BatchRetriesConfig.defaultConfig().build(), + autoBatchConfig, + executor + ); + } + + public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig, + ObjectsBatcher.AutoBatchConfig autoBatchConfig) { + return objectsAutoBatcher( + batchRetriesConfig, + autoBatchConfig, + null ); } public ObjectsBatcher objectsAutoBatcher(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig, - ObjectsBatcher.AutoBatchConfig autoBatchConfig) { - return ObjectsBatcher.createAuto(client, config, data, objectsPath, tokenProvider, grpcVersionSupport, batchRetriesConfig, autoBatchConfig); + ObjectsBatcher.AutoBatchConfig autoBatchConfig, + Executor executor) { + return ObjectsBatcher.createAuto(client, config, data, objectsPath, tokenProvider, grpcVersionSupport, + batchRetriesConfig, autoBatchConfig, executor); } public ObjectsBatchDeleter objectsBatchDeleter() { @@ -78,38 +129,85 @@ public ReferencePayloadBuilder referencePayloadBuilder() { return new ReferencePayloadBuilder(beaconPath); } - // TODO: implement async ReferencesBatcher public ReferencesBatcher referencesBatcher() { - return referencesBatcher(ReferencesBatcher.BatchRetriesConfig.defaultConfig().build()); + return referencesBatcher(ReferencesBatcher.BatchRetriesConfig.defaultConfig().build(), null); + } + + public ReferencesBatcher referencesBatcher(Executor executor) { + return referencesBatcher(ReferencesBatcher.BatchRetriesConfig.defaultConfig().build(), executor); } public ReferencesBatcher referencesBatcher(ReferencesBatcher.BatchRetriesConfig batchRetriesConfig) { - return ReferencesBatcher.create(client, config, referencesPath, batchRetriesConfig); + return referencesBatcher(batchRetriesConfig, null); + } + + public ReferencesBatcher referencesBatcher(ReferencesBatcher.BatchRetriesConfig batchRetriesConfig, + Executor executor) { + return ReferencesBatcher.create(client, config, referencesPath, batchRetriesConfig, executor); } public ReferencesBatcher referencesAutoBatcher() { return referencesAutoBatcher( ReferencesBatcher.BatchRetriesConfig.defaultConfig().build(), - ReferencesBatcher.AutoBatchConfig.defaultConfig().build() + ReferencesBatcher.AutoBatchConfig.defaultConfig().build(), + null + ); + } + + public ReferencesBatcher referencesAutoBatcher(Executor executor) { + return referencesAutoBatcher( + ReferencesBatcher.BatchRetriesConfig.defaultConfig().build(), + ReferencesBatcher.AutoBatchConfig.defaultConfig().build(), + executor ); } public ReferencesBatcher referencesAutoBatcher(ReferencesBatcher.BatchRetriesConfig batchRetriesConfig) { return referencesAutoBatcher( batchRetriesConfig, - ReferencesBatcher.AutoBatchConfig.defaultConfig().build() + ReferencesBatcher.AutoBatchConfig.defaultConfig().build(), + null + ); + } + + public ReferencesBatcher referencesAutoBatcher(ReferencesBatcher.BatchRetriesConfig batchRetriesConfig, + Executor executor) { + return referencesAutoBatcher( + batchRetriesConfig, + ReferencesBatcher.AutoBatchConfig.defaultConfig().build(), + executor ); } public ReferencesBatcher referencesAutoBatcher(ReferencesBatcher.AutoBatchConfig autoBatchConfig) { return referencesAutoBatcher( ReferencesBatcher.BatchRetriesConfig.defaultConfig().build(), - autoBatchConfig + autoBatchConfig, + null + ); + } + + public ReferencesBatcher referencesAutoBatcher(ReferencesBatcher.AutoBatchConfig autoBatchConfig, + Executor executor) { + return referencesAutoBatcher( + ReferencesBatcher.BatchRetriesConfig.defaultConfig().build(), + autoBatchConfig, + executor + ); + } + + public ReferencesBatcher referencesAutoBatcher(ReferencesBatcher.BatchRetriesConfig batchRetriesConfig, + ReferencesBatcher.AutoBatchConfig autoBatchConfig) { + return referencesAutoBatcher( + batchRetriesConfig, + autoBatchConfig, + null ); } public ReferencesBatcher referencesAutoBatcher(ReferencesBatcher.BatchRetriesConfig batchRetriesConfig, - ReferencesBatcher.AutoBatchConfig autoBatchConfig) { - return ReferencesBatcher.createAuto(client, config, referencesPath, batchRetriesConfig, autoBatchConfig); + ReferencesBatcher.AutoBatchConfig autoBatchConfig, + Executor executor) { + return ReferencesBatcher.createAuto(client, config, referencesPath, batchRetriesConfig, autoBatchConfig, executor); } } diff --git a/src/main/java/io/weaviate/client/v1/async/batch/api/ObjectsBatcher.java b/src/main/java/io/weaviate/client/v1/async/batch/api/ObjectsBatcher.java index 74209782..6965ad07 100644 --- a/src/main/java/io/weaviate/client/v1/async/batch/api/ObjectsBatcher.java +++ b/src/main/java/io/weaviate/client/v1/async/batch/api/ObjectsBatcher.java @@ -1,18 +1,19 @@ package io.weaviate.client.v1.async.batch.api; -import com.google.common.util.concurrent.ListenableFuture; import io.weaviate.client.Config; import io.weaviate.client.base.AsyncBaseClient; import io.weaviate.client.base.AsyncClientResult; -import io.weaviate.client.base.Response; import io.weaviate.client.base.Result; +import io.weaviate.client.base.WeaviateError; import io.weaviate.client.base.WeaviateErrorMessage; import io.weaviate.client.base.WeaviateErrorResponse; import io.weaviate.client.base.grpc.AsyncGrpcClient; import io.weaviate.client.base.util.Assert; +import io.weaviate.client.base.util.Futures; import io.weaviate.client.base.util.GrpcVersionSupport; import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBase; import io.weaviate.client.grpc.protocol.v1.WeaviateProtoBatch; +import io.weaviate.client.v1.async.data.Data; import io.weaviate.client.v1.auth.provider.AccessTokenProvider; import io.weaviate.client.v1.batch.grpc.BatchObjectConverter; import io.weaviate.client.v1.batch.model.ObjectGetResponse; @@ -20,10 +21,22 @@ import io.weaviate.client.v1.batch.model.ObjectsBatchRequestBody; import io.weaviate.client.v1.batch.model.ObjectsGetResponseAO2Result; import io.weaviate.client.v1.batch.util.ObjectsPath; -import io.weaviate.client.v1.async.data.Data; import io.weaviate.client.v1.data.model.WeaviateObject; import io.weaviate.client.v1.data.replication.model.ConsistencyLevel; -import java.io.Closeable; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import lombok.experimental.FieldDefaults; +import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.lang3.ObjectUtils; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; +import org.apache.hc.core5.concurrent.FutureCallback; +import org.apache.hc.core5.http.HttpStatus; + import java.net.ConnectException; import java.net.SocketTimeoutException; import java.util.ArrayList; @@ -32,94 +45,77 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; -import java.util.concurrent.Executors; import java.util.concurrent.Future; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; import java.util.function.Consumer; -import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.RequiredArgsConstructor; -import lombok.ToString; -import lombok.experimental.FieldDefaults; -import org.apache.commons.lang3.ArrayUtils; -import org.apache.commons.lang3.ObjectUtils; -import org.apache.commons.lang3.StringUtils; -import org.apache.commons.lang3.tuple.Pair; -import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; -import org.apache.hc.core5.concurrent.FutureCallback; public class ObjectsBatcher extends AsyncBaseClient - implements AsyncClientResult, Closeable { + implements AsyncClientResult { private final Data data; private final ObjectsPath objectsPath; + private final AccessTokenProvider tokenProvider; + private final GrpcVersionSupport grpcVersionSupport; private final ObjectsBatcher.BatchRetriesConfig batchRetriesConfig; private final ObjectsBatcher.AutoBatchConfig autoBatchConfig; + private final Config config; private final boolean autoRunEnabled; - private final ScheduledExecutorService executorService; - private final ObjectsBatcher.DelayedExecutor delayedExecutor; + private final Executor executor; + private final List>> futures; + private final List objects; private String consistencyLevel; - private final List>> undoneFutures; - private final boolean useGRPC; - private final AccessTokenProvider tokenProvider; - private final GrpcVersionSupport grpcVersionSupport; - private final Config config; private ObjectsBatcher(CloseableHttpAsyncClient client, Config config, Data data, ObjectsPath objectsPath, - AccessTokenProvider tokenProvider, GrpcVersionSupport grpcVersionSupport, - ObjectsBatcher.BatchRetriesConfig batchRetriesConfig, ObjectsBatcher.AutoBatchConfig autoBatchConfig) { + AccessTokenProvider tokenProvider, GrpcVersionSupport grpcVersionSupport, + ObjectsBatcher.BatchRetriesConfig batchRetriesConfig, ObjectsBatcher.AutoBatchConfig autoBatchConfig, + Executor executor) { super(client, config); this.config = config; - this.useGRPC = config.useGRPC(); this.tokenProvider = tokenProvider; this.data = data; this.objectsPath = objectsPath; this.grpcVersionSupport = grpcVersionSupport; - this.objects = new ArrayList<>(); this.batchRetriesConfig = batchRetriesConfig; + this.objects = Collections.synchronizedList(new ArrayList<>()); + this.futures = Collections.synchronizedList(new ArrayList<>()); + this.executor = executor; if (autoBatchConfig != null) { this.autoRunEnabled = true; this.autoBatchConfig = autoBatchConfig; - this.executorService = Executors.newScheduledThreadPool(autoBatchConfig.poolSize); - this.delayedExecutor = new ObjectsBatcher.ExecutorServiceDelayedExecutor(executorService); - this.undoneFutures = Collections.synchronizedList(new ArrayList<>()); } else { this.autoRunEnabled = false; this.autoBatchConfig = null; - this.executorService = null; - this.delayedExecutor = new ObjectsBatcher.SleepDelayedExecutor(); - this.undoneFutures = null; } } public static ObjectsBatcher create(CloseableHttpAsyncClient client, Config config, Data data, ObjectsPath objectsPath, - AccessTokenProvider tokenProvider, GrpcVersionSupport grpcVersionSupport, - ObjectsBatcher.BatchRetriesConfig batchRetriesConfig) { + AccessTokenProvider tokenProvider, GrpcVersionSupport grpcVersionSupport, + ObjectsBatcher.BatchRetriesConfig batchRetriesConfig, + Executor executor) { Assert.requiredNotNull(batchRetriesConfig, "batchRetriesConfig"); - return new ObjectsBatcher(client, config, data, objectsPath, tokenProvider, grpcVersionSupport, batchRetriesConfig, null); + return new ObjectsBatcher(client, config, data, objectsPath, tokenProvider, grpcVersionSupport, + batchRetriesConfig, null, executor); } public static ObjectsBatcher createAuto(CloseableHttpAsyncClient client, Config config, Data data, ObjectsPath objectsPath, - AccessTokenProvider tokenProvider, GrpcVersionSupport grpcVersionSupport, - ObjectsBatcher.BatchRetriesConfig batchRetriesConfig, ObjectsBatcher.AutoBatchConfig autoBatchConfig) { + AccessTokenProvider tokenProvider, GrpcVersionSupport grpcVersionSupport, + ObjectsBatcher.BatchRetriesConfig batchRetriesConfig, ObjectsBatcher.AutoBatchConfig autoBatchConfig, + Executor executor) { Assert.requiredNotNull(batchRetriesConfig, "batchRetriesConfig"); Assert.requiredNotNull(autoBatchConfig, "autoBatchConfig"); - return new ObjectsBatcher(client, config, data, objectsPath, tokenProvider, grpcVersionSupport, batchRetriesConfig, autoBatchConfig); + return new ObjectsBatcher(client, config, data, objectsPath, tokenProvider, grpcVersionSupport, + batchRetriesConfig, autoBatchConfig, executor); } @@ -139,71 +135,6 @@ public ObjectsBatcher withConsistencyLevel(String consistencyLevel) { return this; } - public Result runBatch() { - if (autoRunEnabled) { - flush(); // fallback to flush in auto run enabled - return null; - } - - if (objects.isEmpty()) { - return new Result<>(0, new ObjectGetResponse[0], null); - } - - List batch = extractBatch(objects.size()); - return runRecursively(batch, 0, 0, null, - (ObjectsBatcher.DelayedExecutor>) delayedExecutor); - } - - @Override - public Future> run(FutureCallback> callback) { - CompletableFuture> result = CompletableFuture.supplyAsync(() -> runBatch()); - if (callback != null) { - return result.whenComplete((res, e) -> { - callback.completed(res); - if (e != null) { - callback.failed(new Exception(e)); - } - }); - } - return result; - } - - public void flush() { - if (!autoRunEnabled) { - run(); // fallback to run if auto run disabled - return; - } - - if (!objects.isEmpty()) { - List batch = extractBatch(objects.size()); - runInThread(batch); - } - - CompletableFuture[] futures = undoneFutures.toArray(new CompletableFuture[0]); - if (futures.length == 0) { - return; - } - - CompletableFuture.allOf(futures).join(); - } - - @Override - public void close() { - if (!autoRunEnabled) { - return; - } - - executorService.shutdown(); - try { - if (!executorService.awaitTermination(autoBatchConfig.awaitTerminationMs, TimeUnit.MILLISECONDS)) { - executorService.shutdownNow(); - } - } catch (InterruptedException e) { - executorService.shutdownNow(); - } - } - - private void addMissingIds(WeaviateObject[] objects) { Arrays.stream(objects) .filter(o -> o.getId() == null) @@ -227,178 +158,252 @@ private void autoRun() { while (objects.size() >= autoBatchConfig.batchSize) { List batch = extractBatch(autoBatchConfig.batchSize); - runInThread(batch); + runBatch(batch); } } - private void runInThread(List batch) { - CompletableFuture> future = CompletableFuture.supplyAsync( - () -> createRunFuture(batch), - executorService - ).thenCompose(f -> f); + @Override + public Future> run(FutureCallback> callback) { + CompletableFuture> future = runAll(); + if (callback != null) { + future = future.whenComplete((result, throwable) -> { + if (throwable != null) { + callback.failed((Exception) throwable); + } else { + callback.completed(result); + } + }); + } + return future; + } - if (autoBatchConfig.callback != null) { - future = future.whenComplete((result, e) -> autoBatchConfig.callback.accept(result)); + private CompletableFuture> runAll() { + if (!autoRunEnabled) { + if (objects.isEmpty()) { + return CompletableFuture.completedFuture(new Result<>(0, new ObjectGetResponse[0], null)); + } + + List batch = extractBatch(objects.size()); + return runBatchRecursively(batch, 0, 0, null); } - CompletableFuture> undoneFuture = future; - undoneFutures.add(undoneFuture); - undoneFuture.whenComplete((result, ex) -> undoneFutures.remove(undoneFuture)); - } + if (!objects.isEmpty()) { + List batch = extractBatch(objects.size()); + runBatch(batch); + } + if (futures.isEmpty()) { + return CompletableFuture.completedFuture(new Result<>(0, new ObjectGetResponse[0], null)); + } - private CompletableFuture> createRunFuture(List batch) { - return runRecursively(batch, 0, 0, null, - (ObjectsBatcher.DelayedExecutor>>) delayedExecutor); - } + CompletableFuture[] futuresAsArray = futures.toArray(new CompletableFuture[0]); + return CompletableFuture.allOf(futuresAsArray).thenApply(v -> { + List allResponses = new ArrayList<>(); + List allMessages = new ArrayList<>(); + int[] lastErrStatusCode = new int[]{HttpStatus.SC_OK}; - private T runRecursively(List batchF, Integer connectionErrorCountF, int timeoutErrorCountF, - List combinedSingleResponsesF, ObjectsBatcher.DelayedExecutor delayedExecutor) { - Future> resultFuture = useGRPC ? internalGrpcRun(batchF, null) : internalRun(batchF); + futures.stream().map(resultCompletableFuture -> { + try { + return resultCompletableFuture.get(); + } catch (InterruptedException | ExecutionException e) { + throw new CompletionException(e); + } + }).forEach(result -> { + Optional.ofNullable(result) + .map(Result::getResult) + .map(Arrays::asList) + .ifPresent(allResponses::addAll); + Optional.ofNullable(result) + .filter(Result::hasErrors) + .map(Result::getError) + .map(WeaviateError::getMessages) + .ifPresent(allMessages::addAll); + Optional.ofNullable(result) + .filter(Result::hasErrors) + .map(Result::getError) + .map(WeaviateError::getStatusCode) + .ifPresent(sc -> lastErrStatusCode[0] = sc); + }); - CompletableFuture runRecursivelyFuture = CompletableFuture.supplyAsync(() -> { - try { - return resultFuture.get(); - } catch (InterruptedException | ExecutionException e) { - throw new CompletionException(e); - } - }).thenApplyAsync(result -> { - List batch = batchF; - Integer connectionErrorCount = connectionErrorCountF; - int timeoutErrorCount = timeoutErrorCountF; - List combinedSingleResponses = combinedSingleResponsesF; - if (result.hasErrors()) { - List messages = result.getError().getMessages(); - if (!messages.isEmpty()) { - Throwable throwable = messages.get(0).getThrowable(); - boolean executeAgain = false; - int delay = 0; - - if (throwable instanceof ConnectException) { - if (connectionErrorCount++ < batchRetriesConfig.maxConnectionRetries) { - executeAgain = true; - delay = connectionErrorCount * batchRetriesConfig.retriesIntervalMs; - } - } else if (throwable instanceof SocketTimeoutException) { - Pair, List> pair = fetchCreatedAndBuildBatchToReRun(batch); - combinedSingleResponses = combineSingleResponses(combinedSingleResponses, pair.getLeft()); - batch = pair.getRight(); - - if (ObjectUtils.isNotEmpty(batch) && timeoutErrorCount++ < batchRetriesConfig.maxTimeoutRetries) { - executeAgain = true; - delay = timeoutErrorCount * batchRetriesConfig.retriesIntervalMs; - } - } + WeaviateErrorResponse errorResponse = allMessages.isEmpty() + ? null + : WeaviateErrorResponse.builder().error(allMessages).code(lastErrStatusCode[0]).build(); + return new Result<>(lastErrStatusCode[0], allResponses.toArray(new ObjectGetResponse[0]), errorResponse); + }); + } - if (executeAgain) { - int lambdaConnectionErrorCount = connectionErrorCount; - int lambdaTimeoutErrorCount = timeoutErrorCount; - List lambdaBatch = batch; - List lambdaCombinedSingleResponses = combinedSingleResponses; + private void runBatch(List batch) { + CompletableFuture> future = runBatchRecursively(batch, 0, 0, null); + if (autoBatchConfig.callback != null) { + future = future.whenComplete((result, t) -> autoBatchConfig.callback.accept(result)); + } + futures.add(future); + } - return delayedExecutor.delayed( - delay, - () -> runRecursively(lambdaBatch, lambdaConnectionErrorCount, lambdaTimeoutErrorCount, lambdaCombinedSingleResponses, delayedExecutor) - ); + private CompletableFuture> runBatchRecursively(List batch, + int connectionErrorCount, int timeoutErrorCount, + List combinedSingleResponses) { + return Futures.handleAsync(internalRun(batch), (result, throwable) -> { + List tempCombinedSingleResponses = combinedSingleResponses; + List tempBatch = batch; + + if (throwable != null) { + boolean executeAgain = false; + int tempConnCount = connectionErrorCount; + int tempTimeCount = timeoutErrorCount; + int delay = 0; + + if (throwable instanceof ConnectException) { + if (tempConnCount++ < batchRetriesConfig.maxConnectionRetries) { + executeAgain = true; + delay = tempConnCount * batchRetriesConfig.retriesIntervalMs; + } + } else if (throwable instanceof SocketTimeoutException) { + Pair, List> pair = fetchCreatedAndBuildBatchToReRun(tempBatch); + tempCombinedSingleResponses = combineSingleResponses(tempCombinedSingleResponses, pair.getLeft()); + tempBatch = pair.getRight(); + + if (ObjectUtils.isNotEmpty(tempBatch) && tempTimeCount++ < batchRetriesConfig.maxTimeoutRetries) { + executeAgain = true; + delay = tempTimeCount * batchRetriesConfig.retriesIntervalMs; + } + } + if (executeAgain) { + try { + List finalCombinedSingleResponses = tempCombinedSingleResponses; + List finalBatch = tempBatch; + int connCount = tempConnCount; + int timeCount = tempTimeCount; + return Futures.supplyDelayed(() -> runBatchRecursively(finalBatch, connCount, timeCount, finalCombinedSingleResponses), delay, executor); + } catch (InterruptedException e) { + throw new CompletionException(e); } } - } else { - batch = null; + } else if (!result.hasErrors()) { + tempBatch = null; } - Result finalResult = createFinalResultFromLastResultAndCombinedSingleResponses(result, combinedSingleResponses, batch); - return delayedExecutor.now(finalResult); - }); - - return runRecursivelyFuture.join(); + return CompletableFuture.completedFuture(createFinalResultFromLastResultAndCombinedSingleResponses(result, + throwable, tempCombinedSingleResponses, tempBatch)); + }, executor); } - private Future> internalRun(List batch) { - ObjectsBatchRequestBody batchRequest = ObjectsBatchRequestBody.builder() - .objects(batch.toArray(new WeaviateObject[0])) - .fields(new String[]{"ALL"}) - .build(); - String path = objectsPath.buildCreate(ObjectsPath.Params.builder() - .consistencyLevel(consistencyLevel) - .build()); - return sendPostRequest(path, batchRequest, ObjectGetResponse[].class, null); + private CompletableFuture> internalRun(List batch) { + return config.useGRPC() ? internalGrpcRun(batch) : internalHttpRun(batch); } - private Future> internalGrpcRun(List batch, FutureCallback> callback) { + private CompletableFuture> internalGrpcRun(List batch) { BatchObjectConverter batchObjectConverter = new BatchObjectConverter(grpcVersionSupport); List batchObjects = batch.stream() .map(batchObjectConverter::toBatchObject) .collect(Collectors.toList()); + WeaviateProtoBatch.BatchObjectsRequest.Builder batchObjectsRequestBuilder = WeaviateProtoBatch.BatchObjectsRequest.newBuilder(); batchObjectsRequestBuilder.addAllObjects(batchObjects); - if (consistencyLevel != null) { - WeaviateProtoBase.ConsistencyLevel cl = WeaviateProtoBase.ConsistencyLevel.CONSISTENCY_LEVEL_ONE; - if (consistencyLevel.equals(ConsistencyLevel.ALL)) { - cl = WeaviateProtoBase.ConsistencyLevel.CONSISTENCY_LEVEL_ALL; - } - if (consistencyLevel.equals(ConsistencyLevel.QUORUM)) { - cl = WeaviateProtoBase.ConsistencyLevel.CONSISTENCY_LEVEL_QUORUM; - } - batchObjectsRequestBuilder.setConsistencyLevel(cl); - } + Optional.ofNullable(consistencyLevel) + .map(cl -> { + switch (cl) { + case ConsistencyLevel.ALL: + return WeaviateProtoBase.ConsistencyLevel.CONSISTENCY_LEVEL_ALL; + case ConsistencyLevel.QUORUM: + return WeaviateProtoBase.ConsistencyLevel.CONSISTENCY_LEVEL_QUORUM; + default: + return WeaviateProtoBase.ConsistencyLevel.CONSISTENCY_LEVEL_ONE; + } + }).ifPresent(batchObjectsRequestBuilder::setConsistencyLevel); WeaviateProtoBatch.BatchObjectsRequest batchObjectsRequest = batchObjectsRequestBuilder.build(); - CompletableFuture batchObjectsReplyFuture = CompletableFuture.supplyAsync(() -> { - AsyncGrpcClient grpcClient = AsyncGrpcClient.create(this.config, this.tokenProvider); - try { - return grpcClient.batchObjects(batchObjectsRequest).get(); - } catch (InterruptedException | ExecutionException e) { - throw new CompletionException(e); - } finally { - grpcClient.shutdown(); - } - }); + // TODO convert ListenableFuture into CompletableFuture? + return Futures.supplyAsync(() -> { + AsyncGrpcClient grpcClient = AsyncGrpcClient.create(config, tokenProvider); + try { + return grpcClient.batchObjects(batchObjectsRequest).get(); + } catch (InterruptedException | ExecutionException e) { + throw new CompletionException(e); + } finally { + grpcClient.shutdown(); + } + }, executor) + .thenApply(batchObjectsReply -> { + List weaviateErrorMessages = batchObjectsReply.getErrorsList().stream() + .map(WeaviateProtoBatch.BatchObjectsReply.BatchError::getError) + .filter(e -> !e.isEmpty()) + .map(msg -> WeaviateErrorMessage.builder().message(msg).build()) + .collect(Collectors.toList()); + + if (!weaviateErrorMessages.isEmpty()) { + int statusCode = HttpStatus.SC_UNPROCESSABLE_CONTENT; + WeaviateErrorResponse weaviateErrorResponse = WeaviateErrorResponse.builder() + .code(statusCode) + .message(StringUtils.join(weaviateErrorMessages, ",")) + .error(weaviateErrorMessages) + .build(); + return new Result<>(statusCode, null, weaviateErrorResponse); + } - CompletableFuture> resultFuture = batchObjectsReplyFuture.thenApplyAsync(batchObjectsReply -> { - List weaviateErrorMessages = batchObjectsReply.getErrorsList().stream() - .map(WeaviateProtoBatch.BatchObjectsReply.BatchError::getError) - .filter(e -> !e.isEmpty()) - .map(msg -> WeaviateErrorMessage.builder().message(msg).build()) - .collect(Collectors.toList()); + ObjectGetResponse[] objectGetResponses = batch.stream().map(o -> { + ObjectsGetResponseAO2Result result = new ObjectsGetResponseAO2Result(); + result.setStatus(ObjectGetResponseStatus.SUCCESS); - if (!weaviateErrorMessages.isEmpty()) { - WeaviateErrorResponse weaviateErrorResponse = WeaviateErrorResponse.builder() - .code(422).message(StringUtils.join(weaviateErrorMessages, ",")).error(weaviateErrorMessages).build(); - return new Result<>(422, null, weaviateErrorResponse); + ObjectGetResponse resp = new ObjectGetResponse(); + resp.setId(o.getId()); + resp.setClassName(o.getClassName()); + resp.setTenant(o.getTenant()); + resp.setResult(result); + return resp; + }).toArray(ObjectGetResponse[]::new); + + return new Result<>(HttpStatus.SC_OK, objectGetResponses, null); + }); + } + + private CompletableFuture> internalHttpRun(List batch) { + CompletableFuture> future = new CompletableFuture<>(); + ObjectsBatchRequestBody payload = ObjectsBatchRequestBody.builder() + .objects(batch.toArray(new WeaviateObject[0])) + .fields(new String[]{"ALL"}) + .build(); + String path = objectsPath.buildCreate(ObjectsPath.Params.builder() + .consistencyLevel(consistencyLevel) + .build()); + sendPostRequest(path, payload, ObjectGetResponse[].class, new FutureCallback>() { + @Override + public void completed(Result batchResult) { + future.complete(batchResult); } - ObjectGetResponse[] objectGetResponses = batch.stream().map(o -> { - ObjectGetResponse resp = new ObjectGetResponse(); - resp.setId(o.getId()); - resp.setClassName(o.getClassName()); - resp.setTenant(o.getTenant()); - ObjectsGetResponseAO2Result result = new ObjectsGetResponseAO2Result(); - result.setStatus(ObjectGetResponseStatus.SUCCESS); - resp.setResult(result); - return resp; - }).toArray(ObjectGetResponse[]::new); - - return new Result<>(200, objectGetResponses, null); - }); + @Override + public void failed(Exception e) { + future.completeExceptionally(e); + } - if (callback != null) { - return resultFuture.whenComplete((res, e) -> { - callback.completed(res); - if (e != null) { - callback.failed(new Exception(e)); - } - }); - } - return resultFuture; + @Override + public void cancelled() { + } + }); + return future; } private Pair, List> fetchCreatedAndBuildBatchToReRun(List batch) { List rerunBatch = new ArrayList<>(batch.size()); List createdResponses = new ArrayList<>(batch.size()); + List>>> futures = new ArrayList<>(batch.size()); + + for (WeaviateObject batchObject : batch) { + futures.add(fetchExistingObject(batchObject)); + } + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); + + try { + for (int i = 0; i < batch.size(); i++) { + CompletableFuture>> future = futures.get(i); + WeaviateObject batchObject = batch.get(i); - for (WeaviateObject batchObject: batch) { - try { - Result> existingResult = fetchExistingObject(batchObject).get(); + if (future.isCompletedExceptionally()) { + rerunBatch.add(batchObject); + continue; + } + Result> existingResult = future.get(); if (existingResult.hasErrors() || ObjectUtils.isEmpty(existingResult.getResult())) { rerunBatch.add(batchObject); continue; @@ -411,20 +416,37 @@ private Pair, List> fetchCreatedAndBuild } createdResponses.add(createResponseFromExistingObject(existingObject)); - } catch (InterruptedException | ExecutionException e) { - throw new CompletionException(e); } + } catch (InterruptedException | ExecutionException e) { + throw new CompletionException(e); } return Pair.of(createdResponses, rerunBatch); } - private Future>> fetchExistingObject(WeaviateObject batchObject) { - return data.objectsGetter() + private CompletableFuture>> fetchExistingObject(WeaviateObject batchObject) { + CompletableFuture>> future = new CompletableFuture<>(); + data.objectsGetter() .withID(batchObject.getId()) .withClassName(batchObject.getClassName()) .withVector() - .run(); + .run(new FutureCallback>>() { + @Override + public void completed(Result> objectsResult) { + future.complete(objectsResult); + } + + @Override + public void failed(Exception e) { + future.completeExceptionally(e); + } + + @Override + public void cancelled() { + } + }); + + return future; } private boolean isDifferentObject(WeaviateObject batchObject, WeaviateObject existingObject) { @@ -471,7 +493,7 @@ private ObjectGetResponse createResponseFromExistingObject(WeaviateObject existi private List combineSingleResponses(List combinedSingleResponses, - List createdResponses) { + List createdResponses) { if (ObjectUtils.isNotEmpty(createdResponses)) { combinedSingleResponses = ObjectUtils.isEmpty(combinedSingleResponses) ? createdResponses @@ -483,14 +505,26 @@ private List combineSingleResponses(List c return combinedSingleResponses; } - private Result createFinalResultFromLastResultAndCombinedSingleResponses( - Result lastResult, List combinedSingleResponses, List failedBatch) { + private Result createFinalResultFromLastResultAndCombinedSingleResponses(Result lastResult, + Throwable throwable, + List combinedSingleResponses, + List failedBatch) { + int statusCode = 0; + if (throwable != null && lastResult == null) { + lastResult = new Result<>(statusCode, null, WeaviateErrorResponse.builder() + .error(Collections.singletonList(WeaviateErrorMessage.builder() + .message(throwable.getMessage()) + .throwable(throwable) + .build())) + .code(statusCode) + .build() + ); + } if (ObjectUtils.isEmpty(failedBatch) && ObjectUtils.isEmpty(combinedSingleResponses)) { return lastResult; } - int statusCode = 0; ObjectGetResponse[] allResponses = null; if (ObjectUtils.isNotEmpty(lastResult.getResult())) { allResponses = lastResult.getResult(); @@ -524,49 +558,6 @@ private Result createFinalResultFromLastResultAndCombinedSi ); } - - - private interface DelayedExecutor { - T delayed(int delay, Supplier supplier); - T now(Result result); - } - - @RequiredArgsConstructor - private static class ExecutorServiceDelayedExecutor implements ObjectsBatcher.DelayedExecutor>> { - - private final ScheduledExecutorService executorService; - - @Override - public CompletableFuture> delayed(int delay, Supplier>> supplier) { - Executor executor = (runnable) -> executorService.schedule(runnable, delay, TimeUnit.MILLISECONDS); - return CompletableFuture.supplyAsync(supplier, executor).thenCompose(f -> f); - } - - @Override - public CompletableFuture> now(Result result) { - return CompletableFuture.completedFuture(result); - } - } - - - private static class SleepDelayedExecutor implements ObjectsBatcher.DelayedExecutor> { - - @Override - public Result delayed(int delay, Supplier> supplier) { - try { - Thread.sleep(delay); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - return supplier.get(); - } - - @Override - public Result now(Result result) { - return result; - } - } - @Getter @Builder @ToString @@ -608,31 +599,20 @@ public static ObjectsBatcher.BatchRetriesConfig.BatchRetriesConfigBuilder defaul public static class AutoBatchConfig { public static final int BATCH_SIZE = 100; - public static final int POOL_SIZE = 1; - public static final int AWAIT_TERMINATION_MS = 10_000; int batchSize; - int poolSize; - int awaitTerminationMs; Consumer> callback; - private AutoBatchConfig(int batchSize, int poolSize, int awaitTerminationMs, - Consumer> callback) { + private AutoBatchConfig(int batchSize, Consumer> callback) { Assert.requireGreaterEqual(batchSize, 1, "batchSize"); - Assert.requireGreaterEqual(poolSize, 1, "corePoolSize"); - Assert.requireGreater(awaitTerminationMs, 0, "awaitTerminationMs"); this.batchSize = batchSize; - this.poolSize = poolSize; - this.awaitTerminationMs = awaitTerminationMs; this.callback = callback; } public static ObjectsBatcher.AutoBatchConfig.AutoBatchConfigBuilder defaultConfig() { return ObjectsBatcher.AutoBatchConfig.builder() .batchSize(BATCH_SIZE) - .poolSize(POOL_SIZE) - .awaitTerminationMs(AWAIT_TERMINATION_MS) .callback(null); } } diff --git a/src/main/java/io/weaviate/client/v1/async/batch/api/ReferencesBatcher.java b/src/main/java/io/weaviate/client/v1/async/batch/api/ReferencesBatcher.java index 118dc1b8..036d068a 100644 --- a/src/main/java/io/weaviate/client/v1/async/batch/api/ReferencesBatcher.java +++ b/src/main/java/io/weaviate/client/v1/async/batch/api/ReferencesBatcher.java @@ -8,6 +8,7 @@ import io.weaviate.client.base.WeaviateErrorMessage; import io.weaviate.client.base.WeaviateErrorResponse; import io.weaviate.client.base.util.Assert; +import io.weaviate.client.base.util.Futures; import io.weaviate.client.v1.batch.model.BatchReference; import io.weaviate.client.v1.batch.model.BatchReferenceResponse; import io.weaviate.client.v1.batch.util.ReferencesPath; @@ -31,6 +32,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; import java.util.concurrent.Future; import java.util.function.Consumer; import java.util.stream.Collectors; @@ -39,21 +41,26 @@ public class ReferencesBatcher extends AsyncBaseClient implements AsyncClientResult { private final ReferencesPath referencesPath; + private final BatchRetriesConfig batchRetriesConfig; private final AutoBatchConfig autoBatchConfig; private final boolean autoRunEnabled; + private final Executor executor; private final List>> futures; + private final List references; private String consistencyLevel; private ReferencesBatcher(CloseableHttpAsyncClient client, Config config, ReferencesPath referencesPath, - BatchRetriesConfig batchRetriesConfig, AutoBatchConfig autoBatchConfig) { + BatchRetriesConfig batchRetriesConfig, AutoBatchConfig autoBatchConfig, + Executor executor) { super(client, config); this.referencesPath = referencesPath; this.futures = Collections.synchronizedList(new ArrayList<>()); this.references = Collections.synchronizedList(new ArrayList<>()); this.batchRetriesConfig = batchRetriesConfig; + this.executor = executor; if (autoBatchConfig != null) { this.autoRunEnabled = true; @@ -65,16 +72,17 @@ private ReferencesBatcher(CloseableHttpAsyncClient client, Config config, Refere } public static ReferencesBatcher create(CloseableHttpAsyncClient client, Config config, ReferencesPath referencesPath, - BatchRetriesConfig batchRetriesConfig) { + BatchRetriesConfig batchRetriesConfig, Executor executor) { Assert.requiredNotNull(batchRetriesConfig, "batchRetriesConfig"); - return new ReferencesBatcher(client, config, referencesPath, batchRetriesConfig, null); + return new ReferencesBatcher(client, config, referencesPath, batchRetriesConfig, null, executor); } public static ReferencesBatcher createAuto(CloseableHttpAsyncClient client, Config config, ReferencesPath referencesPath, - BatchRetriesConfig batchRetriesConfig, AutoBatchConfig autoBatchConfig) { + BatchRetriesConfig batchRetriesConfig, AutoBatchConfig autoBatchConfig, + Executor executor) { Assert.requiredNotNull(batchRetriesConfig, "batchRetriesConfig"); Assert.requiredNotNull(autoBatchConfig, "autoBatchConfig"); - return new ReferencesBatcher(client, config, referencesPath, batchRetriesConfig, autoBatchConfig); + return new ReferencesBatcher(client, config, referencesPath, batchRetriesConfig, autoBatchConfig, executor); } @@ -193,38 +201,37 @@ private void runBatch(List batch) { private CompletableFuture> runBatchRecursively(List batch, int connectionErrorCount, int timeoutErrorCount) { - return internalRun(batch).handle((Result result, Throwable throwable) -> { - int lambdaConnectionErrorCount = connectionErrorCount; - int lambdaTimeErrorCount = timeoutErrorCount; - - if (throwable != null) { - boolean executeAgain = false; - int delay = 0; - - if (throwable instanceof ConnectException) { - if (lambdaConnectionErrorCount++ < batchRetriesConfig.maxConnectionRetries) { - executeAgain = true; - delay = lambdaConnectionErrorCount * batchRetriesConfig.retriesIntervalMs; - } - } else if (throwable instanceof SocketTimeoutException) { - if (lambdaTimeErrorCount++ < batchRetriesConfig.maxTimeoutRetries) { - executeAgain = true; - delay = lambdaTimeErrorCount * batchRetriesConfig.retriesIntervalMs; - } + return Futures.handleAsync(internalRun(batch), (result, throwable) -> { + if (throwable != null) { + boolean executeAgain = false; + int tempConnCount = connectionErrorCount; + int tempTimeCount = timeoutErrorCount; + int delay = 0; + + if (throwable instanceof ConnectException) { + if (tempConnCount++ < batchRetriesConfig.maxConnectionRetries) { + executeAgain = true; + delay = tempConnCount * batchRetriesConfig.retriesIntervalMs; } - if (executeAgain) { - try { - Thread.sleep(delay); - return runBatchRecursively(batch, lambdaConnectionErrorCount, lambdaTimeErrorCount); - } catch (InterruptedException e) { - throw new CompletionException(e); - } + } else if (throwable instanceof SocketTimeoutException) { + if (tempTimeCount++ < batchRetriesConfig.maxTimeoutRetries) { + executeAgain = true; + delay = tempTimeCount * batchRetriesConfig.retriesIntervalMs; } } + if (executeAgain) { + int finalConnCount = tempConnCount; + int finalTimeCount = tempTimeCount; + try { + return Futures.supplyDelayed(() -> runBatchRecursively(batch, finalConnCount, finalTimeCount), delay, executor); + } catch (InterruptedException e) { + throw new CompletionException(e); + } + } + } - return CompletableFuture.completedFuture(createFinalResultFromLastResult(result, throwable, batch)); - }) - .thenCompose(f -> f); + return CompletableFuture.completedFuture(createFinalResultFromLastResult(result, throwable, batch)); + }, executor); } private CompletableFuture> internalRun(List batch) { diff --git a/src/main/java/io/weaviate/client/v1/async/classifications/Classifications.java b/src/main/java/io/weaviate/client/v1/async/classifications/Classifications.java index 61e547dd..75d611f8 100644 --- a/src/main/java/io/weaviate/client/v1/async/classifications/Classifications.java +++ b/src/main/java/io/weaviate/client/v1/async/classifications/Classifications.java @@ -6,6 +6,8 @@ import lombok.RequiredArgsConstructor; import org.apache.hc.client5.http.impl.async.CloseableHttpAsyncClient; +import java.util.concurrent.Executor; + @RequiredArgsConstructor public class Classifications { @@ -14,7 +16,11 @@ public class Classifications { public Scheduler scheduler() { - return new Scheduler(client, config, getter()); + return scheduler(null); + } + + public Scheduler scheduler(Executor executor) { + return new Scheduler(client, config, getter(), executor); } public Getter getter() { diff --git a/src/main/java/io/weaviate/client/v1/async/classifications/api/Scheduler.java b/src/main/java/io/weaviate/client/v1/async/classifications/api/Scheduler.java index c72cc0a1..ea37a3a9 100644 --- a/src/main/java/io/weaviate/client/v1/async/classifications/api/Scheduler.java +++ b/src/main/java/io/weaviate/client/v1/async/classifications/api/Scheduler.java @@ -4,6 +4,7 @@ import io.weaviate.client.base.AsyncBaseClient; import io.weaviate.client.base.AsyncClientResult; import io.weaviate.client.base.Result; +import io.weaviate.client.base.util.Futures; import io.weaviate.client.v1.classifications.model.Classification; import io.weaviate.client.v1.classifications.model.ClassificationFilters; import io.weaviate.client.v1.filters.WhereFilter; @@ -14,6 +15,7 @@ import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; +import java.util.concurrent.Executor; import java.util.concurrent.Future; public class Scheduler extends AsyncBaseClient @@ -32,11 +34,13 @@ public class Scheduler extends AsyncBaseClient private Object settings; private final Getter getter; + private final Executor executor; - public Scheduler(CloseableHttpAsyncClient client, Config config, Getter getter) { + public Scheduler(CloseableHttpAsyncClient client, Config config, Getter getter, Executor executor) { super(client, config); this.getter = getter; + this.executor = executor; } @@ -168,7 +172,7 @@ public void cancelled() { } private CompletableFuture> getByIdRecursively(String id) { - return getById(id).thenCompose(classificationResult -> { + return Futures.thenComposeAsync(getById(id), classificationResult -> { boolean isRunning = Optional.ofNullable(classificationResult) .map(Result::getResult) .map(Classification::getStatus) @@ -177,14 +181,13 @@ private CompletableFuture> getByIdRecursively(String id) if (isRunning) { try { - Thread.sleep(WAIT_INTERVAL); - return getByIdRecursively(id); + return Futures.supplyDelayed(() -> getByIdRecursively(id), WAIT_INTERVAL, executor); } catch (InterruptedException e) { throw new CompletionException(e); } } return CompletableFuture.completedFuture(classificationResult); - }); + }, executor); } private ClassificationFilters getClassificationFilters(WhereFilter sourceWhere, WhereFilter targetWhere, WhereFilter trainingSetWhere) { diff --git a/src/test/java/io/weaviate/integration/client/async/batch/ClientBatchCreateMockServerTest.java b/src/test/java/io/weaviate/integration/client/async/batch/ClientBatchCreateMockServerTest.java new file mode 100644 index 00000000..65aa259d --- /dev/null +++ b/src/test/java/io/weaviate/integration/client/async/batch/ClientBatchCreateMockServerTest.java @@ -0,0 +1,326 @@ +package io.weaviate.integration.client.async.batch; + +import com.jparams.junit4.JParamsTestRunner; +import com.jparams.junit4.data.DataMethod; +import io.weaviate.client.Config; +import io.weaviate.client.WeaviateClient; +import io.weaviate.client.base.Result; +import io.weaviate.client.base.Serializer; +import io.weaviate.client.v1.async.WeaviateAsyncClient; +import io.weaviate.client.v1.async.batch.api.ObjectsBatcher; +import io.weaviate.client.v1.batch.model.ObjectGetResponse; +import io.weaviate.integration.tests.batch.BatchObjectsMockServerTestSuite; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockserver.client.MockServerClient; +import org.mockserver.integration.ClientAndServer; +import org.mockserver.model.Delay; +import org.mockserver.verify.VerificationTimes; + +import java.util.concurrent.ExecutionException; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import static org.mockserver.integration.ClientAndServer.startClientAndServer; +import static org.mockserver.model.HttpRequest.request; +import static org.mockserver.model.HttpResponse.response; + +@RunWith(JParamsTestRunner.class) +public class ClientBatchCreateMockServerTest { + + private WeaviateClient client; + private ClientAndServer mockServer; + private MockServerClient mockServerClient; + + private static final String MOCK_SERVER_HOST = "localhost"; + private static final int MOCK_SERVER_PORT = 8999; + + @Before + public void before() { + mockServer = startClientAndServer(MOCK_SERVER_PORT); + mockServerClient = new MockServerClient(MOCK_SERVER_HOST, MOCK_SERVER_PORT); + + mockServerClient.when( + request().withMethod("GET").withPath("/v1/meta") + ).respond( + response().withStatusCode(200).withBody(metaBody()) + ); + + Config config = new Config("http", MOCK_SERVER_HOST + ":" + MOCK_SERVER_PORT, null, 1, 1, 1); + client = new WeaviateClient(config); + } + + @After + public void stopMockServer() { + mockServer.stop(); + } + + @Test + @DataMethod(source = ClientBatchCreateMockServerTest.class, method = "provideForNotCreateBatchDueToConnectionIssue") + public void shouldNotCreateBatchDueToConnectionIssue(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig, + long expectedExecMinMillis, long expectedExecMaxMillis) { + // stop server to simulate connection issues + mockServer.stop(); + + try (WeaviateAsyncClient asyncClient = client.async()) { + Supplier> supplierObjectsBatcher = () -> { + try { + return asyncClient.batch().objectsBatcher(batchRetriesConfig) + .withObjects(BatchObjectsMockServerTestSuite.PIZZA_1, BatchObjectsMockServerTestSuite.PIZZA_2, + BatchObjectsMockServerTestSuite.SOUP_1, BatchObjectsMockServerTestSuite.SOUP_2) + .run() + .get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + }; + + BatchObjectsMockServerTestSuite.testNotCreateBatchDueToConnectionIssue(supplierObjectsBatcher, + expectedExecMinMillis, expectedExecMaxMillis); + } + } + + @Test + @DataMethod(source = ClientBatchCreateMockServerTest.class, method = "provideForNotCreateBatchDueToConnectionIssue") + public void shouldNotCreateAutoBatchDueToConnectionIssue(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig, + long expectedExecMinMillis, long expectedExecMaxMillis) { + // stop server to simulate connection issues + mockServer.stop(); + + try (WeaviateAsyncClient asyncClient = client.async()) { + Consumer>> supplierObjectsBatcher = callback -> { + ObjectsBatcher.AutoBatchConfig autoBatchConfig = ObjectsBatcher.AutoBatchConfig.defaultConfig() + .batchSize(2) + .callback(callback) + .build(); + + try { + asyncClient.batch().objectsAutoBatcher(batchRetriesConfig, autoBatchConfig) + .withObjects(BatchObjectsMockServerTestSuite.PIZZA_1, BatchObjectsMockServerTestSuite.PIZZA_2, + BatchObjectsMockServerTestSuite.SOUP_1, BatchObjectsMockServerTestSuite.SOUP_2) + .run() + .get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + }; + + BatchObjectsMockServerTestSuite.testNotCreateAutoBatchDueToConnectionIssue(supplierObjectsBatcher, + expectedExecMinMillis, expectedExecMaxMillis); + } + } + + public static Object[][] provideForNotCreateBatchDueToConnectionIssue() { + return new Object[][]{ + new Object[]{ + // final response should be available immediately + ObjectsBatcher.BatchRetriesConfig.defaultConfig() + .retriesIntervalMs(400) + .maxConnectionRetries(0) + .build(), + 0, 350 + }, + new Object[]{ + // final response should be available after 1 retry (400 ms) + ObjectsBatcher.BatchRetriesConfig.defaultConfig() + .retriesIntervalMs(400) + .maxConnectionRetries(1) + .build(), + 400, 750 + }, + new Object[]{ + // final response should be available after 2 retries (400 + 800 ms) + ObjectsBatcher.BatchRetriesConfig.defaultConfig() + .retriesIntervalMs(400) + .maxConnectionRetries(2) + .build(), + 1200, 1550 + }, + new Object[]{ + // final response should be available after 1 retry (400 + 800 + 1200 ms) + ObjectsBatcher.BatchRetriesConfig.defaultConfig() + .retriesIntervalMs(400) + .maxConnectionRetries(3) + .build(), + 2400, 2750 + }, + }; + } + + @Test + @DataMethod(source = ClientBatchCreateMockServerTest.class, method = "provideForNotCreateBatchDueToTimeoutIssue") + public void shouldNotCreateBatchDueToTimeoutIssue(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig, + int expectedBatchCallsCount) { + // given client times out after 1s + + Serializer serializer = new Serializer(); + String pizza1Str = serializer.toJsonString(BatchObjectsMockServerTestSuite.PIZZA_1); + String soup1Str = serializer.toJsonString(BatchObjectsMockServerTestSuite.SOUP_1); + + // batch request should end up with timeout exception, but Pizza1 and Soup1 should be "added" and available by get + mockServerClient.when( + request().withMethod("POST").withPath("/v1/batch/objects") + ).respond( + response().withDelay(Delay.seconds(2)).withStatusCode(200) + ); + mockServerClient.when( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Pizza", BatchObjectsMockServerTestSuite.PIZZA_1_ID)) + ).respond( + response().withBody(pizza1Str) + ); + mockServerClient.when( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Soup", BatchObjectsMockServerTestSuite.SOUP_1_ID)) + ).respond( + response().withBody(soup1Str) + ); + + try (WeaviateAsyncClient asyncClient = client.async()) { + Supplier> supplierObjectsBatcher = () -> { + try { + return asyncClient.batch().objectsBatcher(batchRetriesConfig) + .withObjects(BatchObjectsMockServerTestSuite.PIZZA_1, BatchObjectsMockServerTestSuite.PIZZA_2, + BatchObjectsMockServerTestSuite.SOUP_1, BatchObjectsMockServerTestSuite.SOUP_2) + .run() + .get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + }; + Consumer assertPostObjectsCallsCount = count -> mockServerClient.verify( + request().withMethod("POST").withPath("/v1/batch/objects"), + VerificationTimes.exactly(count) + ); + Consumer assertGetPizza1CallsCount = count -> mockServerClient.verify( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Pizza", BatchObjectsMockServerTestSuite.PIZZA_1_ID)), + VerificationTimes.exactly(count) + ); + Consumer assertGetPizza2CallsCount = count -> mockServerClient.verify( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Pizza", BatchObjectsMockServerTestSuite.PIZZA_2_ID)), + VerificationTimes.exactly(count) + ); + Consumer assertGetSoup1CallsCount = count -> mockServerClient.verify( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Soup", BatchObjectsMockServerTestSuite.SOUP_1_ID)), + VerificationTimes.exactly(count) + ); + Consumer assertGetSoup2CallsCount = count -> mockServerClient.verify( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Soup", BatchObjectsMockServerTestSuite.SOUP_2_ID)), + VerificationTimes.exactly(count) + ); + + BatchObjectsMockServerTestSuite.testNotCreateBatchDueToTimeoutIssue(supplierObjectsBatcher, + assertPostObjectsCallsCount, assertGetPizza1CallsCount, assertGetPizza2CallsCount, + assertGetSoup1CallsCount, assertGetSoup2CallsCount, expectedBatchCallsCount, "1 SECONDS"); + } + } + + @Test + @DataMethod(source = ClientBatchCreateMockServerTest.class, method = "provideForNotCreateBatchDueToTimeoutIssue") + public void shouldNotCreateAutoBatchDueToTimeoutIssue(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig, + int expectedBatchCallsCount) { + // given client times out after 1s + + Serializer serializer = new Serializer(); + String pizza1Str = serializer.toJsonString(BatchObjectsMockServerTestSuite.PIZZA_1); + String soup1Str = serializer.toJsonString(BatchObjectsMockServerTestSuite.SOUP_1); + + // batch request should end up with timeout exception, but Pizza1 and Soup1 should be "added" and available by get + mockServerClient.when( + request().withMethod("POST").withPath("/v1/batch/objects") + ).respond( + response().withDelay(Delay.seconds(2)).withStatusCode(200) + ); + mockServerClient.when( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Pizza", BatchObjectsMockServerTestSuite.PIZZA_1_ID)) + ).respond( + response().withBody(pizza1Str) + ); + mockServerClient.when( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Soup", BatchObjectsMockServerTestSuite.SOUP_1_ID)) + ).respond( + response().withBody(soup1Str) + ); + + try (WeaviateAsyncClient asyncClient = client.async()) { + Consumer>> supplierObjectsBatcher = callback -> { + ObjectsBatcher.AutoBatchConfig autoBatchConfig = ObjectsBatcher.AutoBatchConfig.defaultConfig() + .batchSize(2) + .callback(callback) + .build(); + + try { + asyncClient.batch().objectsAutoBatcher(batchRetriesConfig, autoBatchConfig) + .withObjects(BatchObjectsMockServerTestSuite.PIZZA_1, BatchObjectsMockServerTestSuite.PIZZA_2, + BatchObjectsMockServerTestSuite.SOUP_1, BatchObjectsMockServerTestSuite.SOUP_2) + .run() + .get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + }; + + Consumer assertPostObjectsCallsCount = count -> mockServerClient.verify( + request().withMethod("POST").withPath("/v1/batch/objects"), + VerificationTimes.exactly(count) + ); + Consumer assertGetPizza1CallsCount = count -> mockServerClient.verify( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Pizza", BatchObjectsMockServerTestSuite.PIZZA_1_ID)), + VerificationTimes.exactly(count) + ); + Consumer assertGetPizza2CallsCount = count -> mockServerClient.verify( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Pizza", BatchObjectsMockServerTestSuite.PIZZA_2_ID)), + VerificationTimes.exactly(count) + ); + Consumer assertGetSoup1CallsCount = count -> mockServerClient.verify( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Soup", BatchObjectsMockServerTestSuite.SOUP_1_ID)), + VerificationTimes.exactly(count) + ); + Consumer assertGetSoup2CallsCount = count -> mockServerClient.verify( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Soup", BatchObjectsMockServerTestSuite.SOUP_2_ID)), + VerificationTimes.exactly(count) + ); + + BatchObjectsMockServerTestSuite.testNotCreateAutoBatchDueToTimeoutIssue(supplierObjectsBatcher, + assertPostObjectsCallsCount, assertGetPizza1CallsCount, assertGetPizza2CallsCount, + assertGetSoup1CallsCount, assertGetSoup2CallsCount, expectedBatchCallsCount, "1 SECONDS"); + } + } + + public static Object[][] provideForNotCreateBatchDueToTimeoutIssue() { + return new Object[][]{ + new Object[]{ + // final response should be available immediately + ObjectsBatcher.BatchRetriesConfig.defaultConfig() + .retriesIntervalMs(200) + .maxTimeoutRetries(0) + .build(), + 1 + }, + new Object[]{ + // final response should be available after 1 retry (200 ms) + ObjectsBatcher.BatchRetriesConfig.defaultConfig() + .retriesIntervalMs(200) + .maxTimeoutRetries(1) + .build(), + 2 + }, + new Object[]{ + // final response should be available after 2 retries (200 + 400 ms) + ObjectsBatcher.BatchRetriesConfig.defaultConfig() + .retriesIntervalMs(200) + .maxTimeoutRetries(2) + .build(), + 3 + }, + }; + } + + private String metaBody() { + return String.format("{\n" + + " \"hostname\": \"http://[::]:%s\",\n" + + " \"modules\": {},\n" + + " \"version\": \"%s\"\n" + + "}", MOCK_SERVER_PORT, "1.17.999-mock-server-version"); + } +} diff --git a/src/test/java/io/weaviate/integration/client/async/batch/ClientBatchCreateTest.java b/src/test/java/io/weaviate/integration/client/async/batch/ClientBatchCreateTest.java index e663ddb2..f37189cf 100644 --- a/src/test/java/io/weaviate/integration/client/async/batch/ClientBatchCreateTest.java +++ b/src/test/java/io/weaviate/integration/client/async/batch/ClientBatchCreateTest.java @@ -4,22 +4,28 @@ import io.weaviate.client.WeaviateClient; import io.weaviate.client.base.Result; import io.weaviate.client.v1.async.WeaviateAsyncClient; +import io.weaviate.client.v1.async.batch.api.ObjectsBatcher; import io.weaviate.client.v1.batch.model.ObjectGetResponse; import io.weaviate.client.v1.data.model.WeaviateObject; import io.weaviate.client.v1.data.replication.model.ConsistencyLevel; import io.weaviate.integration.client.WeaviateDockerCompose; import io.weaviate.integration.client.WeaviateTestGenerics; -import io.weaviate.integration.tests.batch.BatchTestSuite; -import java.util.List; -import java.util.concurrent.ExecutionException; -import java.util.function.Function; -import java.util.function.Supplier; +import io.weaviate.integration.tests.batch.BatchObjectsTestSuite; +import org.jetbrains.annotations.NotNull; import org.junit.After; import org.junit.Before; import org.junit.ClassRule; import org.junit.Test; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + public class ClientBatchCreateTest { + private WeaviateClient client; private final WeaviateTestGenerics testGenerics = new WeaviateTestGenerics(); @@ -28,9 +34,7 @@ public class ClientBatchCreateTest { @Before public void before() { - String httpHost = compose.getHttpHostAddress(); - Config config = new Config("http", httpHost); - + Config config = new Config("http", compose.getHttpHostAddress()); client = new WeaviateClient(config); testGenerics.createWeaviateTestSchemaFood(client); } @@ -42,87 +46,170 @@ public void after() { @Test public void shouldCreateBatch() { - Supplier> resPizza1 = () -> { - try (WeaviateAsyncClient asyncClient = client.async()) { - return asyncClient.data().creator() - .withClassName("Pizza") - .withID(BatchTestSuite.PIZZA_1_ID) - .withProperties(BatchTestSuite.PIZZA_1_PROPS) - .run().get(); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); - } - }; - Supplier> resSoup1 = () -> { - try (WeaviateAsyncClient asyncClient = client.async()) { - return asyncClient.data().creator() - .withClassName("Soup") - .withID(BatchTestSuite.SOUP_1_ID) - .withProperties(BatchTestSuite.SOUP_1_PROPS) - .run().get(); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); - } - }; - - Function, Result> resBatchPizzas = (pizza1) -> { - try (WeaviateAsyncClient asyncClient = client.async()) { - return asyncClient.batch().objectsBatcher() - .withObjects( - pizza1.getResult(), - WeaviateObject.builder().className("Pizza").id(BatchTestSuite.PIZZA_2_ID).properties(BatchTestSuite.PIZZA_2_PROPS).build() - ) - .withConsistencyLevel(ConsistencyLevel.QUORUM) - .run().get(); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); - } - }; - - Function, Result> resBatchSoups = (soup1) -> { - try (WeaviateAsyncClient asyncClient = client.async()) { - return asyncClient.batch().objectsBatcher() - .withObjects( - soup1.getResult(), - WeaviateObject.builder().className("Soup").id(BatchTestSuite.SOUP_2_ID).properties(BatchTestSuite.SOUP_2_PROPS).build() - ) - .withConsistencyLevel(ConsistencyLevel.QUORUM) - .run().get(); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); - } - }; - - // check if created objects exist - Supplier>> resGetPizza1 = () -> { - try (WeaviateAsyncClient asyncClient = client.async()) { - return asyncClient.data().objectsGetter().withID(BatchTestSuite.PIZZA_1_ID).withClassName("Pizza").run().get(); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); - } - }; - Supplier>> resGetPizza2 = () -> { - try (WeaviateAsyncClient asyncClient = client.async()) { - return asyncClient.data().objectsGetter().withID(BatchTestSuite.PIZZA_2_ID).withClassName("Pizza").run().get(); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); - } - }; - Supplier>> resGetSoup1 = () -> { - try (WeaviateAsyncClient asyncClient = client.async()) { - return asyncClient.data().objectsGetter().withID(BatchTestSuite.SOUP_1_ID).withClassName("Soup").run().get(); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); - } - }; - Supplier>> resGetSoup2 = () -> { - try (WeaviateAsyncClient asyncClient = client.async()) { - return asyncClient.data().objectsGetter().withID(BatchTestSuite.SOUP_2_ID).withClassName("Soup").run().get(); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException(e); - } - }; - - BatchTestSuite.shouldCreateBatch(resPizza1, resSoup1, resBatchPizzas, resBatchSoups, resGetPizza1, resGetPizza2, resGetSoup1, resGetSoup2); + try (WeaviateAsyncClient asyncClient = client.async()) { + Function> supplierObjectsBatcherPizzas = pizza -> { + try { + return asyncClient.batch().objectsBatcher() + .withObjects(pizza, WeaviateObject.builder() + .className("Pizza") + .id(BatchObjectsTestSuite.PIZZA_2_ID) + .properties(BatchObjectsTestSuite.PIZZA_2_PROPS) + .build()) + .withConsistencyLevel(ConsistencyLevel.QUORUM) + .run() + .get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + }; + Function> supplierObjectsBatcherSoups = soup -> { + try { + return asyncClient.batch().objectsBatcher() + .withObjects(soup, WeaviateObject.builder() + .className("Soup") + .id(BatchObjectsTestSuite.SOUP_2_ID) + .properties(BatchObjectsTestSuite.SOUP_2_PROPS) + .build()) + .withConsistencyLevel(ConsistencyLevel.QUORUM) + .run() + .get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + }; + + BatchObjectsTestSuite.testCreateBatch(supplierObjectsBatcherPizzas, supplierObjectsBatcherSoups, + createSupplierDataPizza1(), createSupplierDataSoup1(), + createSupplierGetterPizza1(), createSupplierGetterPizza2(), + createSupplierGetterSoup1(), createSupplierGetterSoup2()); + } + } + + @Test + public void shouldCreateAutoBatch() { + try (WeaviateAsyncClient asyncClient = client.async()) { + BiConsumer>> supplierObjectsBatcherPizzas = (pizza, callback) -> { + ObjectsBatcher.AutoBatchConfig autoBatchConfig = ObjectsBatcher.AutoBatchConfig.defaultConfig() + .batchSize(2) + .callback(callback) + .build(); + + try { + asyncClient.batch().objectsAutoBatcher(autoBatchConfig) + .withObjects(pizza, WeaviateObject.builder().className("Pizza") + .id(BatchObjectsTestSuite.PIZZA_2_ID) + .properties(BatchObjectsTestSuite.PIZZA_2_PROPS) + .build()) + .run() + .get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + }; + BiConsumer>> supplierObjectsBatcherSoups = (soup, callback) -> { + ObjectsBatcher.AutoBatchConfig autoBatchConfig = ObjectsBatcher.AutoBatchConfig.defaultConfig() + .batchSize(2) + .callback(callback) + .build(); + + try { + asyncClient.batch().objectsAutoBatcher(autoBatchConfig) + .withObjects(soup, WeaviateObject.builder() + .className("Soup") + .id(BatchObjectsTestSuite.SOUP_2_ID) + .properties(BatchObjectsTestSuite.SOUP_2_PROPS) + .build()) + .run() + .get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + }; + + BatchObjectsTestSuite.testCreateAutoBatch(supplierObjectsBatcherPizzas, supplierObjectsBatcherSoups, + createSupplierDataPizza1(), createSupplierDataSoup1(), + createSupplierGetterPizza1(), createSupplierGetterPizza2(), + createSupplierGetterSoup1(), createSupplierGetterSoup2()); + } + } + + @Test + public void shouldCreateBatchWithPartialError() { + try (WeaviateAsyncClient asyncClient = client.async()) { + Supplier> supplierObjectsBatcherPizzas = () -> { + WeaviateObject pizzaWithError = WeaviateObject.builder() + .className("Pizza") + .id(BatchObjectsTestSuite.PIZZA_1_ID) + .properties(BatchObjectsTestSuite.createFoodProperties(1, "This pizza should throw a invalid name error")) + .build(); + WeaviateObject pizza = WeaviateObject.builder() + .className("Pizza") + .id(BatchObjectsTestSuite.PIZZA_2_ID) + .properties(BatchObjectsTestSuite.PIZZA_2_PROPS) + .build(); + + try { + return asyncClient.batch().objectsBatcher() + .withObjects(pizzaWithError, pizza) + .run() + .get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + }; + + BatchObjectsTestSuite.testCreateBatchWithPartialError(supplierObjectsBatcherPizzas, + createSupplierGetterPizza1(), createSupplierGetterPizza2()); + } + } + + @NotNull + private Supplier> createSupplierDataSoup1() { + return () -> client.data().creator() + .withClassName("Soup") + .withID(BatchObjectsTestSuite.SOUP_1_ID) + .withProperties(BatchObjectsTestSuite.SOUP_1_PROPS) + .run(); + } + + @NotNull + private Supplier> createSupplierDataPizza1() { + return () -> client.data().creator() + .withClassName("Pizza") + .withID(BatchObjectsTestSuite.PIZZA_1_ID) + .withProperties(BatchObjectsTestSuite.PIZZA_1_PROPS) + .run(); + } + + @NotNull + private Supplier>> createSupplierGetterPizza1() { + return () -> client.data().objectsGetter() + .withID(BatchObjectsTestSuite.PIZZA_1_ID) + .withClassName("Pizza") + .run(); + } + + @NotNull + private Supplier>> createSupplierGetterPizza2() { + return () -> client.data().objectsGetter() + .withID(BatchObjectsTestSuite.PIZZA_2_ID) + .withClassName("Pizza") + .run(); + } + + @NotNull + private Supplier>> createSupplierGetterSoup1() { + return () -> client.data().objectsGetter() + .withID(BatchObjectsTestSuite.SOUP_1_ID) + .withClassName("Soup") + .run(); + } + + @NotNull + private Supplier>> createSupplierGetterSoup2() { + return () -> client.data().objectsGetter() + .withID(BatchObjectsTestSuite.SOUP_2_ID) + .withClassName("Soup") + .run(); } } diff --git a/src/test/java/io/weaviate/integration/client/auth/AuthAzureClientCredentialsTest.java b/src/test/java/io/weaviate/integration/client/auth/AuthAzureClientCredentialsTest.java index 246d546a..f2b59203 100644 --- a/src/test/java/io/weaviate/integration/client/auth/AuthAzureClientCredentialsTest.java +++ b/src/test/java/io/weaviate/integration/client/auth/AuthAzureClientCredentialsTest.java @@ -7,6 +7,7 @@ import org.apache.commons.lang3.StringUtils; import org.junit.Before; import org.junit.ClassRule; +import org.junit.Ignore; import org.junit.Test; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -33,6 +34,7 @@ public void before() { } @Test + @Ignore("client secret expired") public void testAuthAzure() throws AuthException { String clientSecret = System.getenv("AZURE_CLIENT_SECRET"); if (StringUtils.isNotBlank(clientSecret)) { @@ -50,6 +52,7 @@ public void testAuthAzure() throws AuthException { } @Test + @Ignore("client secret expired") public void testAuthAzureHardcodedScope() throws AuthException { String clientSecret = System.getenv("AZURE_CLIENT_SECRET"); if (StringUtils.isNotBlank(clientSecret)) { diff --git a/src/test/java/io/weaviate/integration/client/batch/ClientBatchCreateMockServer2Test.java b/src/test/java/io/weaviate/integration/client/batch/ClientBatchCreateMockServer2Test.java new file mode 100644 index 00000000..b63a2d1c --- /dev/null +++ b/src/test/java/io/weaviate/integration/client/batch/ClientBatchCreateMockServer2Test.java @@ -0,0 +1,293 @@ +package io.weaviate.integration.client.batch; + +import com.jparams.junit4.JParamsTestRunner; +import com.jparams.junit4.data.DataMethod; +import io.weaviate.client.Config; +import io.weaviate.client.WeaviateClient; +import io.weaviate.client.base.Result; +import io.weaviate.client.base.Serializer; +import io.weaviate.client.v1.batch.api.ObjectsBatcher; +import io.weaviate.client.v1.batch.model.ObjectGetResponse; +import io.weaviate.integration.tests.batch.BatchObjectsMockServerTestSuite; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockserver.client.MockServerClient; +import org.mockserver.integration.ClientAndServer; +import org.mockserver.model.Delay; +import org.mockserver.verify.VerificationTimes; + +import java.util.function.Consumer; +import java.util.function.Supplier; + +import static org.mockserver.integration.ClientAndServer.startClientAndServer; +import static org.mockserver.model.HttpRequest.request; +import static org.mockserver.model.HttpResponse.response; + +@RunWith(JParamsTestRunner.class) +public class ClientBatchCreateMockServer2Test { + + private WeaviateClient client; + private ClientAndServer mockServer; + private MockServerClient mockServerClient; + + private static final String MOCK_SERVER_HOST = "localhost"; + private static final int MOCK_SERVER_PORT = 8999; + + @Before + public void before() { + mockServer = startClientAndServer(MOCK_SERVER_PORT); + mockServerClient = new MockServerClient(MOCK_SERVER_HOST, MOCK_SERVER_PORT); + + mockServerClient.when( + request().withMethod("GET").withPath("/v1/meta") + ).respond( + response().withStatusCode(200).withBody(metaBody()) + ); + + Config config = new Config("http", MOCK_SERVER_HOST + ":" + MOCK_SERVER_PORT, null, 1, 1, 1); + client = new WeaviateClient(config); + } + + @After + public void stopMockServer() { + mockServer.stop(); + } + + @Test + @DataMethod(source = ClientBatchCreateMockServer2Test.class, method = "provideForNotCreateBatchDueToConnectionIssue") + public void shouldNotCreateBatchDueToConnectionIssue(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig, + long expectedExecMinMillis, long expectedExecMaxMillis) { + // stop server to simulate connection issues + mockServer.stop(); + + Supplier> supplierObjectsBatcher = () -> client.batch().objectsBatcher(batchRetriesConfig) + .withObjects(BatchObjectsMockServerTestSuite.PIZZA_1, BatchObjectsMockServerTestSuite.PIZZA_2, + BatchObjectsMockServerTestSuite.SOUP_1, BatchObjectsMockServerTestSuite.SOUP_2) + .run(); + + BatchObjectsMockServerTestSuite.testNotCreateBatchDueToConnectionIssue(supplierObjectsBatcher, + expectedExecMinMillis, expectedExecMaxMillis); + } + + @Test + @DataMethod(source = ClientBatchCreateMockServer2Test.class, method = "provideForNotCreateBatchDueToConnectionIssue") + public void shouldNotCreateAutoBatchDueToConnectionIssue(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig, + long expectedExecMinMillis, long expectedExecMaxMillis) { + // stop server to simulate connection issues + mockServer.stop(); + + Consumer>> supplierObjectsBatcher = callback -> { + ObjectsBatcher.AutoBatchConfig autoBatchConfig = ObjectsBatcher.AutoBatchConfig.defaultConfig() + .batchSize(2) + .callback(callback) + .build(); + + client.batch().objectsAutoBatcher(batchRetriesConfig, autoBatchConfig) + .withObjects(BatchObjectsMockServerTestSuite.PIZZA_1, BatchObjectsMockServerTestSuite.PIZZA_2, + BatchObjectsMockServerTestSuite.SOUP_1, BatchObjectsMockServerTestSuite.SOUP_2) + .flush(); + }; + + BatchObjectsMockServerTestSuite.testNotCreateAutoBatchDueToConnectionIssue(supplierObjectsBatcher, + expectedExecMinMillis, expectedExecMaxMillis); + } + + public static Object[][] provideForNotCreateBatchDueToConnectionIssue() { + return new Object[][]{ + new Object[]{ + // final response should be available immediately + ObjectsBatcher.BatchRetriesConfig.defaultConfig() + .retriesIntervalMs(400) + .maxConnectionRetries(0) + .build(), + 0, 350 + }, + new Object[]{ + // final response should be available after 1 retry (400 ms) + ObjectsBatcher.BatchRetriesConfig.defaultConfig() + .retriesIntervalMs(400) + .maxConnectionRetries(1) + .build(), + 400, 750 + }, + new Object[]{ + // final response should be available after 2 retries (400 + 800 ms) + ObjectsBatcher.BatchRetriesConfig.defaultConfig() + .retriesIntervalMs(400) + .maxConnectionRetries(2) + .build(), + 1200, 1550 + }, + new Object[]{ + // final response should be available after 1 retry (400 + 800 + 1200 ms) + ObjectsBatcher.BatchRetriesConfig.defaultConfig() + .retriesIntervalMs(400) + .maxConnectionRetries(3) + .build(), + 2400, 2750 + }, + }; + } + + @Test + @DataMethod(source = ClientBatchCreateMockServer2Test.class, method = "provideForNotCreateBatchDueToTimeoutIssue") + public void shouldNotCreateBatchDueToTimeoutIssue(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig, + int expectedBatchCallsCount) { + // given client times out after 1s + + Serializer serializer = new Serializer(); + String pizza1Str = serializer.toJsonString(BatchObjectsMockServerTestSuite.PIZZA_1); + String soup1Str = serializer.toJsonString(BatchObjectsMockServerTestSuite.SOUP_1); + + // batch request should end up with timeout exception, but Pizza1 and Soup1 should be "added" and available by get + mockServerClient.when( + request().withMethod("POST").withPath("/v1/batch/objects") + ).respond( + response().withDelay(Delay.seconds(2)).withStatusCode(200) + ); + mockServerClient.when( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Pizza", BatchObjectsMockServerTestSuite.PIZZA_1_ID)) + ).respond( + response().withBody(pizza1Str) + ); + mockServerClient.when( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Soup", BatchObjectsMockServerTestSuite.SOUP_1_ID)) + ).respond( + response().withBody(soup1Str) + ); + + Supplier> supplierObjectsBatcher = () -> client.batch().objectsBatcher(batchRetriesConfig) + .withObjects(BatchObjectsMockServerTestSuite.PIZZA_1, BatchObjectsMockServerTestSuite.PIZZA_2, + BatchObjectsMockServerTestSuite.SOUP_1, BatchObjectsMockServerTestSuite.SOUP_2) + .run(); + Consumer assertPostObjectsCallsCount = count -> mockServerClient.verify( + request().withMethod("POST").withPath("/v1/batch/objects"), + VerificationTimes.exactly(count) + ); + Consumer assertGetPizza1CallsCount = count -> mockServerClient.verify( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Pizza", BatchObjectsMockServerTestSuite.PIZZA_1_ID)), + VerificationTimes.exactly(count) + ); + Consumer assertGetPizza2CallsCount = count -> mockServerClient.verify( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Pizza", BatchObjectsMockServerTestSuite.PIZZA_2_ID)), + VerificationTimes.exactly(count) + ); + Consumer assertGetSoup1CallsCount = count -> mockServerClient.verify( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Soup", BatchObjectsMockServerTestSuite.SOUP_1_ID)), + VerificationTimes.exactly(count) + ); + Consumer assertGetSoup2CallsCount = count -> mockServerClient.verify( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Soup", BatchObjectsMockServerTestSuite.SOUP_2_ID)), + VerificationTimes.exactly(count) + ); + + BatchObjectsMockServerTestSuite.testNotCreateBatchDueToTimeoutIssue(supplierObjectsBatcher, + assertPostObjectsCallsCount, assertGetPizza1CallsCount, assertGetPizza2CallsCount, + assertGetSoup1CallsCount, assertGetSoup2CallsCount, expectedBatchCallsCount, "Read timed out"); + } + + @Test + @DataMethod(source = ClientBatchCreateMockServer2Test.class, method = "provideForNotCreateBatchDueToTimeoutIssue") + public void shouldNotCreateAutoBatchDueToTimeoutIssue(ObjectsBatcher.BatchRetriesConfig batchRetriesConfig, + int expectedBatchCallsCount) { + // given client times out after 1s + + Serializer serializer = new Serializer(); + String pizza1Str = serializer.toJsonString(BatchObjectsMockServerTestSuite.PIZZA_1); + String soup1Str = serializer.toJsonString(BatchObjectsMockServerTestSuite.SOUP_1); + + // batch request should end up with timeout exception, but Pizza1 and Soup1 should be "added" and available by get + mockServerClient.when( + request().withMethod("POST").withPath("/v1/batch/objects") + ).respond( + response().withDelay(Delay.seconds(2)).withStatusCode(200) + ); + mockServerClient.when( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Pizza", BatchObjectsMockServerTestSuite.PIZZA_1_ID)) + ).respond( + response().withBody(pizza1Str) + ); + mockServerClient.when( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Soup", BatchObjectsMockServerTestSuite.SOUP_1_ID)) + ).respond( + response().withBody(soup1Str) + ); + + Consumer>> supplierObjectsBatcher = callback -> { + ObjectsBatcher.AutoBatchConfig autoBatchConfig = ObjectsBatcher.AutoBatchConfig.defaultConfig() + .batchSize(2) + .poolSize(2) + .callback(callback) + .build(); + + client.batch().objectsAutoBatcher(batchRetriesConfig, autoBatchConfig) + .withObjects(BatchObjectsMockServerTestSuite.PIZZA_1, BatchObjectsMockServerTestSuite.PIZZA_2, + BatchObjectsMockServerTestSuite.SOUP_1, BatchObjectsMockServerTestSuite.SOUP_2) + .flush(); + }; + + Consumer assertPostObjectsCallsCount = count -> mockServerClient.verify( + request().withMethod("POST").withPath("/v1/batch/objects"), + VerificationTimes.exactly(count) + ); + Consumer assertGetPizza1CallsCount = count -> mockServerClient.verify( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Pizza", BatchObjectsMockServerTestSuite.PIZZA_1_ID)), + VerificationTimes.exactly(count) + ); + Consumer assertGetPizza2CallsCount = count -> mockServerClient.verify( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Pizza", BatchObjectsMockServerTestSuite.PIZZA_2_ID)), + VerificationTimes.exactly(count) + ); + Consumer assertGetSoup1CallsCount = count -> mockServerClient.verify( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Soup", BatchObjectsMockServerTestSuite.SOUP_1_ID)), + VerificationTimes.exactly(count) + ); + Consumer assertGetSoup2CallsCount = count -> mockServerClient.verify( + request().withMethod("GET").withPath(String.format("/v1/objects/%s/%s", "Soup", BatchObjectsMockServerTestSuite.SOUP_2_ID)), + VerificationTimes.exactly(count) + ); + + BatchObjectsMockServerTestSuite.testNotCreateAutoBatchDueToTimeoutIssue(supplierObjectsBatcher, + assertPostObjectsCallsCount, assertGetPizza1CallsCount, assertGetPizza2CallsCount, + assertGetSoup1CallsCount, assertGetSoup2CallsCount, expectedBatchCallsCount, "Read timed out"); + } + + public static Object[][] provideForNotCreateBatchDueToTimeoutIssue() { + return new Object[][]{ + new Object[]{ + // final response should be available immediately + ObjectsBatcher.BatchRetriesConfig.defaultConfig() + .retriesIntervalMs(200) + .maxTimeoutRetries(0) + .build(), + 1 + }, + new Object[]{ + // final response should be available after 1 retry (200 ms) + ObjectsBatcher.BatchRetriesConfig.defaultConfig() + .retriesIntervalMs(200) + .maxTimeoutRetries(1) + .build(), + 2 + }, + new Object[]{ + // final response should be available after 2 retries (200 + 400 ms) + ObjectsBatcher.BatchRetriesConfig.defaultConfig() + .retriesIntervalMs(200) + .maxTimeoutRetries(2) + .build(), + 3 + }, + }; + } + + private String metaBody() { + return String.format("{\n" + + " \"hostname\": \"http://[::]:%s\",\n" + + " \"modules\": {},\n" + + " \"version\": \"%s\"\n" + + "}", MOCK_SERVER_PORT, "1.17.999-mock-server-version"); + } +} diff --git a/src/test/java/io/weaviate/integration/client/batch/ClientBatchCreateTest.java b/src/test/java/io/weaviate/integration/client/batch/ClientBatchCreateTest.java index 7d17f0d1..f6ff6030 100644 --- a/src/test/java/io/weaviate/integration/client/batch/ClientBatchCreateTest.java +++ b/src/test/java/io/weaviate/integration/client/batch/ClientBatchCreateTest.java @@ -5,36 +5,24 @@ import io.weaviate.client.base.Result; import io.weaviate.client.v1.batch.api.ObjectsBatcher; import io.weaviate.client.v1.batch.model.ObjectGetResponse; -import io.weaviate.client.v1.batch.model.ObjectsGetResponseAO2Result; import io.weaviate.client.v1.data.model.WeaviateObject; import io.weaviate.client.v1.data.replication.model.ConsistencyLevel; import io.weaviate.integration.client.WeaviateDockerCompose; import io.weaviate.integration.client.WeaviateTestGenerics; -import io.weaviate.integration.tests.batch.BatchTestSuite; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; -import java.util.function.Function; -import java.util.function.Supplier; -import static org.assertj.core.api.Assertions.assertThat; +import io.weaviate.integration.tests.batch.BatchObjectsTestSuite; +import org.jetbrains.annotations.NotNull; import org.junit.After; import org.junit.Before; import org.junit.ClassRule; import org.junit.Test; -public class ClientBatchCreateTest { +import java.util.List; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; - private static final String PIZZA_1_ID = "abefd256-8574-442b-9293-9205193737ee"; - private static final Map PIZZA_1_PROPS = createFoodProperties("Hawaii", "Universally accepted to be the best pizza ever created."); - private static final String PIZZA_2_ID = "97fa5147-bdad-4d74-9a81-f8babc811b09"; - private static final Map PIZZA_2_PROPS = createFoodProperties("Doener", "A innovation, some say revolution, in the pizza industry."); - private static final String SOUP_1_ID = "565da3b6-60b3-40e5-ba21-e6bfe5dbba91"; - private static final Map SOUP_1_PROPS = createFoodProperties("ChickenSoup", "Used by humans when their inferior genetics are attacked by microscopic organisms."); - private static final String SOUP_2_ID = "07473b34-0ab2-4120-882d-303d9e13f7af"; - private static final Map SOUP_2_PROPS = createFoodProperties("Beautiful", "Putting the game of letter soups to a whole new level."); +public class ClientBatchCreateTest { private WeaviateClient client; private final WeaviateTestGenerics testGenerics = new WeaviateTestGenerics(); @@ -44,9 +32,7 @@ public class ClientBatchCreateTest { @Before public void before() { - String httpHost = compose.getHttpHostAddress(); - Config config = new Config("http", httpHost); - + Config config = new Config("http", compose.getHttpHostAddress()); client = new WeaviateClient(config); testGenerics.createWeaviateTestSchemaFood(client); } @@ -58,172 +44,135 @@ public void after() { @Test public void shouldCreateBatch() { - Supplier> resPizza1 = () -> client.data().creator() - .withClassName("Pizza") - .withID(BatchTestSuite.PIZZA_1_ID) - .withProperties(BatchTestSuite.PIZZA_1_PROPS) + Function> supplierObjectsBatcherPizzas = pizza -> client.batch().objectsBatcher() + .withObjects(pizza, WeaviateObject.builder() + .className("Pizza") + .id(BatchObjectsTestSuite.PIZZA_2_ID) + .properties(BatchObjectsTestSuite.PIZZA_2_PROPS) + .build()) + .withConsistencyLevel(ConsistencyLevel.QUORUM) .run(); - Supplier> resSoup1 = () -> client.data().creator() - .withClassName("Soup") - .withID(BatchTestSuite.SOUP_1_ID) - .withProperties(BatchTestSuite.SOUP_1_PROPS) + Function> supplierObjectsBatcherSoups = soup -> client.batch().objectsBatcher() + .withObjects(soup, WeaviateObject.builder() + .className("Soup") + .id(BatchObjectsTestSuite.SOUP_2_ID) + .properties(BatchObjectsTestSuite.SOUP_2_PROPS) + .build()) + .withConsistencyLevel(ConsistencyLevel.QUORUM) .run(); - Function, Result> resBatchPizzas = (pizza1) -> client.batch().objectsBatcher() - .withObjects( - pizza1.getResult(), - WeaviateObject.builder().className("Pizza").id(BatchTestSuite.PIZZA_2_ID).properties(BatchTestSuite.PIZZA_2_PROPS).build() - ) - .withConsistencyLevel(ConsistencyLevel.QUORUM) - .run(); - - Function, Result> resBatchSoups = (soup1) -> client.batch().objectsBatcher() - .withObjects( - soup1.getResult(), - WeaviateObject.builder().className("Soup").id(BatchTestSuite.SOUP_2_ID).properties(BatchTestSuite.SOUP_2_PROPS).build() - ) - .withConsistencyLevel(ConsistencyLevel.QUORUM) - .run(); - - // check if created objects exist - Supplier>> resGetPizza1 = () -> client.data().objectsGetter().withID(PIZZA_1_ID).withClassName("Pizza").run(); - Supplier>> resGetPizza2 = () -> client.data().objectsGetter().withID(PIZZA_2_ID).withClassName("Pizza").run(); - Supplier>> resGetSoup1 = () -> client.data().objectsGetter().withID(SOUP_1_ID).withClassName("Soup").run(); - Supplier>> resGetSoup2 = () -> client.data().objectsGetter().withID(SOUP_2_ID).withClassName("Soup").run(); - - BatchTestSuite.shouldCreateBatch(resPizza1, resSoup1, resBatchPizzas, resBatchSoups, resGetPizza1, resGetPizza2, resGetSoup1, resGetSoup2); + BatchObjectsTestSuite.testCreateBatch(supplierObjectsBatcherPizzas, supplierObjectsBatcherSoups, + createSupplierDataPizza1(), createSupplierDataSoup1(), + createSupplierGetterPizza1(), createSupplierGetterPizza2(), + createSupplierGetterSoup1(), createSupplierGetterSoup2()); } @Test public void shouldCreateAutoBatch() { - // when - Result resPizza1 = client.data().creator() - .withClassName("Pizza") - .withID(PIZZA_1_ID) - .withProperties(PIZZA_1_PROPS) - .run(); - Result resSoup1 = client.data().creator() - .withClassName("Soup") - .withID(SOUP_1_ID) - .withProperties(SOUP_1_PROPS) - .run(); - - assertThat(resPizza1).isNotNull() - .returns(false, Result::hasErrors) - .extracting(Result::getResult).isNotNull(); - assertThat(resSoup1).isNotNull() - .returns(false, Result::hasErrors) - .extracting(Result::getResult).isNotNull(); - - List> resBatches = Collections.synchronizedList(new ArrayList<>(2)); - ObjectsBatcher.AutoBatchConfig autoBatchConfig = ObjectsBatcher.AutoBatchConfig.defaultConfig() - .batchSize(2) - .callback(resBatches::add) - .build(); - - client.batch().objectsAutoBatcher(autoBatchConfig) - .withObjects( - resPizza1.getResult(), - WeaviateObject.builder().className("Pizza").id(PIZZA_2_ID).properties(PIZZA_2_PROPS).build() - ).flush(); - client.batch().objectsAutoBatcher(autoBatchConfig) - .withObjects( - resSoup1.getResult(), - WeaviateObject.builder().className("Soup").id(SOUP_2_ID).properties(SOUP_2_PROPS).build() - ).flush(); - - // check if created objects exist - Result> resGetPizza1 = client.data().objectsGetter().withID(PIZZA_1_ID).withClassName("Pizza").run(); - Result> resGetPizza2 = client.data().objectsGetter().withID(PIZZA_2_ID).withClassName("Pizza").run(); - Result> resGetSoup1 = client.data().objectsGetter().withID(SOUP_1_ID).withClassName("Soup").run(); - Result> resGetSoup2 = client.data().objectsGetter().withID(SOUP_2_ID).withClassName("Soup").run(); - - // then - assertThat(resBatches.get(0)).isNotNull() - .returns(false, Result::hasErrors); - assertThat(resBatches.get(0).getResult()).hasSize(2); - - assertThat(resBatches.get(1)).isNotNull() - .returns(false, Result::hasErrors); - assertThat(resBatches.get(1).getResult()).hasSize(2); - - assertThat(resGetPizza1).isNotNull() - .returns(false, Result::hasErrors) - .extracting(Result::getResult).asList().hasSize(1) - .extracting(o -> ((WeaviateObject)o).getId()).first().isEqualTo(PIZZA_1_ID); - - assertThat(resGetPizza2).isNotNull() - .returns(false, Result::hasErrors) - .extracting(Result::getResult).asList().hasSize(1) - .extracting(o -> ((WeaviateObject)o).getId()).first().isEqualTo(PIZZA_2_ID); - - assertThat(resGetSoup1).isNotNull() - .returns(false, Result::hasErrors) - .extracting(Result::getResult).asList().hasSize(1) - .extracting(o -> ((WeaviateObject)o).getId()).first().isEqualTo(SOUP_1_ID); - - assertThat(resGetSoup2).isNotNull() - .returns(false, Result::hasErrors) - .extracting(Result::getResult).asList().hasSize(1) - .extracting(o -> ((WeaviateObject)o).getId()).first().isEqualTo(SOUP_2_ID); + BiConsumer>> supplierObjectsBatcherPizzas = (pizza, callback) -> { + ObjectsBatcher.AutoBatchConfig autoBatchConfig = ObjectsBatcher.AutoBatchConfig.defaultConfig() + .batchSize(2) + .callback(callback) + .build(); + + client.batch().objectsAutoBatcher(autoBatchConfig) + .withObjects(pizza, WeaviateObject.builder().className("Pizza") + .id(BatchObjectsTestSuite.PIZZA_2_ID) + .properties(BatchObjectsTestSuite.PIZZA_2_PROPS) + .build()) + .flush(); + }; + BiConsumer>> supplierObjectsBatcherSoups = (soup, callback) -> { + ObjectsBatcher.AutoBatchConfig autoBatchConfig = ObjectsBatcher.AutoBatchConfig.defaultConfig() + .batchSize(2) + .callback(callback) + .build(); + + client.batch().objectsAutoBatcher(autoBatchConfig) + .withObjects(soup, WeaviateObject.builder() + .className("Soup") + .id(BatchObjectsTestSuite.SOUP_2_ID) + .properties(BatchObjectsTestSuite.SOUP_2_PROPS) + .build()) + .flush(); + }; + + BatchObjectsTestSuite.testCreateAutoBatch(supplierObjectsBatcherPizzas, supplierObjectsBatcherSoups, + createSupplierDataPizza1(), createSupplierDataSoup1(), + createSupplierGetterPizza1(), createSupplierGetterPizza2(), + createSupplierGetterSoup1(), createSupplierGetterSoup2()); } @Test public void shouldCreateBatchWithPartialError() { - WeaviateObject pizzaWithError = WeaviateObject.builder() - .className("Pizza") - .id(PIZZA_1_ID) - .properties(createFoodProperties(1, "This pizza should throw a invalid name error")) - .build(); - WeaviateObject pizza = WeaviateObject.builder() - .className("Pizza") - .id(PIZZA_2_ID) - .properties(PIZZA_2_PROPS) - .build(); - - Result resBatch = client.batch().objectsBatcher() - .withObjects(pizzaWithError, pizza) + Supplier> supplierObjectsBatcherPizzas = () -> { + WeaviateObject pizzaWithError = WeaviateObject.builder() + .className("Pizza") + .id(BatchObjectsTestSuite.PIZZA_1_ID) + .properties(BatchObjectsTestSuite.createFoodProperties(1, "This pizza should throw a invalid name error")) + .build(); + WeaviateObject pizza = WeaviateObject.builder() + .className("Pizza") + .id(BatchObjectsTestSuite.PIZZA_2_ID) + .properties(BatchObjectsTestSuite.PIZZA_2_PROPS) + .build(); + + return client.batch().objectsBatcher() + .withObjects(pizzaWithError, pizza) + .run(); + }; + + BatchObjectsTestSuite.testCreateBatchWithPartialError(supplierObjectsBatcherPizzas, + createSupplierGetterPizza1(), createSupplierGetterPizza2()); + } + + @NotNull + private Supplier> createSupplierDataSoup1() { + return () -> client.data().creator() + .withClassName("Soup") + .withID(BatchObjectsTestSuite.SOUP_1_ID) + .withProperties(BatchObjectsTestSuite.SOUP_1_PROPS) .run(); + } - assertThat(resBatch).isNotNull() - .returns(false, Result::hasErrors); - assertThat(resBatch.getResult()).hasSize(2); - - ObjectGetResponse resPizzaWithError = resBatch.getResult()[0]; - assertThat(resPizzaWithError.getId()).isEqualTo(PIZZA_1_ID); - assertThat(resPizzaWithError.getResult().getErrors()) - .extracting(ObjectsGetResponseAO2Result.ErrorResponse::getError).asList() - .first() - .extracting(i -> ((ObjectsGetResponseAO2Result.ErrorItem) i).getMessage()).asString() - .contains("invalid text property 'name' on class 'Pizza': not a string, but json.Number"); - ObjectGetResponse resPizza = resBatch.getResult()[1]; - assertThat(resPizza.getId()).isEqualTo(PIZZA_2_ID); - assertThat(resPizza.getResult().getErrors()).isNull(); - - Result> resGetPizzaWithError = client.data().objectsGetter() + @NotNull + private Supplier> createSupplierDataPizza1() { + return () -> client.data().creator() .withClassName("Pizza") - .withID(PIZZA_1_ID) + .withID(BatchObjectsTestSuite.PIZZA_1_ID) + .withProperties(BatchObjectsTestSuite.PIZZA_1_PROPS) .run(); - Result> resGetPizza = client.data().objectsGetter() + } + + @NotNull + private Supplier>> createSupplierGetterPizza1() { + return () -> client.data().objectsGetter() + .withID(BatchObjectsTestSuite.PIZZA_1_ID) .withClassName("Pizza") - .withID(PIZZA_2_ID) .run(); - - assertThat(resGetPizzaWithError).isNotNull() - .returns(false, Result::hasErrors); - assertThat(resGetPizzaWithError.getResult()).isNull(); - - assertThat(resGetPizza).isNotNull() - .returns(false, Result::hasErrors); - assertThat(resGetPizza.getResult()).hasSize(1); } + @NotNull + private Supplier>> createSupplierGetterPizza2() { + return () -> client.data().objectsGetter() + .withID(BatchObjectsTestSuite.PIZZA_2_ID) + .withClassName("Pizza") + .run(); + } - private static Map createFoodProperties(Object name, Object description) { - Map props = new HashMap<>(); - props.put("name", name); - props.put("description", description); + @NotNull + private Supplier>> createSupplierGetterSoup1() { + return () -> client.data().objectsGetter() + .withID(BatchObjectsTestSuite.SOUP_1_ID) + .withClassName("Soup") + .run(); + } - return props; + @NotNull + private Supplier>> createSupplierGetterSoup2() { + return () -> client.data().objectsGetter() + .withID(BatchObjectsTestSuite.SOUP_2_ID) + .withClassName("Soup") + .run(); } } diff --git a/src/test/java/io/weaviate/integration/tests/batch/BatchObjectsMockServerTestSuite.java b/src/test/java/io/weaviate/integration/tests/batch/BatchObjectsMockServerTestSuite.java new file mode 100644 index 00000000..4969cf0d --- /dev/null +++ b/src/test/java/io/weaviate/integration/tests/batch/BatchObjectsMockServerTestSuite.java @@ -0,0 +1,183 @@ +package io.weaviate.integration.tests.batch; + +import io.weaviate.client.base.Result; +import io.weaviate.client.base.WeaviateErrorMessage; +import io.weaviate.client.v1.batch.model.ObjectGetResponse; +import io.weaviate.client.v1.batch.model.ObjectGetResponseStatus; +import io.weaviate.client.v1.batch.model.ObjectsGetResponseAO2Result; +import io.weaviate.client.v1.data.model.WeaviateObject; + +import java.net.ConnectException; +import java.net.SocketTimeoutException; +import java.time.ZonedDateTime; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import static org.assertj.core.api.Assertions.assertThat; + +public class BatchObjectsMockServerTestSuite { + + public static final String PIZZA_1_ID = "abefd256-8574-442b-9293-9205193737ee"; + private static final Map PIZZA_1_PROPS = createFoodProperties( + "Hawaii", "Universally accepted to be the best pizza ever created."); + public static final String PIZZA_2_ID = "97fa5147-bdad-4d74-9a81-f8babc811b09"; + private static final Map PIZZA_2_PROPS = createFoodProperties( + "Doener", "A innovation, some say revolution, in the pizza industry."); + public static final String SOUP_1_ID = "565da3b6-60b3-40e5-ba21-e6bfe5dbba91"; + private static final Map SOUP_1_PROPS = createFoodProperties( + "ChickenSoup", "Used by humans when their inferior genetics are attacked by microscopic organisms."); + public static final String SOUP_2_ID = "07473b34-0ab2-4120-882d-303d9e13f7af"; + private static final Map SOUP_2_PROPS = createFoodProperties( + "Beautiful", "Putting the game of letter soups to a whole new level."); + + public static final WeaviateObject PIZZA_1 = WeaviateObject.builder().className("Pizza").id(PIZZA_1_ID).properties(PIZZA_1_PROPS).build(); + public static final WeaviateObject PIZZA_2 = WeaviateObject.builder().className("Pizza").id(PIZZA_2_ID).properties(PIZZA_2_PROPS).build(); + public static final WeaviateObject SOUP_1 = WeaviateObject.builder().className("Soup").id(SOUP_1_ID).properties(SOUP_1_PROPS).build(); + public static final WeaviateObject SOUP_2 = WeaviateObject.builder().className("Soup").id(SOUP_2_ID).properties(SOUP_2_PROPS).build(); + + + public static void testNotCreateBatchDueToConnectionIssue(Supplier> supplierObjectsBatcher, + long expectedExecMinMillis, long expectedExecMaxMillis) { + ZonedDateTime start = ZonedDateTime.now(); + Result resBatch = supplierObjectsBatcher.get(); + ZonedDateTime end = ZonedDateTime.now(); + + assertThat(ChronoUnit.MILLIS.between(start, end)).isBetween(expectedExecMinMillis, expectedExecMaxMillis); + assertThat(resBatch.getResult()).isNull(); + assertThat(resBatch.hasErrors()).isTrue(); + + List errorMessages = resBatch.getError().getMessages(); + assertThat(errorMessages).hasSize(2); + assertThat(errorMessages.get(0).getThrowable()).isInstanceOf(ConnectException.class); + assertThat(errorMessages.get(0).getMessage()).contains("Connection refused"); + assertThat(errorMessages.get(1).getThrowable()).isNull(); + assertThat(errorMessages.get(1).getMessage()).contains(PIZZA_1_ID, PIZZA_2_ID, SOUP_1_ID, SOUP_2_ID); + } + + public static void testNotCreateAutoBatchDueToConnectionIssue(Consumer>> supplierObjectsBatcher, + long expectedExecMinMillis, long expectedExecMaxMillis) { + List> resBatches = Collections.synchronizedList(new ArrayList<>(2)); + + ZonedDateTime start = ZonedDateTime.now(); + supplierObjectsBatcher.accept(resBatches::add); + ZonedDateTime end = ZonedDateTime.now(); + + assertThat(ChronoUnit.MILLIS.between(start, end)).isBetween(expectedExecMinMillis, expectedExecMaxMillis); + assertThat(resBatches).hasSize(2); + + for (Result resBatch : resBatches) { + assertThat(resBatch.getResult()).isNull(); + assertThat(resBatch.hasErrors()).isTrue(); + + List errorMessages = resBatch.getError().getMessages(); + assertThat(errorMessages).hasSize(2); + assertThat(errorMessages.get(0).getThrowable()).isInstanceOf(ConnectException.class); + assertThat(errorMessages.get(0).getMessage()).contains("Connection refused"); + assertThat(errorMessages.get(1).getThrowable()).isNull(); + + String failedIdsMessage = errorMessages.get(1).getMessage(); + if (failedIdsMessage.contains(PIZZA_1_ID)) { + assertThat(failedIdsMessage).contains(PIZZA_1_ID, PIZZA_2_ID).doesNotContain(SOUP_1_ID, SOUP_2_ID); + } else { + assertThat(failedIdsMessage).contains(SOUP_1_ID, SOUP_2_ID).doesNotContain(PIZZA_1_ID, PIZZA_2_ID); + } + } + } + + public static void testNotCreateBatchDueToTimeoutIssue(Supplier> supplierObjectsBatcher, + Consumer assertPostObjectsCallsCount, + Consumer assertGetPizza1CallsCount, + Consumer assertGetPizza2CallsCount, + Consumer assertGetSoup1CallsCount, + Consumer assertGetSoup2CallsCount, + int expectedBatchCallsCount, String expectedErr) { + Result resBatch = supplierObjectsBatcher.get(); + + assertPostObjectsCallsCount.accept(expectedBatchCallsCount); + assertGetPizza2CallsCount.accept(expectedBatchCallsCount); + assertGetSoup2CallsCount.accept(expectedBatchCallsCount); + assertGetPizza1CallsCount.accept(1); + assertGetSoup1CallsCount.accept(1); + + assertThat(resBatch.getResult()).hasSize(2); + assertThat(resBatch.hasErrors()).isTrue(); + + List errorMessages = resBatch.getError().getMessages(); + assertThat(errorMessages).hasSize(2); + assertThat(errorMessages.get(0).getThrowable()).isInstanceOf(SocketTimeoutException.class); + assertThat(errorMessages.get(0).getMessage()).contains(expectedErr); + assertThat(errorMessages.get(1).getThrowable()).isNull(); + assertThat(errorMessages.get(1).getMessage()).contains(PIZZA_2_ID, SOUP_2_ID).doesNotContain(PIZZA_1_ID, SOUP_1_ID); + + assertThat(resBatch.getResult()[0]) + .returns(PIZZA_1_ID, ObjectGetResponse::getId) + .extracting(ObjectGetResponse::getResult).isNotNull() + .returns(ObjectGetResponseStatus.SUCCESS, ObjectsGetResponseAO2Result::getStatus) + .returns(null, ObjectsGetResponseAO2Result::getErrors); + assertThat(resBatch.getResult()[1]) + .returns(SOUP_1_ID, ObjectGetResponse::getId) + .extracting(ObjectGetResponse::getResult).isNotNull() + .returns(ObjectGetResponseStatus.SUCCESS, ObjectsGetResponseAO2Result::getStatus) + .returns(null, ObjectsGetResponseAO2Result::getErrors); + } + + public static void testNotCreateAutoBatchDueToTimeoutIssue(Consumer>> supplierObjectsBatcher, + Consumer assertPostObjectsCallsCount, + Consumer assertGetPizza1CallsCount, + Consumer assertGetPizza2CallsCount, + Consumer assertGetSoup1CallsCount, + Consumer assertGetSoup2CallsCount, + int expectedBatchCallsCount, String expectedErr) { + List> resBatches = Collections.synchronizedList(new ArrayList<>(2)); + supplierObjectsBatcher.accept(resBatches::add); + + assertPostObjectsCallsCount.accept(expectedBatchCallsCount * 2); + assertGetPizza2CallsCount.accept(expectedBatchCallsCount); + assertGetSoup2CallsCount.accept(expectedBatchCallsCount); + assertGetPizza1CallsCount.accept(1); + assertGetSoup1CallsCount.accept(1); + + assertThat(resBatches).hasSize(2); + for (Result resBatch : resBatches) { + assertThat(resBatch.getResult()).hasSize(1); + assertThat(resBatch.hasErrors()).isTrue(); + + List errorMessages = resBatch.getError().getMessages(); + assertThat(errorMessages).hasSize(2); + assertThat(errorMessages.get(0).getThrowable()).isInstanceOf(SocketTimeoutException.class); + assertThat(errorMessages.get(0).getMessage()).contains(expectedErr); + assertThat(errorMessages.get(1).getThrowable()).isNull(); + + String failedIdsMessage = errorMessages.get(1).getMessage(); + if (failedIdsMessage.contains(PIZZA_2_ID)) { + assertThat(failedIdsMessage).contains(PIZZA_2_ID).doesNotContain(PIZZA_1_ID, SOUP_1_ID, SOUP_2_ID); + assertThat(resBatch.getResult()[0]) + .returns(PIZZA_1_ID, ObjectGetResponse::getId) + .extracting(ObjectGetResponse::getResult).isNotNull() + .returns(ObjectGetResponseStatus.SUCCESS, ObjectsGetResponseAO2Result::getStatus) + .returns(null, ObjectsGetResponseAO2Result::getErrors); + } else { + assertThat(failedIdsMessage).contains(SOUP_2_ID).doesNotContain(PIZZA_1_ID, PIZZA_2_ID, SOUP_1_ID); + assertThat(resBatch.getResult()[0]) + .returns(SOUP_1_ID, ObjectGetResponse::getId) + .extracting(ObjectGetResponse::getResult).isNotNull() + .returns(ObjectGetResponseStatus.SUCCESS, ObjectsGetResponseAO2Result::getStatus) + .returns(null, ObjectsGetResponseAO2Result::getErrors); + } + } + } + + private static Map createFoodProperties(String name, String description) { + Map props = new HashMap<>(); + props.put("name", name); + props.put("description", description); + + return props; + } +} diff --git a/src/test/java/io/weaviate/integration/tests/batch/BatchObjectsTestSuite.java b/src/test/java/io/weaviate/integration/tests/batch/BatchObjectsTestSuite.java new file mode 100644 index 00000000..a1d2a870 --- /dev/null +++ b/src/test/java/io/weaviate/integration/tests/batch/BatchObjectsTestSuite.java @@ -0,0 +1,181 @@ +package io.weaviate.integration.tests.batch; + +import io.weaviate.client.base.Result; +import io.weaviate.client.v1.batch.model.ObjectGetResponse; +import io.weaviate.client.v1.batch.model.ObjectsGetResponseAO2Result; +import io.weaviate.client.v1.data.model.WeaviateObject; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + +import static org.assertj.core.api.Assertions.assertThat; + +public class BatchObjectsTestSuite { + + public static final String PIZZA_1_ID = "abefd256-8574-442b-9293-9205193737ee"; + public static final Map PIZZA_1_PROPS = createFoodProperties( + "Hawaii", "Universally accepted to be the best pizza ever created."); + public static final String PIZZA_2_ID = "97fa5147-bdad-4d74-9a81-f8babc811b09"; + public static final Map PIZZA_2_PROPS = createFoodProperties( + "Doener", "A innovation, some say revolution, in the pizza industry."); + public static final String SOUP_1_ID = "565da3b6-60b3-40e5-ba21-e6bfe5dbba91"; + public static final Map SOUP_1_PROPS = createFoodProperties( + "ChickenSoup", "Used by humans when their inferior genetics are attacked by microscopic organisms."); + public static final String SOUP_2_ID = "07473b34-0ab2-4120-882d-303d9e13f7af"; + public static final Map SOUP_2_PROPS = createFoodProperties( + "Beautiful", "Putting the game of letter soups to a whole new level."); + + + public static void testCreateBatch(Function> supplierObjectsBatcherPizzas, + Function> supplierObjectBatcherSoups, + Supplier> supplierDataPizza1, + Supplier> supplierDataSoup1, + Supplier>> supplierGetterPizza1, + Supplier>> supplierGetterPizza2, + Supplier>> supplierGetterSoup1, + Supplier>> supplierGetterSoup2) { + // when + Result resPizza1 = supplierDataPizza1.get(); + Result resSoup1 = supplierDataSoup1.get(); + + assertThat(resPizza1).isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult).isNotNull(); + assertThat(resSoup1).isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult).isNotNull(); + + Result resBatchPizzas = supplierObjectsBatcherPizzas.apply(resPizza1.getResult()); + Result resBatchSoups = supplierObjectBatcherSoups.apply(resSoup1.getResult()); + + assertThat(resBatchPizzas).isNotNull() + .returns(false, Result::hasErrors); + assertThat(resBatchPizzas.getResult()).hasSize(2); + assertThat(resBatchSoups).isNotNull() + .returns(false, Result::hasErrors); + assertThat(resBatchSoups.getResult()).hasSize(2); + + // check if created objects exist + Result> resGetPizza1 = supplierGetterPizza1.get(); + Result> resGetPizza2 = supplierGetterPizza2.get(); + Result> resGetSoup1 = supplierGetterSoup1.get(); + Result> resGetSoup2 = supplierGetterSoup2.get(); + + assertThat(resGetPizza1).isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult).asList().hasSize(1) + .extracting(o -> ((WeaviateObject) o).getId()).first().isEqualTo(PIZZA_1_ID); + assertThat(resGetPizza2).isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult).asList().hasSize(1) + .extracting(o -> ((WeaviateObject) o).getId()).first().isEqualTo(PIZZA_2_ID); + assertThat(resGetSoup1).isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult).asList().hasSize(1) + .extracting(o -> ((WeaviateObject) o).getId()).first().isEqualTo(SOUP_1_ID); + assertThat(resGetSoup2).isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult).asList().hasSize(1) + .extracting(o -> ((WeaviateObject) o).getId()).first().isEqualTo(SOUP_2_ID); + } + + public static void testCreateAutoBatch(BiConsumer>> supplierObjectsBatcherPizzas, + BiConsumer>> supplierObjectsBatcherSoups, + Supplier> supplierDataPizza1, + Supplier> supplierDataSoup1, + Supplier>> supplierGetterPizza1, + Supplier>> supplierGetterPizza2, + Supplier>> supplierGetterSoup1, + Supplier>> supplierGetterSoup2) { + // when + Result resPizza1 = supplierDataPizza1.get(); + Result resSoup1 = supplierDataSoup1.get(); + + assertThat(resPizza1).isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult).isNotNull(); + assertThat(resSoup1).isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult).isNotNull(); + + List> resBatches = Collections.synchronizedList(new ArrayList<>(2)); + supplierObjectsBatcherPizzas.accept(resPizza1.getResult(), resBatches::add); + supplierObjectsBatcherSoups.accept(resSoup1.getResult(), resBatches::add); + + assertThat(resBatches.get(0)).isNotNull() + .returns(false, Result::hasErrors); + assertThat(resBatches.get(0).getResult()).hasSize(2); + assertThat(resBatches.get(1)).isNotNull() + .returns(false, Result::hasErrors); + assertThat(resBatches.get(1).getResult()).hasSize(2); + + // check if created objects exist + Result> resGetPizza1 = supplierGetterPizza1.get(); + Result> resGetPizza2 = supplierGetterPizza2.get(); + Result> resGetSoup1 = supplierGetterSoup1.get(); + Result> resGetSoup2 = supplierGetterSoup2.get(); + + assertThat(resGetPizza1).isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult).asList().hasSize(1) + .extracting(o -> ((WeaviateObject) o).getId()).first().isEqualTo(PIZZA_1_ID); + assertThat(resGetPizza2).isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult).asList().hasSize(1) + .extracting(o -> ((WeaviateObject) o).getId()).first().isEqualTo(PIZZA_2_ID); + assertThat(resGetSoup1).isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult).asList().hasSize(1) + .extracting(o -> ((WeaviateObject) o).getId()).first().isEqualTo(SOUP_1_ID); + assertThat(resGetSoup2).isNotNull() + .returns(false, Result::hasErrors) + .extracting(Result::getResult).asList().hasSize(1) + .extracting(o -> ((WeaviateObject) o).getId()).first().isEqualTo(SOUP_2_ID); + } + + public static void testCreateBatchWithPartialError(Supplier> supplierObjectsBatcherPizzas, + Supplier>> supplierGetterPizza1, + Supplier>> supplierGetterPizza2) { + Result resBatch = supplierObjectsBatcherPizzas.get(); + assertThat(resBatch).isNotNull() + .returns(false, Result::hasErrors); + assertThat(resBatch.getResult()).hasSize(2); + + ObjectGetResponse resPizzaWithError = resBatch.getResult()[0]; + assertThat(resPizzaWithError.getId()).isEqualTo(PIZZA_1_ID); + assertThat(resPizzaWithError.getResult().getErrors()) + .extracting(ObjectsGetResponseAO2Result.ErrorResponse::getError).asList() + .first() + .extracting(i -> ((ObjectsGetResponseAO2Result.ErrorItem) i).getMessage()).asString() + .contains("invalid text property 'name' on class 'Pizza': not a string, but json.Number"); + ObjectGetResponse resPizza = resBatch.getResult()[1]; + assertThat(resPizza.getId()).isEqualTo(PIZZA_2_ID); + assertThat(resPizza.getResult().getErrors()).isNull(); + + Result> resGetPizzaWithError = supplierGetterPizza1.get(); + Result> resGetPizza = supplierGetterPizza2.get(); + + assertThat(resGetPizzaWithError).isNotNull() + .returns(false, Result::hasErrors); + assertThat(resGetPizzaWithError.getResult()).isNull(); + assertThat(resGetPizza).isNotNull() + .returns(false, Result::hasErrors); + assertThat(resGetPizza.getResult()).hasSize(1); + } + + + public static Map createFoodProperties(Object name, Object description) { + Map props = new HashMap<>(); + props.put("name", name); + props.put("description", description); + + return props; + } +} diff --git a/src/test/java/io/weaviate/integration/tests/batch/BatchTestSuite.java b/src/test/java/io/weaviate/integration/tests/batch/BatchTestSuite.java deleted file mode 100644 index 48af5a70..00000000 --- a/src/test/java/io/weaviate/integration/tests/batch/BatchTestSuite.java +++ /dev/null @@ -1,86 +0,0 @@ -package io.weaviate.integration.tests.batch; - -import io.weaviate.client.base.Result; -import io.weaviate.client.v1.batch.model.ObjectGetResponse; -import io.weaviate.client.v1.data.model.WeaviateObject; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Function; -import java.util.function.Supplier; -import static org.assertj.core.api.Assertions.assertThat; - -public class BatchTestSuite { - - public static final String PIZZA_1_ID = "abefd256-8574-442b-9293-9205193737ee"; - public static final Map PIZZA_1_PROPS = createFoodProperties("Hawaii", "Universally accepted to be the best pizza ever created."); - public static final String PIZZA_2_ID = "97fa5147-bdad-4d74-9a81-f8babc811b09"; - public static final Map PIZZA_2_PROPS = createFoodProperties("Doener", "A innovation, some say revolution, in the pizza industry."); - public static final String SOUP_1_ID = "565da3b6-60b3-40e5-ba21-e6bfe5dbba91"; - public static final Map SOUP_1_PROPS = createFoodProperties("ChickenSoup", "Used by humans when their inferior genetics are attacked by " + - "microscopic organisms."); - public static final String SOUP_2_ID = "07473b34-0ab2-4120-882d-303d9e13f7af"; - public static final Map SOUP_2_PROPS = createFoodProperties("Beautiful", "Putting the game of letter soups to a whole new level."); - - public static void shouldCreateBatch(Supplier> createResPizza1, Supplier> createResSoup1, - Function, Result> supplyResBatchPizzas, Function, Result> supplyResBatchSoups, - Supplier>> supplyResGetPizza1, Supplier>> supplyResGetPizza2, - Supplier>> supplyResGetSoup1, Supplier>> supplyResGetSoup2) { - // when - Result resPizza1 = createResPizza1.get(); - Result resSoup1 = createResSoup1.get(); - - assertThat(resPizza1).isNotNull() - .returns(false, Result::hasErrors) - .extracting(Result::getResult).isNotNull(); - assertThat(resSoup1).isNotNull() - .returns(false, Result::hasErrors) - .extracting(Result::getResult).isNotNull(); - - Result resBatchPizzas = supplyResBatchPizzas.apply(resPizza1); - Result resBatchSoups = supplyResBatchSoups.apply(resSoup1); - - // check if created objects exist - Result> resGetPizza1 = supplyResGetPizza1.get(); - Result> resGetPizza2 = supplyResGetPizza2.get(); - Result> resGetSoup1 = supplyResGetSoup1.get(); - Result> resGetSoup2 = supplyResGetSoup2.get(); - - // then - assertThat(resBatchPizzas).isNotNull() - .returns(false, Result::hasErrors); - assertThat(resBatchPizzas.getResult()).hasSize(2); - - assertThat(resBatchSoups).isNotNull() - .returns(false, Result::hasErrors); - assertThat(resBatchSoups.getResult()).hasSize(2); - - assertThat(resGetPizza1).isNotNull() - .returns(false, Result::hasErrors) - .extracting(Result::getResult).asList().hasSize(1) - .extracting(o -> ((WeaviateObject) o).getId()).first().isEqualTo(PIZZA_1_ID); - - assertThat(resGetPizza2).isNotNull() - .returns(false, Result::hasErrors) - .extracting(Result::getResult).asList().hasSize(1) - .extracting(o -> ((WeaviateObject) o).getId()).first().isEqualTo(PIZZA_2_ID); - - assertThat(resGetSoup1).isNotNull() - .returns(false, Result::hasErrors) - .extracting(Result::getResult).asList().hasSize(1) - .extracting(o -> ((WeaviateObject) o).getId()).first().isEqualTo(SOUP_1_ID); - - assertThat(resGetSoup2).isNotNull() - .returns(false, Result::hasErrors) - .extracting(Result::getResult).asList().hasSize(1) - .extracting(o -> ((WeaviateObject) o).getId()).first().isEqualTo(SOUP_2_ID); - } - - private static Map createFoodProperties(Object name, Object description) { - Map props = new HashMap<>(); - props.put("name", name); - props.put("description", description); - - return props; - } -}