Skip to content

Commit

Permalink
Add retryCondidition to lambda Client
Browse files Browse the repository at this point in the history
Signed-off-by: Srikanth Govindarajan <[email protected]>
  • Loading branch information
srikanthjg committed Jan 22, 2025
1 parent 8a84f60 commit 00f9851
Show file tree
Hide file tree
Showing 10 changed files with 286 additions and 203 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,28 @@
import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType;
import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions;
import org.opensearch.dataprepper.plugins.lambda.common.util.CustomLambdaRetryCondition;
import org.opensearch.dataprepper.plugins.lambda.utils.CountingHttpClient;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.lambda.LambdaAsyncClient;
import software.amazon.awssdk.services.lambda.model.InvokeRequest;
import software.amazon.awssdk.services.lambda.model.InvokeResponse;
import software.amazon.awssdk.services.lambda.model.TooManyRequestsException;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand Down Expand Up @@ -94,9 +105,12 @@ private LambdaProcessor createObjectUnderTest(LambdaProcessorConfig processorCon

@BeforeEach
public void setup() {
lambdaRegion = System.getProperty("tests.lambda.processor.region");
functionName = System.getProperty("tests.lambda.processor.functionName");
role = System.getProperty("tests.lambda.processor.sts_role_arn");
// lambdaRegion = System.getProperty("tests.lambda.processor.region");
// functionName = System.getProperty("tests.lambda.processor.functionName");
// role = System.getProperty("tests.lambda.processor.sts_role_arn");
lambdaRegion = "us-west-2";
functionName = "lambdaNoReturn";
role = "arn:aws:iam::176893235612:role/osis-s3-opensearch-role";

pluginMetrics = mock(PluginMetrics.class);
pluginSetting = mock(PluginSetting.class);
Expand Down Expand Up @@ -373,4 +387,81 @@ private List<Record<Event>> createRecords(int numRecords) {
}
return records;
}

/*
* For this test, set concurrency limit to 1
*/
@Test
void testTooManyRequestsExceptionWithCustomRetryCondition() {
//Note lambda function for this test looks like this:
/*def lambda_handler(event, context):
# Simulate a slow operation so that
# if concurrency = 1, multiple parallel invocations
# will result in TooManyRequestsException for the second+ invocation.
time.sleep(10)
# Return a simple success response
return {
"statusCode": 200,
"body": "Hello from concurrency-limited Lambda!"
}
*/

// Wrap the default HTTP client to count requests
CountingHttpClient countingHttpClient = new CountingHttpClient(
NettyNioAsyncHttpClient.builder().build()
);

// Configure a custom retry policy with 3 retries and your custom condition
RetryPolicy retryPolicy = RetryPolicy.builder()
.numRetries(3)
.retryCondition(new CustomLambdaRetryCondition())
.build();

// Build the real Lambda client
LambdaAsyncClient client = LambdaAsyncClient.builder()
.overrideConfiguration(
ClientOverrideConfiguration.builder()
.retryPolicy(retryPolicy)
.build()
)
.region(Region.of(lambdaRegion))
.httpClient(countingHttpClient)
.build();

// Parallel invocations to force concurrency=1 to throw TooManyRequestsException
int parallelInvocations = 10;
CompletableFuture<?>[] futures = new CompletableFuture[parallelInvocations];
for (int i = 0; i < parallelInvocations; i++) {
InvokeRequest request = InvokeRequest.builder()
.functionName(functionName)
.build();

futures[i] = client.invoke(request);
}

// 5) Wait for all to complete
CompletableFuture.allOf(futures).join();

// 6) Check how many had TooManyRequestsException
long tooManyRequestsCount = Arrays.stream(futures)
.filter(f -> {
try {
f.join();
return false; // no error => no TMR
} catch (CompletionException e) {
return e.getCause() instanceof TooManyRequestsException;
}
})
.count();

// 7) Observe how many total network requests occurred (including SDK retries)
int totalRequests = countingHttpClient.getRequestCount();
System.out.println("Total network requests (including retries): " + totalRequests);

// Optionally: If you want to confirm the EXACT number,
// this might vary depending on how many parallel calls and how your TMR throttles them.
// For example, if all 5 calls are blocked, you might see 5*(numRetries + 1) in worst case.
assertTrue(totalRequests >= parallelInvocations,
"Should be at least one request per initial invocation, plus retries.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions;
import org.opensearch.dataprepper.plugins.lambda.common.util.CustomLambdaRetryCondition;
import org.opensearch.dataprepper.plugins.metricpublisher.MicrometerMetricPublisher;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
Expand Down Expand Up @@ -48,13 +49,14 @@ private static ClientOverrideConfiguration createOverrideConfiguration(
.maxBackoffTime(clientOptions.getMaxBackoff())
.build();

final RetryPolicy retryPolicy = RetryPolicy.builder()
final RetryPolicy customRetryPolicy = RetryPolicy.builder()
.retryCondition(new CustomLambdaRetryCondition())
.numRetries(clientOptions.getMaxConnectionRetries())
.backoffStrategy(backoffStrategy)
.build();

return ClientOverrideConfiguration.builder()
.retryPolicy(retryPolicy)
.retryPolicy(customRetryPolicy)
.addMetricPublisher(new MicrometerMetricPublisher(awsSdkMetrics))
.apiCallTimeout(clientOptions.getApiCallTimeout())
.build();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package org.opensearch.dataprepper.plugins.lambda.common.util;

import software.amazon.awssdk.core.retry.conditions.RetryCondition;
import software.amazon.awssdk.core.retry.RetryPolicyContext;

public class CustomLambdaRetryCondition implements RetryCondition {

@Override
public boolean shouldRetry(RetryPolicyContext context) {
Throwable exception = context.exception();
if (exception != null) {
return LambdaRetryStrategy.isRetryableException(exception);
}

return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,7 @@ private LambdaRetryStrategy() {
)
);


public static boolean isRetryable(final InvokeResponse response) {
if(response == null) return false;
int statusCode = response.statusCode();
public static boolean isRetryable(final int statusCode) {
return TIMEOUT_ERRORS.contains(statusCode) || (statusCode >= 500 && statusCode < 600);
}

Expand Down Expand Up @@ -120,51 +117,5 @@ public static boolean isTimeoutError(final InvokeResponse response) {
return TIMEOUT_ERRORS.contains(response.statusCode());
}

public static InvokeResponse retryOrFail(
final LambdaAsyncClient lambdaAsyncClient,
final Buffer buffer,
final LambdaCommonConfig config,
final InvokeResponse previousResponse,
final Logger LOG
) {
int maxRetries = config.getClientOptions().getMaxConnectionRetries();
Duration backoff = config.getClientOptions().getBaseDelay();

int attempt = 1;
InvokeResponse response = previousResponse;

do{
LOG.warn("Retrying Lambda invocation attempt {} of {} after {} ms backoff",
attempt, maxRetries, backoff);
try {
// Sleep for backoff
Thread.sleep(backoff.toMillis());

// Re-invoke Lambda with the same payload
InvokeRequest requestPayload = buffer.getRequestPayload(
config.getFunctionName(),
config.getInvocationType().getAwsLambdaValue()
);
// Do a synchronous call.
response = lambdaAsyncClient.invoke(requestPayload).join();

if (isSuccess(response)) {
LOG.info("Retry attempt {} succeeded with status code {}", attempt, response.statusCode());
return response;
} else{
throw new RuntimeException();
}
} catch (Exception e) {
LOG.error("Failed to invoke failed with exception {} in attempt {}", e.getMessage(), attempt);
if(!isRetryable(response)){
throw new RuntimeException("Failed to invoke failed",e);
}
}
attempt++;
} while(attempt <= maxRetries && isRetryable(response));

return response;
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer;
import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory;
import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions;
import org.opensearch.dataprepper.plugins.lambda.common.util.LambdaRetryStrategy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.core.SdkBytes;
Expand Down Expand Up @@ -177,26 +176,14 @@ public Collection<Record<Event>> doExecute(Collection<Record<Event>> records) {
Buffer inputBuffer = entry.getKey();
try {
InvokeResponse response = future.join();

// If this response has a failure is retryable, do a direct retry
if (!isSuccess(response) && LambdaRetryStrategy.isRetryable(response)){
response = LambdaRetryStrategy.retryOrFail(
lambdaAsyncClient,
inputBuffer,
lambdaProcessorConfig,
response,
LOG
);
}
if(response == null || !isSuccess(response)) {
numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount());
resultRecords.addAll(addFailureTags(inputBuffer.getRecords()));
return resultRecords;
}

Duration latency = inputBuffer.stopLatencyWatch();
lambdaLatencyMetric.record(latency.toMillis(), TimeUnit.MILLISECONDS);
requestPayloadMetric.record(inputBuffer.getPayloadRequestSize());
if (!isSuccess(response)) {
String errorMessage = String.format("Lambda invoke failed with status code %s error %s ",
response.statusCode(), response.payload().asUtf8String());
throw new RuntimeException(errorMessage);
}

resultRecords.addAll(convertLambdaResponseToEvent(inputBuffer, response));
numberOfRecordsSuccessCounter.increment(inputBuffer.getEventCount());
Expand All @@ -207,24 +194,10 @@ public Collection<Record<Event>> doExecute(Collection<Record<Event>> records) {

} catch (Exception e) {
LOG.error(NOISY, e.getMessage(), e);
InvokeResponse response = null;
if (LambdaRetryStrategy.isRetryableException(e)){
response = LambdaRetryStrategy.retryOrFail(
lambdaAsyncClient,
inputBuffer,
lambdaProcessorConfig,
null,
LOG
);
String errorMessage = String.format("Lambda invoke failed with status code %s error %s. Will be Retrying the request ",
response.statusCode(), response.payload().asUtf8String());
LOG.error(NOISY, e.getMessage(), e);
}
if(response == null || !isSuccess(response)) {
/* fall through */
numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount());
resultRecords.addAll(addFailureTags(inputBuffer.getRecords()));
}
/* fall through */
numberOfRecordsFailedCounter.increment(inputBuffer.getEventCount());
numberOfRequestsFailedCounter.increment();
resultRecords.addAll(addFailureTags(inputBuffer.getRecords()));
}
}
return resultRecords;
Expand Down Expand Up @@ -294,4 +267,4 @@ public boolean isReadyForShutdown() {
public void shutdown() {
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,40 @@
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions;
import org.opensearch.dataprepper.plugins.lambda.common.util.CustomLambdaRetryCondition;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.core.retry.RetryPolicyContext;
import software.amazon.awssdk.core.retry.conditions.RetryCondition;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.lambda.LambdaAsyncClient;
import software.amazon.awssdk.services.lambda.model.InvokeRequest;
import software.amazon.awssdk.services.lambda.model.InvokeResponse;
import software.amazon.awssdk.services.lambda.model.TooManyRequestsException;

import java.util.HashMap;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.spy;

@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT)
class LambdaClientFactoryTest {

@Mock
Expand Down Expand Up @@ -75,4 +94,42 @@ void testCreateAsyncLambdaClientOverrideConfiguration() {
assertNotNull(overrideConfig.metricPublishers());
assertFalse(overrideConfig.metricPublishers().isEmpty());
}

@Test
void testCustomRetryConditionWorks_withSpyOrRetryCondition() {
// Arrange
CustomLambdaRetryCondition customRetryCondition = new CustomLambdaRetryCondition();
RetryCondition spyRetryCondition = spy(customRetryCondition);

LambdaAsyncClient lambdaClient = LambdaAsyncClient.builder()
.httpClient(NettyNioAsyncHttpClient.builder().build())
.overrideConfiguration(ClientOverrideConfiguration.builder()
.retryPolicy(RetryPolicy.builder()
// Even though we set numRetries=3,
// the SDK may only call our custom condition once
.numRetries(3)
.retryCondition(spyRetryCondition)
.build())
.build())
.region(Region.US_EAST_1)
.build();

// Simulate a retryable exception
InvokeRequest request = InvokeRequest.builder()
.functionName("test-function")
.build();

// Act
try {
CompletableFuture<InvokeResponse> futureResponse = lambdaClient.invoke(request);
futureResponse.join(); // Force completion
} catch (Exception e) {
}

// Assert
// The AWS SDK's internal 'OrRetryCondition' may only call our condition once
verify(spyRetryCondition, atLeastOnce())
.shouldRetry(any(RetryPolicyContext.class));
}

}
Loading

0 comments on commit 00f9851

Please sign in to comment.