diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java index f3e2ea1f8f..adafc7a55d 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/accumlator/InMemoryBuffer.java @@ -58,6 +58,7 @@ public InMemoryBuffer(String batchOptionKeyName, OutputCodecContext outputCodecC this.outputCodecContext = outputCodecContext; } + @Override public void addRecord(Record record) { records.add(record); Event event = record.getData(); @@ -72,6 +73,7 @@ public void addRecord(Record record) { eventCount++; } + @Override public List> getRecords() { return records; } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java index d981cf67ca..61fb36c4ff 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java @@ -22,10 +22,13 @@ import org.opensearch.dataprepper.model.sink.OutputCodecContext; import org.opensearch.dataprepper.model.sink.Sink; import org.opensearch.dataprepper.model.sink.SinkContext; +import org.opensearch.dataprepper.model.types.ByteCount; import org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; +import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBuffer; import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory; import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; +import org.opensearch.dataprepper.plugins.lambda.common.util.ThresholdCheck; import org.opensearch.dataprepper.plugins.lambda.sink.dlq.DlqPushHandler; import org.opensearch.dataprepper.plugins.lambda.sink.dlq.LambdaSinkFailedDlqData; import org.opensearch.dataprepper.model.failures.DlqObject; @@ -38,6 +41,7 @@ import java.time.Duration; import java.util.Collection; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -75,6 +79,13 @@ public class LambdaSink extends AbstractSink> { private final OutputCodecContext outputCodecContext; private volatile boolean sinkInitialized; private DlqPushHandler dlqPushHandler = null; + final int maxEvents; + final long maxBytes; + final Duration maxCollectTime; + + // The partial buffer that may not yet have reached threshold. + // Access must be synchronized + private Buffer statefulBuffer; @DataPrepperPluginConstructor public LambdaSink(final PluginSetting pluginSetting, @@ -90,6 +101,9 @@ public LambdaSink(final PluginSetting pluginSetting, this.lambdaSinkConfig = lambdaSinkConfig; this.expressionEvaluator = expressionEvaluator; this.outputCodecContext = OutputCodecContext.fromSinkContext(sinkContext); + this.maxEvents = lambdaSinkConfig.getBatchOptions().getThresholdOptions().getEventCount(); + this.maxBytes = lambdaSinkConfig.getBatchOptions().getThresholdOptions().getMaximumSize().getBytes(); + this.maxCollectTime = lambdaSinkConfig.getBatchOptions().getThresholdOptions().getEventCollectTimeOut(); this.numberOfRecordsSuccessCounter = pluginMetrics.counter( NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS); @@ -138,57 +152,59 @@ public void doInitialize() { } private void doInitializeInternal() { + // Initialize the partial buffer + statefulBuffer = new InMemoryBuffer( + lambdaSinkConfig.getBatchOptions().getKeyName(), + outputCodecContext + ); sinkInitialized = Boolean.TRUE; } /** - * @param records Records to be output + * We only flush the partial buffer if we're shutting down or if we want to + * do a time-based flush. */ @Override - public void doOutput(final Collection> records) { + public synchronized void shutdown() { + // Flush the partial buffer if any leftover + if (statefulBuffer.getEventCount() > 0) { + flushBuffers(Collections.singletonList(statefulBuffer)); + } + } + + @Override + public synchronized void doOutput(final Collection> records) { + if (!sinkInitialized) { + LOG.warn("LambdaSink doOutput called before initialization"); + return; + } if (records.isEmpty()) { return; } - Map> bufferToFutureMap = new HashMap<>(); - try { - //Result from lambda is not currently processes. - bufferToFutureMap = LambdaCommonHandler.sendRecords( - records, - lambdaSinkConfig, - lambdaAsyncClient, - outputCodecContext); - } catch (Exception e) { - LOG.error("Exception while processing records ", e); - handleFailure(records, e, HttpURLConnection.HTTP_BAD_REQUEST); - } + // We'll collect any "full" buffers in a local list, flush them at the end + List fullBuffers = new ArrayList<>(); - for (Map.Entry> entry : bufferToFutureMap.entrySet()) { - CompletableFuture future = entry.getValue(); - Buffer inputBuffer = entry.getKey(); - try { - InvokeResponse response = future.join(); - Duration latency = inputBuffer.stopLatencyWatch(); - lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); - requestPayloadMetric.record(inputBuffer.getPayloadRequestSize()); - if (!isSuccess(response)) { - String errorMessage = String.format("Lambda invoke failed with status code %s error %s ", - response.statusCode(), response.payload().asUtf8String()); - throw new RuntimeException(errorMessage); - } - - releaseEventHandles(inputBuffer.getRecords(), true); - numberOfRecordsSuccessCounter.increment(inputBuffer.getEventCount()); - numberOfRequestsSuccessCounter.increment(); - if (response.payload() != null) { - responsePayloadMetric.record(response.payload().asByteArray().length); - } + // Add to the persistent buffer, check threshold + for (Record record : records) { + //statefulBuffer is either empty or partially filled(from previous run) + statefulBuffer.addRecord(record); - } catch (Exception e) { - LOG.error(NOISY, e.getMessage(), e); - handleFailure(inputBuffer.getRecords(), new RuntimeException("failed"), HttpURLConnection.HTTP_INTERNAL_ERROR); + if (isThresholdExceeded(statefulBuffer)) { + // This buffer is full + fullBuffers.add(statefulBuffer); + // Create new partial buffer + statefulBuffer = new InMemoryBuffer( + lambdaSinkConfig.getBatchOptions().getKeyName(), + outputCodecContext + ); } } + + // Flush any full buffers + if (!fullBuffers.isEmpty()) { + flushBuffers(fullBuffers); + } } @@ -210,7 +226,7 @@ private DlqObject createDlqObjectFromEvent(final Event event, .build(); } - void handleFailure(Collection> failedRecords, Throwable throwable, int statusCode) { + synchronized void handleFailure(Collection> failedRecords, Throwable throwable, int statusCode) { if (failedRecords.isEmpty()) { return; } @@ -249,4 +265,65 @@ private void releaseEventHandles(Collection> records, boolean succ } } } + + private synchronized void flushBuffers(final List buffersToFlush) { + // Combine all their records for a single call to sendRecords + List> combinedRecords = new ArrayList<>(); + for (Buffer buf : buffersToFlush) { + combinedRecords.addAll(buf.getRecords()); + } + + Map> bufferToFutureMap; + try { + bufferToFutureMap = LambdaCommonHandler.sendRecords( + combinedRecords, + lambdaSinkConfig, + lambdaAsyncClient, + outputCodecContext + ); + } catch (Exception e) { + LOG.error(NOISY, "Error sending buffers to Lambda", e); + handleFailure(combinedRecords, e, HttpURLConnection.HTTP_INTERNAL_ERROR); + return; + } + + for (Map.Entry> entry : bufferToFutureMap.entrySet()) { + Buffer inputBuffer = entry.getKey(); + CompletableFuture future = entry.getValue(); + + try { + InvokeResponse response = future.join(); + Duration latency = inputBuffer.stopLatencyWatch(); + lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS); + requestPayloadMetric.record(inputBuffer.getPayloadRequestSize()); + if (!isSuccess(response)) { + String errorMsg = String.format( + "Lambda invoke failed with code %d, error: %s", + response.statusCode(), + response.payload() != null ? response.payload().asUtf8String() : "No payload" + ); + throw new RuntimeException(errorMsg); + } + + releaseEventHandles(inputBuffer.getRecords(), true); + numberOfRecordsSuccessCounter.increment(inputBuffer.getEventCount()); + numberOfRequestsSuccessCounter.increment(); + if (response.payload() != null) { + responsePayloadMetric.record(response.payload().asByteArray().length); + } + } catch (Exception ex) { + LOG.error(NOISY, "Error handling future response from Lambda", ex); + handleFailure(inputBuffer.getRecords(), ex, HttpURLConnection.HTTP_INTERNAL_ERROR); + } + } + } + + private boolean isThresholdExceeded(Buffer buffer) { + return ThresholdCheck.checkThresholdExceed( + buffer, + maxEvents, + ByteCount.ofBytes(maxBytes), + maxCollectTime + ); + } } diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkTest.java index cf45281cfd..91559ba889 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkTest.java @@ -1,8 +1,3 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - package org.opensearch.dataprepper.plugins.lambda.sink; import io.micrometer.core.instrument.Counter; @@ -12,219 +7,335 @@ import org.junit.jupiter.api.Test; import org.mockito.Mock; import org.mockito.MockedStatic; -import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; import org.opensearch.dataprepper.expression.ExpressionEvaluator; import org.opensearch.dataprepper.metrics.PluginMetrics; -import org.opensearch.dataprepper.model.event.Event; -import org.opensearch.dataprepper.model.record.Record; -import org.opensearch.dataprepper.model.codec.OutputCodec; import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.plugin.PluginFactory; +import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.sink.SinkContext; -import org.opensearch.dataprepper.model.sink.OutputCodecContext; import org.opensearch.dataprepper.model.types.ByteCount; -import org.opensearch.dataprepper.model.configuration.PluginModel; -import org.opensearch.dataprepper.plugins.codec.json.JsonOutputCodec; +import org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBuffer; -import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; -import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; -import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; -import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType; -import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions; -import org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler; +import org.opensearch.dataprepper.plugins.lambda.common.config.*; import org.opensearch.dataprepper.plugins.lambda.sink.dlq.DlqPushHandler; -import software.amazon.awssdk.services.lambda.model.InvokeResponse; -import software.amazon.awssdk.core.SdkBytes; -import software.amazon.awssdk.services.lambda.LambdaAsyncClient; +import org.opensearch.dataprepper.plugins.lambda.sink.dlq.LambdaSinkFailedDlqData; +import org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil; import org.opensearch.dataprepper.plugins.dlq.DlqProvider; import org.opensearch.dataprepper.plugins.dlq.DlqWriter; +import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.lambda.LambdaAsyncClient; +import software.amazon.awssdk.services.lambda.model.InvokeResponse; +import org.mockito.Mockito; import java.lang.reflect.Field; import java.time.Duration; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Collection; -import java.util.Map; -import java.util.Optional; -import java.util.UUID; -import java.util.concurrent.atomic.AtomicLong; +import java.util.*; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicLong; -import static org.mockito.Mockito.any; -import static org.mockito.ArgumentMatchers.anyList; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.anyString; -import static org.mockito.Mockito.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; import static org.opensearch.dataprepper.plugins.lambda.utils.LambdaTestSetupUtil.getSampleRecord; -public class LambdaSinkTest { - - private static final String TEST_BUCKET = "test"; - private static final String TEST_ROLE = "arn:aws:iam::524239988122:role/app-test"; - private static final String TEST_REGION = "ap-south-1"; - - @Mock - SinkContext sinkContext; +class LambdaSinkTest { @Mock private LambdaSinkConfig lambdaSinkConfig; @Mock + private SinkContext sinkContext; + @Mock private PluginMetrics pluginMetrics; @Mock private PluginFactory pluginFactory; - - private PluginSetting pluginSetting; - @Mock private AwsCredentialsSupplier awsCredentialsSupplier; @Mock - private DlqPushHandler dlqPushHandler; - @Mock private ExpressionEvaluator expressionEvaluator; + + // Counters @Mock private Counter numberOfRecordsSuccessCounter; @Mock private Counter numberOfRecordsFailedCounter; @Mock - private Timer lambdaLatencyMetric; + private Counter numberOfRequestsSuccessCounter; @Mock - private DistributionSummary responsePayloadMetric; + private Counter numberOfRequestsFailedCounter; + + // Timer and Summaries @Mock - private DistributionSummary lambdaPayloadMetric; + private Timer lambdaLatencyMetric; @Mock - private OutputCodec requestCodec; + private DistributionSummary requestPayloadMetric; @Mock - private Buffer currentBufferPerBatch; + private DistributionSummary responsePayloadMetric; + @Mock - private PluginModel dlqConfig; + private DlqPushHandler dlqPushHandler; @Mock private DlqProvider dlqProvider; @Mock private DlqWriter dlqWriter; - @Mock - CompletableFuture future; - @Mock - InvokeResponse response; + private PluginSetting pluginSetting; private LambdaSink lambdaSink; - @Mock - private AwsAuthenticationOptions awsAuthenticationOptions; - - @BeforeEach - public void setUp() { + void setUp() { MockitoAnnotations.openMocks(this); - // Mock PluginMetrics counters and timers - when(pluginMetrics.counter(LambdaSink.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS)).thenReturn( - numberOfRecordsSuccessCounter); - when(pluginMetrics.counter(LambdaSink.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED)).thenReturn( - numberOfRecordsFailedCounter); - when(pluginMetrics.timer(anyString())).thenReturn(lambdaLatencyMetric); + // Setup plugin metrics mocks + when(pluginMetrics.counter(LambdaSink.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS)) + .thenReturn(numberOfRecordsSuccessCounter); + when(pluginMetrics.counter(LambdaSink.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED)) + .thenReturn(numberOfRecordsFailedCounter); + when(pluginMetrics.counter(LambdaSink.NUMBER_OF_SUCCESSFUL_REQUESTS_TO_LAMBDA)) + .thenReturn(numberOfRequestsSuccessCounter); + when(pluginMetrics.counter(LambdaSink.NUMBER_OF_FAILED_REQUESTS_TO_LAMBDA)) + .thenReturn(numberOfRequestsFailedCounter); + + when(pluginMetrics.timer(LambdaSink.LAMBDA_LATENCY_METRIC)).thenReturn(lambdaLatencyMetric); + when(pluginMetrics.summary(LambdaSink.REQUEST_PAYLOAD_SIZE)).thenReturn(requestPayloadMetric); + when(pluginMetrics.summary(LambdaSink.RESPONSE_PAYLOAD_SIZE)).thenReturn(responsePayloadMetric); when(pluginMetrics.gauge(anyString(), any(AtomicLong.class))).thenReturn(new AtomicLong()); - // Mock lambdaSinkConfig + // Mock the Batch/Threshold options + final ThresholdOptions thresholdOptions = mock(ThresholdOptions.class); + when(thresholdOptions.getEventCount()).thenReturn(2); // flush after 2 events + when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse("1mb")); + when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.ofMinutes(5)); + + final BatchOptions batchOptions = mock(BatchOptions.class); + when(batchOptions.getKeyName()).thenReturn("testKey"); + when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions); + + // Setup LambdaSinkConfig when(lambdaSinkConfig.getFunctionName()).thenReturn("test-function"); when(lambdaSinkConfig.getInvocationType()).thenReturn(InvocationType.EVENT); - - // Mock BatchOptions and ThresholdOptions - BatchOptions batchOptions = mock(BatchOptions.class); - ThresholdOptions thresholdOptions = mock(ThresholdOptions.class); - when(batchOptions.getKeyName()).thenReturn("test"); when(lambdaSinkConfig.getBatchOptions()).thenReturn(batchOptions); - when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions); - when(thresholdOptions.getEventCount()).thenReturn(10); - when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse("1mb")); - when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.ofSeconds(1)); - - // Mock JsonOutputCodec - requestCodec = mock(JsonOutputCodec.class); - when(pluginFactory.loadPlugin(eq(OutputCodec.class), any(PluginSetting.class))).thenReturn( - requestCodec); - - // Initialize bufferFactory and buffer - currentBufferPerBatch = mock(Buffer.class); - when(currentBufferPerBatch.getEventCount()).thenReturn(0); - when(lambdaSinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); - when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.of("us-east-1")); - this.pluginSetting = new PluginSetting("aws_lambda", Collections.emptyMap()); - this.pluginSetting.setPipelineName(UUID.randomUUID().toString()); - this.awsAuthenticationOptions = new AwsAuthenticationOptions(); - - ClientOptions clientOptions = new ClientOptions(); + + final AwsAuthenticationOptions awsAuthOptions = mock(AwsAuthenticationOptions.class); + when(awsAuthOptions.getAwsRegion()).thenReturn(Region.US_EAST_1); + when(lambdaSinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthOptions); + + final ClientOptions clientOptions = new ClientOptions(); when(lambdaSinkConfig.getClientOptions()).thenReturn(clientOptions); - dlqConfig = mock(PluginModel.class); - dlqWriter = mock(DlqWriter.class); - dlqProvider = mock(DlqProvider.class); - when(dlqConfig.getPluginName()).thenReturn("testPlugin"); - when(dlqConfig.getPluginSettings()).thenReturn(Map.of("bucket", TEST_BUCKET, DlqPushHandler.REGION, TEST_REGION, DlqPushHandler.STS_ROLE_ARN, TEST_ROLE)); - when(lambdaSinkConfig.getDlq()).thenReturn(dlqConfig); - when(dlqProvider.getDlqWriter(anyString())).thenReturn(Optional.of(dlqWriter)); - when(pluginFactory.loadPlugin(eq(DlqProvider.class), any(PluginSetting.class))).thenReturn(dlqProvider); - this.lambdaSink = new LambdaSink(pluginSetting, lambdaSinkConfig, pluginFactory, sinkContext, - awsCredentialsSupplier, expressionEvaluator); + // For DLQ + when(lambdaSinkConfig.getDlqPluginSetting()).thenReturn(null); // default no DLQ + + // Create pluginSetting + pluginSetting = new PluginSetting("aws_lambda", new HashMap<>()); + pluginSetting.setPipelineName("testPipeline"); + + // Construct the LambdaSink + lambdaSink = new LambdaSink( + pluginSetting, + lambdaSinkConfig, + pluginFactory, + sinkContext, + awsCredentialsSupplier, + expressionEvaluator + ); + + // Insert pluginMetrics, counters, etc. +// setPrivateField(lambdaSink, "pluginMetrics", pluginMetrics); + setPrivateField(lambdaSink, "numberOfRecordsSuccessCounter", numberOfRecordsSuccessCounter); + setPrivateField(lambdaSink, "numberOfRecordsFailedCounter", numberOfRecordsFailedCounter); + setPrivateField(lambdaSink, "numberOfRequestsSuccessCounter", numberOfRequestsSuccessCounter); + setPrivateField(lambdaSink, "numberOfRequestsFailedCounter", numberOfRequestsFailedCounter); + setPrivateField(lambdaSink, "lambdaLatencyMetric", lambdaLatencyMetric); + setPrivateField(lambdaSink, "requestPayloadMetric", requestPayloadMetric); + setPrivateField(lambdaSink, "responsePayloadMetric", responsePayloadMetric); + + // Initialize the sink + lambdaSink.doInitialize(); } @Test - public void testOutput_SuccessfulProcessing() throws Exception { - try ( MockedStatic lambdaCommonHandler = Mockito.mockStatic(LambdaCommonHandler.class)){ - Event event = mock(Event.class); - Record record = new Record<>(event); - Collection> records = Collections.singletonList(record); - - future = mock(CompletableFuture.class); - response = mock(InvokeResponse.class); - when(expressionEvaluator.evaluateConditional(anyString(), eq(event))).thenReturn(true); - when(response.statusCode()).thenReturn(202); - when(response.payload()).thenReturn(SdkBytes.fromUtf8String("{\"k\":\"v\"}")); - when(future.join()).thenReturn(response); - - setPrivateField(lambdaSink, "numberOfRecordsSuccessCounter", numberOfRecordsSuccessCounter); - setPrivateField(lambdaSink, "numberOfRecordsFailedCounter", numberOfRecordsFailedCounter); - - doNothing().when(currentBufferPerBatch).addRecord(eq(record)); - when(currentBufferPerBatch.getRecords()).thenReturn(new ArrayList<>(records)); - when(currentBufferPerBatch.getEventCount()).thenReturn(1); - when(currentBufferPerBatch.getSize()).thenReturn(100L); - when(currentBufferPerBatch.getPayloadRequestSize()).thenReturn(100L); - when(currentBufferPerBatch.getDuration()).thenReturn(Duration.ofMillis(500)); - doNothing().when(lambdaLatencyMetric).record(any(Duration.class)); - doNothing().when(lambdaPayloadMetric).record(any(Double.class)); - doNothing().when(responsePayloadMetric).record(any(Double.class)); - lambdaCommonHandler.when(() -> - LambdaCommonHandler.sendRecords(anyList(), any(LambdaSinkConfig.class), any(LambdaAsyncClient.class), any(OutputCodecContext.class))).thenReturn(Map.of(currentBufferPerBatch, future)); - - lambdaCommonHandler.when(() -> - LambdaCommonHandler.isSuccess(any(InvokeResponse.class))).thenReturn(true); + void testNoFlushIfThresholdNotReached() { + // threshold=2, only pass 1 record => no flush + final List> records = Collections.singletonList( + new Record<>(mock(Event.class)) + ); + + // We mock the static method, expecting zero calls + try (MockedStatic mockedHandler = mockStatic(LambdaCommonHandler.class)) { lambdaSink.doOutput(records); - verify(numberOfRecordsSuccessCounter, times(1)).increment(1.0); + // Because threshold=2 and we only provided 1 event, no flush => 0 calls to sendRecords + mockedHandler.verify( + () -> LambdaCommonHandler.sendRecords(anyCollection(), any(), any(), any()), + never() + ); } + + // Also no success or fail increments + verify(numberOfRecordsSuccessCounter, never()).increment(anyDouble()); + verify(numberOfRecordsFailedCounter, never()).increment(anyDouble()); + verify(numberOfRequestsSuccessCounter, never()).increment(); + verify(numberOfRequestsFailedCounter, never()).increment(); } + @Test + void testFlushWhenThresholdReached() { + // threshold=2, pass 2 => flush + final List> records = List.of( + new Record<>(mock(Event.class)), + new Record<>(mock(Event.class)) + ); + + // Mock the static call to 'sendRecords(...)' to return a completed future + try (MockedStatic mockedHandler = mockStatic(LambdaCommonHandler.class)) { + // For any invocation of isSuccess(int), call the real method: + mockedHandler.when(() -> LambdaCommonHandler.isSuccess(any())) + .thenCallRealMethod(); + final InvokeResponse mockResponse = mock(InvokeResponse.class); + when(mockResponse.statusCode()).thenReturn(200); // success + when(mockResponse.payload()).thenReturn(SdkBytes.fromUtf8String("{\"msg\":\"OK\"}")); + + CompletableFuture completedFuture = mock(CompletableFuture.class); + when(completedFuture.join()).thenReturn(mockResponse); + + // We can return a single mock Buffer -> future + final Buffer mockBuffer = mock(Buffer.class); + when(mockBuffer.getRecords()).thenReturn(records); + when(mockBuffer.getEventCount()).thenReturn(2); + + Map> resultMap = + Map.of(mockBuffer, completedFuture); + + mockedHandler.when(() -> + LambdaCommonHandler.sendRecords(anyCollection(), any(), any(), any()) + ).thenReturn(resultMap); + + // ACT + lambdaSink.doOutput(records); - // Helper method to set private fields via reflection - private void setPrivateField(Object targetObject, String fieldName, Object value) { - try { - Field field = targetObject.getClass().getDeclaredField(fieldName); - field.setAccessible(true); - field.set(targetObject, value); - } catch (Exception e) { - throw new RuntimeException(e); + // VERIFY + // Because threshold=2 => flush => we should see exactly 1 call to sendRecords + mockedHandler.verify(() -> + LambdaCommonHandler.sendRecords(anyCollection(), any(), any(), any()), + times(1) + ); + + // The code should treat it as success => increment success counters + verify(numberOfRecordsSuccessCounter).increment(2.0); // 2 events + verify(numberOfRequestsSuccessCounter).increment(); + // No failures + verify(numberOfRecordsFailedCounter, never()).increment(anyDouble()); + verify(numberOfRequestsFailedCounter, never()).increment(); + } + } + + @Test + void testShutdownFlushesPartialIfAny() { + // threshold=2, pass only 1 => partial + final List> records = Collections.singletonList( + new Record<>(mock(Event.class)) + ); + + try (MockedStatic mockedHandler = mockStatic(LambdaCommonHandler.class)) { + mockedHandler.when(() -> LambdaCommonHandler.isSuccess(any())) + .thenCallRealMethod(); + // 1) doOutput => partial => no flush + lambdaSink.doOutput(records); + mockedHandler.verify( + () -> LambdaCommonHandler.sendRecords(anyCollection(), any(), any(), any()), + never() + ); + + // 2) Now shutdown => flush leftover + final InvokeResponse mockResponse = mock(InvokeResponse.class); + when(mockResponse.statusCode()).thenReturn(200); + when(mockResponse.payload()).thenReturn(SdkBytes.fromUtf8String("{\"msg\":\"OK\"}")); + + final CompletableFuture completedFuture = + CompletableFuture.completedFuture(mockResponse); + + final Buffer mockBuffer = mock(Buffer.class); + when(mockBuffer.getRecords()).thenReturn(records); + when(mockBuffer.getEventCount()).thenReturn(1); + + mockedHandler.when(() -> + LambdaCommonHandler.sendRecords(anyCollection(), any(), any(), any()) + ).thenReturn(Map.of(mockBuffer, completedFuture)); + + lambdaSink.shutdown(); + + // Now we expect 1 call on shutdown + mockedHandler.verify(() -> + LambdaCommonHandler.sendRecords(anyCollection(), any(), any(), any()), + times(1) + ); + + verify(numberOfRecordsSuccessCounter).increment(1.0); + verify(numberOfRequestsSuccessCounter).increment(); + } + } + + @Test + void testFailureDuringSendRecords() { + // pass 2 => threshold => flush => but an exception is thrown + final List> records = List.of( + new Record<>(mock(Event.class)), + new Record<>(mock(Event.class)) + ); + + try (MockedStatic mockedHandler = mockStatic(LambdaCommonHandler.class)) { + // cause the method to throw an exception + mockedHandler.when(() -> + LambdaCommonHandler.sendRecords(anyCollection(), any(), any(), any()) + ).thenThrow(new RuntimeException("Test flush error")); + + lambdaSink.doOutput(records); + + // We expect the sink to handle that failure => increment fail counters + verify(numberOfRecordsFailedCounter).increment(2.0); + verify(numberOfRequestsFailedCounter).increment(); + // No success increments + verify(numberOfRecordsSuccessCounter, never()).increment(anyDouble()); + } + } + + @Test + void testFailureInFutureJoin() { + // pass 2 => threshold => flush => future join fails + final List> records = List.of( + new Record<>(mock(Event.class)), + new Record<>(mock(Event.class)) + ); + + try (MockedStatic mockedHandler = mockStatic(LambdaCommonHandler.class)) { + final CompletableFuture failingFuture = new CompletableFuture<>(); + failingFuture.completeExceptionally(new RuntimeException("InvokeResponse error")); + + final Buffer bufferMock = mock(Buffer.class); + when(bufferMock.getRecords()).thenReturn(records); + when(bufferMock.getEventCount()).thenReturn(2); + + final Map> mapResult = + Map.of(bufferMock, failingFuture); + + mockedHandler.when(() -> + LambdaCommonHandler.sendRecords(anyCollection(), any(), any(), any()) + ).thenReturn(mapResult); + + lambdaSink.doOutput(records); + + // Because future threw an error, we expect failure counters + verify(numberOfRecordsFailedCounter).increment(2.0); + verify(numberOfRequestsFailedCounter).increment(); + verify(numberOfRecordsSuccessCounter, never()).increment(anyDouble()); } } + @Test public void testHandleFailure_WithDlq() throws Exception { Throwable throwable = new RuntimeException("Test Exception"); @@ -250,43 +361,14 @@ public void testHandleFailure_WithoutDlq() throws Exception { } - @Test - public void testOutput_ExceptionDuringProcessing() throws Exception { - try ( MockedStatic lambdaCommonHandler = Mockito.mockStatic(LambdaCommonHandler.class)) { - Event event = mock(Event.class); - Record record = new Record<>(event); - Collection> records = Collections.singletonList(record); - - future = mock(CompletableFuture.class); - response = mock(InvokeResponse.class); - when(expressionEvaluator.evaluateConditional(anyString(), eq(event))).thenReturn(true); - when(response.statusCode()).thenReturn(202); - when(response.payload()).thenReturn(SdkBytes.fromUtf8String("{\"k\":\"v\"}")); - when(future.join()).thenThrow(new RuntimeException("Test Exception")); - - setPrivateField(lambdaSink, "numberOfRecordsSuccessCounter", numberOfRecordsSuccessCounter); - setPrivateField(lambdaSink, "numberOfRecordsFailedCounter", numberOfRecordsFailedCounter); - - doNothing().when(currentBufferPerBatch).addRecord(eq(record)); - when(currentBufferPerBatch.getRecords()).thenReturn(new ArrayList<>(records)); - when(currentBufferPerBatch.getEventCount()).thenReturn(1); - when(currentBufferPerBatch.getSize()).thenReturn(100L); - when(currentBufferPerBatch.getPayloadRequestSize()).thenReturn(100L); - when(currentBufferPerBatch.getDuration()).thenReturn(Duration.ofMillis(500)); - doNothing().when(lambdaLatencyMetric).record(any(Duration.class)); - doNothing().when(lambdaPayloadMetric).record(any(Double.class)); - doNothing().when(responsePayloadMetric).record(any(Double.class)); - lambdaCommonHandler.when(() -> - LambdaCommonHandler.sendRecords(anyList(), any(LambdaSinkConfig.class), any(LambdaAsyncClient.class), any(OutputCodecContext.class))).thenReturn(Map.of(currentBufferPerBatch, future)); - - lambdaCommonHandler.when(() -> - LambdaCommonHandler.isSuccess(any(InvokeResponse.class))).thenReturn(true); - lambdaSink.doOutput(records); - - verify(numberOfRecordsFailedCounter, times(1)).increment(1); + // Utility to set private fields + private static void setPrivateField(Object target, String fieldName, Object value) { + try { + Field f = target.getClass().getDeclaredField(fieldName); + f.setAccessible(true); + f.set(target, value); + } catch (Exception e) { + throw new RuntimeException(e); } } - - - }