Skip to content

Commit

Permalink
Added retry to Block prefetch (#229)
Browse files Browse the repository at this point in the history
  • Loading branch information
ozkoca authored Feb 25, 2025
1 parent 0cb2aad commit 3d522c2
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.io.Closeable;
import java.io.IOException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeoutException;
import lombok.Getter;
import lombok.NonNull;
import org.slf4j.Logger;
Expand Down Expand Up @@ -60,6 +61,7 @@ public class Block implements Closeable {
private static final String OPERATION_BLOCK_GET_JOIN = "block.get.join";

private static final int MAX_RETRIES = 20;
private static final long TIMEOUT_MILLIS = 120_000;

private static final Logger LOG = LoggerFactory.getLogger(Block.class);

Expand All @@ -81,7 +83,8 @@ public Block(
long start,
long end,
long generation,
@NonNull ReadMode readMode) {
@NonNull ReadMode readMode)
throws IOException {

this(objectKey, objectClient, telemetry, start, end, generation, readMode, null);
}
Expand All @@ -106,7 +109,8 @@ public Block(
long end,
long generation,
@NonNull ReadMode readMode,
StreamContext streamContext) {
StreamContext streamContext)
throws IOException {

Preconditions.checkArgument(
0 <= generation, "`generation` must be non-negative; was: %s", generation);
Expand All @@ -130,28 +134,57 @@ public Block(
}

/** Method to help construct source and data */
private void generateSourceAndData() {
GetRequest.GetRequestBuilder getRequestBuilder =
GetRequest.builder()
.s3Uri(this.objectKey.getS3URI())
.range(this.range)
.etag(this.objectKey.getEtag())
.referrer(referrer);
private void generateSourceAndData() throws IOException {
int retries = 0;
while (retries < MAX_RETRIES) {
try {
GetRequest getRequest =
GetRequest.builder()
.s3Uri(this.objectKey.getS3URI())
.range(this.range)
.etag(this.objectKey.getEtag())
.referrer(referrer)
.build();

this.source =
this.telemetry.measureCritical(
() ->
Operation.builder()
.name(OPERATION_BLOCK_GET_ASYNC)
.attribute(StreamAttributes.uri(this.objectKey.getS3URI()))
.attribute(StreamAttributes.etag(this.objectKey.getEtag()))
.attribute(StreamAttributes.range(this.range))
.attribute(StreamAttributes.generation(generation))
.build(),
objectClient.getObject(getRequest, streamContext));

// Handle IOExceptions when converting stream to byte array
this.data =
this.source.thenApply(
objectContent -> {
try {
return StreamUtils.toByteArray(objectContent, TIMEOUT_MILLIS);
} catch (IOException | TimeoutException e) {
throw new RuntimeException(
"Error while converting InputStream to byte array", e);
}
});

GetRequest getRequest = getRequestBuilder.build();
this.source =
this.telemetry.measureCritical(
() ->
Operation.builder()
.name(OPERATION_BLOCK_GET_ASYNC)
.attribute(StreamAttributes.uri(this.objectKey.getS3URI()))
.attribute(StreamAttributes.etag(this.objectKey.getEtag()))
.attribute(StreamAttributes.range(this.range))
.attribute(StreamAttributes.generation(generation))
.build(),
objectClient.getObject(getRequest, streamContext));
return; // Successfully generated source and data, exit loop
} catch (RuntimeException e) {
retries++;
LOG.warn(
"Retry {}/{} - Failed to fetch block data due to: {}",
retries,
MAX_RETRIES,
e.getMessage());

this.data = this.source.thenApply(StreamUtils::toByteArray);
if (retries >= MAX_RETRIES) {
LOG.error("Max retries reached. Unable to fetch block data.");
throw new IOException("Failed to fetch block data after retries", e);
}
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,19 @@ public synchronized void makeRangeAvailable(long pos, long len, ReadMode readMod
List<Range> missingRanges =
ioPlanner.planRead(pos, effectiveEndFinal, getLastObjectByte());
List<Range> splits = rangeOptimiser.splitRanges(missingRanges);
splits.forEach(
r -> {
Block block =
new Block(
objectKey,
objectClient,
telemetry,
r.getStart(),
r.getEnd(),
generation,
readMode,
streamContext);
blockStore.add(block);
});
for (Range r : splits) {
Block block =
new Block(
objectKey,
objectClient,
telemetry,
r.getStart(),
r.getEnd(),
generation,
readMode,
streamContext);
blockStore.add(block);
}
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,58 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.s3.analyticsaccelerator.request.ObjectContent;

/** Utility class for stream operations. */
public class StreamUtils {

private static final int BUFFER_SIZE = 8 * ONE_KB;
private static final Logger LOG = LoggerFactory.getLogger(StreamUtils.class);

/**
* Convert an InputStream from the underlying object to a byte array.
*
* @param objectContent the part of the object
* @param timeoutMs read timeout in milliseconds
* @return a byte array
*/
public static byte[] toByteArray(ObjectContent objectContent) {
public static byte[] toByteArray(ObjectContent objectContent, long timeoutMs)
throws IOException, TimeoutException {
InputStream inStream = objectContent.getStream();
ByteArrayOutputStream outStream = new ByteArrayOutputStream();
byte[] buffer = new byte[BUFFER_SIZE];

ExecutorService executorService = Executors.newSingleThreadExecutor();
Future<Void> future =
executorService.submit(
() -> {
try {
int numBytesRead;
LOG.info("Starting to read from InputStream");
while ((numBytesRead = inStream.read(buffer, 0, buffer.length)) != -1) {
outStream.write(buffer, 0, numBytesRead);
}
LOG.info("Successfully read from InputStream");
return null;
} finally {
inStream.close();
}
});

try {
int numBytesRead;
while ((numBytesRead = inStream.read(buffer, 0, buffer.length)) != -1) {
outStream.write(buffer, 0, numBytesRead);
}
inStream.close();
} catch (IOException e) {
throw new RuntimeException(e);
future.get(timeoutMs, TimeUnit.MILLISECONDS);

} catch (TimeoutException e) {
future.cancel(true);
LOG.warn("Reading from InputStream has timed out.");
throw new TimeoutException("Read operation timed out");
} catch (Exception e) {
throw new IOException("Error reading stream", e);
} finally {
executorService.shutdown();
}

return outStream.toByteArray();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ void testMakeRangeAvailableThrowsExceptionWhenEtagChanges() throws IOException {
.thenThrow(S3Exception.builder().message("PreconditionFailed").statusCode(412).build());

assertThrows(
S3Exception.class,
IOException.class,
() -> blockManager.makePositionAvailable(readAheadBytes + 1, ReadMode.SYNC));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.nio.charset.StandardCharsets;
import java.util.Optional;
import java.util.OptionalLong;
import lombok.SneakyThrows;
import org.junit.jupiter.api.Test;
import software.amazon.s3.analyticsaccelerator.TestTelemetry;
import software.amazon.s3.analyticsaccelerator.request.ObjectMetadata;
Expand All @@ -41,6 +42,7 @@ public class BlockStoreTest {
private static final ObjectKey objectKey = ObjectKey.builder().s3URI(TEST_URI).etag(ETAG).build();
private static final int OBJECT_SIZE = 100;

@SneakyThrows
@Test
public void test__blockStore__getBlockAfterAddBlock() {
// Given: empty BlockStore
Expand Down Expand Up @@ -90,6 +92,7 @@ public void test__blockStore__findNextMissingByteCorrect() throws IOException {
assertEquals(OptionalLong.empty(), blockStore.findNextMissingByte(14));
}

@SneakyThrows
@Test
public void test__blockStore__findNextAvailableByteCorrect() {
// Given: BlockStore with blocks (2,3), (5,10), (12,15)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
package software.amazon.s3.analyticsaccelerator.io.physical.data;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import lombok.SneakyThrows;
import org.junit.jupiter.api.Test;
import software.amazon.s3.analyticsaccelerator.TestTelemetry;
import software.amazon.s3.analyticsaccelerator.request.ObjectClient;
Expand Down Expand Up @@ -182,6 +184,7 @@ void testBoundaries() {
null));
}

@SneakyThrows
@Test
void testReadBoundaries() {
final String TEST_DATA = "test-data";
Expand All @@ -203,6 +206,7 @@ void testReadBoundaries() {
assertThrows(IllegalArgumentException.class, () -> block.read(b, 10, 3, 1));
}

@SneakyThrows
@Test
void testContains() {
final String TEST_DATA = "test-data";
Expand All @@ -220,6 +224,7 @@ void testContains() {
assertFalse(block.contains(TEST_DATA.length() + 1));
}

@SneakyThrows
@Test
void testContainsBoundaries() {
final String TEST_DATA = "test-data";
Expand Down Expand Up @@ -255,6 +260,7 @@ void testReadTimeoutAndRetry() throws IOException {
assertThrows(IOException.class, () -> block.read(4));
}

@SneakyThrows
@Test
void testClose() {
final String TEST_DATA = "test-data";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,40 @@
package software.amazon.s3.analyticsaccelerator.util;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.*;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.TimeoutException;
import lombok.SneakyThrows;
import org.junit.jupiter.api.Test;
import software.amazon.s3.analyticsaccelerator.request.ObjectContent;

public class StreamUtilsTest {

private static final long TIMEOUT_MILLIS = 1_000;

@SneakyThrows
@Test
public void testToByteArrayWorksWithEmptyStream() {
// Given: objectContent with an empty stream
ObjectContent objectContent =
ObjectContent.builder().stream(new ByteArrayInputStream(new byte[0])).build();

// When: toByteArray is called
byte[] buf = StreamUtils.toByteArray(objectContent);
byte[] buf = StreamUtils.toByteArray(objectContent, TIMEOUT_MILLIS);

// Then: returned byte array is empty
String content = new String(buf, StandardCharsets.UTF_8);
assertEquals(0, buf.length);
assertEquals("", content);
}

@SneakyThrows
@Test
public void testToByteArrayConvertsCorrectly() {
// Given: objectContent with "Hello World" in it
Expand All @@ -48,9 +58,33 @@ public void testToByteArrayConvertsCorrectly() {
ObjectContent objectContent = ObjectContent.builder().stream(inputStream).build();

// When: toByteArray is called
byte[] buf = StreamUtils.toByteArray(objectContent);
byte[] buf = StreamUtils.toByteArray(objectContent, TIMEOUT_MILLIS);

// Then: 'Hello World' is returned
assertEquals("Hello World", new String(buf, StandardCharsets.UTF_8));
}

@Test
void toByteArrayShouldThrowTimeoutExceptionWhenStreamReadTakesTooLong() throws Exception {
// Mock ObjectContent
ObjectContent mockContent = mock(ObjectContent.class);

// Create a slow InputStream that simulates a delay in reading
InputStream slowInputStream = mock(InputStream.class);
when(slowInputStream.read(any(byte[].class), anyInt(), anyInt()))
.thenAnswer(
invocation -> {
Thread.sleep(TIMEOUT_MILLIS + 100); // Delay beyond timeout
return -1; // Simulate end of stream
});

when(mockContent.getStream()).thenReturn(slowInputStream);

// Test the timeout behavior
assertThrows(
TimeoutException.class, () -> StreamUtils.toByteArray(mockContent, TIMEOUT_MILLIS));

// Verify the stream was accessed
verify(mockContent).getStream();
}
}

0 comments on commit 3d522c2

Please sign in to comment.