Skip to content

Commit

Permalink
Ensure S3InputStream is lazy
Browse files Browse the repository at this point in the history
  • Loading branch information
cyberdelia committed Nov 7, 2020
1 parent e1895a0 commit a558ad2
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 29 deletions.
58 changes: 35 additions & 23 deletions src/main/kotlin/com/lapanthere/signals/S3InputStream.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import software.amazon.awssdk.services.s3.model.GetObjectRequest
import software.amazon.awssdk.services.s3.model.HeadObjectRequest
import java.io.InputStream
import java.io.SequenceInputStream
import java.time.Instant
import java.util.Enumeration

internal val AVAILABLE_PROCESSORS = Runtime.getRuntime().availableProcessors()
Expand All @@ -35,14 +36,13 @@ public class S3InputStream(
mutator: (GetObjectRequest.Builder) -> Unit = {}
) : InputStream() {
private val scope = CoroutineScope(Dispatchers.IO)
private val parts = byteRange(
s3.headObject(
HeadObjectRequest.builder()
.bucket(bucket)
.key(key)
.build()
).get().contentLength()
)
private val s3Object = s3.headObject(
HeadObjectRequest.builder()
.bucket(bucket)
.key(key)
.build()
).get()
private val parts = byteRange(s3Object.contentLength())
private val streams = parts.mapIndexed { i, (begin, end) ->
scope.async(CoroutineName("chunk-${i + 1}"), CoroutineStart.LAZY) {
s3.getObject(
Expand All @@ -56,26 +56,38 @@ public class S3InputStream(
).await().asInputStream()
}
}.toMutableList()
private val buffer = SequenceInputStream(
object : Enumeration<InputStream> {
private val iterator = streams.iterator()
private val buffer: SequenceInputStream by lazy {
SequenceInputStream(
object : Enumeration<InputStream> {
private val iterator = streams.iterator()

override fun hasMoreElements(): Boolean {
// Starts downloading the next chunks ahead.
streams.take(parallelism).forEach { it.start() }
return iterator.hasNext()
}
override fun hasMoreElements(): Boolean {
// Starts downloading the next chunks ahead.
streams.take(parallelism).forEach { it.start() }
return iterator.hasNext()
}

override fun nextElement(): InputStream = runBlocking {
iterator.use { it.await() }
override fun nextElement(): InputStream = runBlocking {
iterator.use { it.await() }
}
}
}
)

override fun read(): Int {
return buffer.read()
)
}

public val eTag: String? = s3Object.eTag()
public val contentLength: Long? = s3Object.contentLength()
public val lastModified: Instant? = s3Object.lastModified()
public val metadata: Map<String, String> = s3Object.metadata()
public val contentType: String? = s3Object.contentType()
public val contentEncoding: String? = s3Object.contentEncoding()
public val contentDisposition: String? = s3Object.contentDisposition()
public val contentLanguage: String? = s3Object.contentLanguage()
public val versionId: String? = s3Object.versionId()
public val cacheControl: String? = s3Object.cacheControl()
public val expires: Instant? = s3Object.expires()

override fun read(): Int = buffer.read()

override fun close() {
buffer.close()
}
Expand Down
24 changes: 21 additions & 3 deletions src/test/kotlin/com/lapanthere/signals/S3InputStreamTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import java.util.concurrent.CompletableFuture
import kotlin.test.Test
import kotlin.test.assertFailsWith

class S3InputStreamTest {
internal class S3InputStreamTest {
private val bucket = "bucket"
private val key = "key"
private val s3: S3AsyncClient = mockk {
Expand All @@ -30,6 +30,8 @@ class S3InputStreamTest {
} returns CompletableFuture.completedFuture(
HeadObjectResponse.builder()
.contentLength(6_291_456)
.contentEncoding("application/json")
.eTag("d41d8cd98f00b204e9800998ecf8427e-2")
.build()
)
every {
Expand Down Expand Up @@ -59,7 +61,7 @@ class S3InputStreamTest {
}

@Test
fun testDownload() {
fun `download a file`() {
ByteArrayOutputStream().use { target ->
S3InputStream(bucket = bucket, key = key, s3 = s3).use { stream ->
stream.copyTo(target)
Expand Down Expand Up @@ -96,7 +98,23 @@ class S3InputStreamTest {
}

@Test
fun testFailure() {
fun `downloading starts when reading starts`() {
S3InputStream(bucket = bucket, key = key, s3 = s3)
verify(exactly = 1) {
s3.headObject(
HeadObjectRequest.builder()
.bucket(bucket)
.key(key)
.build()
)
}
verify(exactly = 0) {
s3.getObject(any<GetObjectRequest>(), any<ByteArrayAsyncResponseTransformer<GetObjectResponse>>())
}
}

@Test
fun `handle exception on failure`() {
every {
s3.getObject(
GetObjectRequest.builder()
Expand Down
6 changes: 3 additions & 3 deletions src/test/kotlin/com/lapanthere/signals/S3OutputStreamTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.concurrent.CompletableFuture
import kotlin.test.Test
import kotlin.test.assertFailsWith

class S3OutputStreamTest {
internal class S3OutputStreamTest {
private val uploadID = "upload-id"
private val bucket = "bucket"
private val key = "key"
Expand Down Expand Up @@ -96,7 +96,7 @@ class S3OutputStreamTest {
}

@Test
fun testUpload() {
fun `uploads a file`() {
ByteArrayInputStream(ByteArray(32)).use { target ->
S3OutputStream(bucket = bucket, key = key, s3 = s3).use { stream ->
target.copyTo(stream)
Expand Down Expand Up @@ -154,7 +154,7 @@ class S3OutputStreamTest {
}

@Test
fun testFailure() {
fun `handle exception on failure`() {
every {
s3.uploadPart(
UploadPartRequest.builder()
Expand Down

0 comments on commit a558ad2

Please sign in to comment.