Skip to content

Commit

Permalink
Add SSE-C support to GCS file system
Browse files Browse the repository at this point in the history
  • Loading branch information
wendigo committed Sep 24, 2024
1 parent 4b0fbd8 commit b57552d
Show file tree
Hide file tree
Showing 9 changed files with 214 additions and 27 deletions.
4 changes: 2 additions & 2 deletions lib/trino-filesystem-gcs/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<excludes>
<exclude>**/TestGcsFileSystem.java</exclude>
<exclude>**/TestGcsFileSystem*.java</exclude>
</excludes>
</configuration>
</plugin>
Expand All @@ -211,7 +211,7 @@
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<includes>
<include>**/TestGcsFileSystem.java</include>
<include>**/TestGcsFileSystem*.java</include>
</includes>
</configuration>
</plugin>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import com.google.cloud.storage.Storage;
import com.google.cloud.storage.Storage.BlobListOption;
import com.google.cloud.storage.StorageBatch;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterators;
import com.google.common.util.concurrent.Futures;
Expand All @@ -33,6 +35,7 @@
import io.trino.filesystem.TrinoInputFile;
import io.trino.filesystem.TrinoOutputFile;
import io.trino.filesystem.UriLocation;
import io.trino.filesystem.encryption.EncryptionKey;

import java.io.IOException;
import java.net.URISyntaxException;
Expand All @@ -51,12 +54,16 @@
import static com.google.cloud.storage.Storage.BlobListOption.currentDirectory;
import static com.google.cloud.storage.Storage.BlobListOption.matchGlob;
import static com.google.cloud.storage.Storage.BlobListOption.pageSize;
import static com.google.cloud.storage.Storage.SignUrlOption.withExtHeaders;
import static com.google.cloud.storage.Storage.SignUrlOption.withV4Signature;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.Iterables.partition;
import static io.airlift.concurrent.MoreFutures.getFutureValue;
import static io.trino.filesystem.gcs.GcsUtils.encodedKey;
import static io.trino.filesystem.gcs.GcsUtils.getBlob;
import static io.trino.filesystem.gcs.GcsUtils.handleGcsException;
import static io.trino.filesystem.gcs.GcsUtils.keySha256Checksum;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

Expand Down Expand Up @@ -85,31 +92,63 @@ public TrinoInputFile newInputFile(Location location)
{
GcsLocation gcsLocation = new GcsLocation(location);
checkIsValidFile(gcsLocation);
return new GcsInputFile(gcsLocation, storage, readBlockSizeBytes, OptionalLong.empty(), Optional.empty());
return new GcsInputFile(gcsLocation, storage, readBlockSizeBytes, OptionalLong.empty(), Optional.empty(), Optional.empty());
}

@Override
public TrinoInputFile newEncryptedInputFile(Location location, EncryptionKey key)
{
GcsLocation gcsLocation = new GcsLocation(location);
checkIsValidFile(gcsLocation);
return new GcsInputFile(gcsLocation, storage, readBlockSizeBytes, OptionalLong.empty(), Optional.empty(), Optional.of(key));
}

@Override
public TrinoInputFile newInputFile(Location location, long length)
{
GcsLocation gcsLocation = new GcsLocation(location);
checkIsValidFile(gcsLocation);
return new GcsInputFile(gcsLocation, storage, readBlockSizeBytes, OptionalLong.of(length), Optional.empty());
return new GcsInputFile(gcsLocation, storage, readBlockSizeBytes, OptionalLong.of(length), Optional.empty(), Optional.empty());
}

@Override
public TrinoInputFile newEncryptedInputFile(Location location, long length, EncryptionKey key)
{
GcsLocation gcsLocation = new GcsLocation(location);
checkIsValidFile(gcsLocation);
return new GcsInputFile(gcsLocation, storage, readBlockSizeBytes, OptionalLong.of(length), Optional.empty(), Optional.of(key));
}

@Override
public TrinoInputFile newInputFile(Location location, long length, Instant lastModified)
{
GcsLocation gcsLocation = new GcsLocation(location);
checkIsValidFile(gcsLocation);
return new GcsInputFile(gcsLocation, storage, readBlockSizeBytes, OptionalLong.of(length), Optional.of(lastModified));
return new GcsInputFile(gcsLocation, storage, readBlockSizeBytes, OptionalLong.of(length), Optional.of(lastModified), Optional.empty());
}

@Override
public TrinoInputFile newEncryptedInputFile(Location location, long length, Instant lastModified, EncryptionKey key)
{
GcsLocation gcsLocation = new GcsLocation(location);
checkIsValidFile(gcsLocation);
return new GcsInputFile(gcsLocation, storage, readBlockSizeBytes, OptionalLong.of(length), Optional.of(lastModified), Optional.of(key));
}

@Override
public TrinoOutputFile newOutputFile(Location location)
{
GcsLocation gcsLocation = new GcsLocation(location);
checkIsValidFile(gcsLocation);
return new GcsOutputFile(gcsLocation, storage, writeBlockSizeBytes);
return new GcsOutputFile(gcsLocation, storage, writeBlockSizeBytes, Optional.empty());
}

@Override
public TrinoOutputFile newEncryptedOutputFile(Location location, EncryptionKey key)
{
GcsLocation gcsLocation = new GcsLocation(location);
checkIsValidFile(gcsLocation);
return new GcsOutputFile(gcsLocation, storage, writeBlockSizeBytes, Optional.of(key));
}

@Override
Expand Down Expand Up @@ -206,7 +245,7 @@ private Page<Blob> getPage(GcsLocation location, BlobListOption... blobListOptio
if (!location.path().isEmpty()) {
optionsBuilder.add(BlobListOption.prefix(location.path()));
}
Arrays.stream(blobListOptions).forEach(optionsBuilder::add);
optionsBuilder.addAll(Arrays.asList(blobListOptions));
optionsBuilder.add(pageSize(this.pageSize));
return storage.list(location.bucket(), optionsBuilder.toArray(BlobListOption[]::new));
}
Expand Down Expand Up @@ -292,21 +331,56 @@ public Optional<Location> createTemporaryDirectory(Location targetPath, String t
@Override
public Optional<UriLocation> preSignedUri(Location location, Duration ttl)
throws IOException
{
return preSignedUri(location, ttl, Optional.empty());
}

@Override
public Optional<UriLocation> encryptedPreSignedUri(Location location, Duration ttl, EncryptionKey key)
throws IOException
{
return preSignedUri(location, ttl, Optional.of(key));
}

private Optional<UriLocation> preSignedUri(Location location, Duration ttl, Optional<EncryptionKey> key)
throws IOException
{
GcsLocation gcsLocation = new GcsLocation(location);
BlobInfo blobInfo = BlobInfo
.newBuilder(BlobId.of(gcsLocation.bucket(), gcsLocation.path()))
.build();

URL url = storage.signUrl(blobInfo, ttl.toMillis(), MILLISECONDS, withV4Signature());
Map<String, String> extHeaders = preSignedHeaders(key);
URL url = storage.signUrl(blobInfo, ttl.toMillis(), MILLISECONDS, withV4Signature(), withExtHeaders(extHeaders));
try {
return Optional.of(new UriLocation(url.toURI(), Map.of()));
return Optional.of(new UriLocation(url.toURI(), toMultiMap(extHeaders)));
}
catch (URISyntaxException e) {
throw new IOException("Error creating URI for location: " + location, e);
}
}

private static Map<String, String> preSignedHeaders(Optional<EncryptionKey> key)
{
if (key.isEmpty()) {
return ImmutableMap.of();
}

EncryptionKey encryption = key.get();
ImmutableMap.Builder<String, String> headers = ImmutableMap.builderWithExpectedSize(3);
headers.put("x-goog-encryption-algorithm", encryption.algorithm());
headers.put("x-goog-encryption-key", encodedKey(encryption));
headers.put("x-goog-encryption-key-sha256", keySha256Checksum(encryption));
return headers.buildOrThrow();
}

private Map<String, List<String>> toMultiMap(Map<String, String> extHeaders)
{
return extHeaders.entrySet().stream()
.map(entry -> Map.entry(entry.getKey(), ImmutableList.of(entry.getValue())))
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
}

@SuppressWarnings("ResultOfObjectAllocationIgnored")
private static void validateGcsLocation(Location location)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@
import com.google.cloud.ReadChannel;
import com.google.cloud.storage.Blob;
import com.google.cloud.storage.Storage;
import com.google.cloud.storage.Storage.BlobGetOption;
import io.trino.filesystem.TrinoInput;
import io.trino.filesystem.encryption.EncryptionKey;

import java.io.EOFException;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Optional;
import java.util.OptionalLong;

import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.filesystem.gcs.GcsUtils.encodedKey;
import static io.trino.filesystem.gcs.GcsUtils.getBlobOrThrow;
import static io.trino.filesystem.gcs.GcsUtils.getReadChannel;
import static io.trino.filesystem.gcs.GcsUtils.handleGcsException;
Expand All @@ -37,15 +41,17 @@ final class GcsInput
private final Storage storage;
private final int readBlockSize;
private final OptionalLong length;
private final Optional<EncryptionKey> key;
private boolean closed;

public GcsInput(GcsLocation location, Storage storage, int readBlockSize, OptionalLong length)
public GcsInput(GcsLocation location, Storage storage, int readBlockSize, OptionalLong length, Optional<EncryptionKey> key)
{
this.location = requireNonNull(location, "location is null");
this.storage = requireNonNull(storage, "storage is null");
checkArgument(readBlockSize >= 0, "readBlockSize is negative");
this.readBlockSize = readBlockSize;
this.length = requireNonNull(length, "length is null");
this.key = requireNonNull(key, "key is null");
}

@Override
Expand All @@ -61,7 +67,7 @@ public void readFully(long position, byte[] buffer, int bufferOffset, int buffer
return;
}

try (ReadChannel readChannel = getReadChannel(getBlobOrThrow(storage, location), location, position, readBlockSize, length)) {
try (ReadChannel readChannel = getReadChannel(getBlobOrThrow(storage, location, blobGetOptions()), location, position, readBlockSize, length, key)) {
int readSize = readNBytes(readChannel, buffer, bufferOffset, bufferLength);
if (readSize != bufferLength) {
throw new EOFException("End of file reached before reading fully: " + location);
Expand All @@ -78,9 +84,9 @@ public int readTail(byte[] buffer, int bufferOffset, int bufferLength)
{
ensureOpen();
checkFromIndexSize(bufferOffset, bufferLength, buffer.length);
Blob blob = getBlobOrThrow(storage, location);
Blob blob = getBlobOrThrow(storage, location, blobGetOptions());
long offset = Math.max(0, length.orElse(blob.getSize()) - bufferLength);
try (ReadChannel readChannel = getReadChannel(blob, location, offset, readBlockSize, length)) {
try (ReadChannel readChannel = getReadChannel(blob, location, offset, readBlockSize, length, key)) {
return readNBytes(readChannel, buffer, bufferOffset, bufferLength);
}
catch (RuntimeException e) {
Expand Down Expand Up @@ -122,4 +128,11 @@ private int readNBytes(ReadChannel readChannel, byte[] buffer, int bufferOffset,
}
return readSize;
}

private BlobGetOption[] blobGetOptions()
{
return key
.map(encryption -> new BlobGetOption[] {BlobGetOption.decryptionKey(encodedKey(encryption))})
.orElseGet(() -> new BlobGetOption[0]);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
import io.trino.filesystem.TrinoInput;
import io.trino.filesystem.TrinoInputFile;
import io.trino.filesystem.TrinoInputStream;
import io.trino.filesystem.encryption.EncryptionKey;

import java.io.IOException;
import java.time.Instant;
import java.util.Optional;
import java.util.OptionalLong;

import static io.trino.filesystem.gcs.GcsUtils.encodedKey;
import static io.trino.filesystem.gcs.GcsUtils.getBlob;
import static io.trino.filesystem.gcs.GcsUtils.getBlobOrThrow;
import static io.trino.filesystem.gcs.GcsUtils.handleGcsException;
Expand All @@ -37,33 +39,35 @@ public class GcsInputFile
private final Storage storage;
private final int readBlockSize;
private final OptionalLong predeclaredLength;
private final Optional<EncryptionKey> key;
private OptionalLong length;
private Optional<Instant> lastModified;

public GcsInputFile(GcsLocation location, Storage storage, int readBockSize, OptionalLong predeclaredLength, Optional<Instant> lastModified)
public GcsInputFile(GcsLocation location, Storage storage, int readBockSize, OptionalLong predeclaredLength, Optional<Instant> lastModified, Optional<EncryptionKey> key)
{
this.location = requireNonNull(location, "location is null");
this.storage = requireNonNull(storage, "storage is null");
this.readBlockSize = readBockSize;
this.predeclaredLength = requireNonNull(predeclaredLength, "length is null");
this.length = OptionalLong.empty();
this.lastModified = requireNonNull(lastModified, "lastModified is null");
this.key = requireNonNull(key, "key is null");
}

@Override
public TrinoInput newInput()
throws IOException
{
// Note: Only pass predeclared length, to keep the contract of TrinoFileSystem.newInputFile
return new GcsInput(location, storage, readBlockSize, predeclaredLength);
return new GcsInput(location, storage, readBlockSize, predeclaredLength, key);
}

@Override
public TrinoInputStream newStream()
throws IOException
{
Blob blob = getBlobOrThrow(storage, location);
return new GcsInputStream(location, blob, readBlockSize, predeclaredLength);
Blob blob = getBlobOrThrow(storage, location, blobGetOptions());
return new GcsInputStream(location, blob, readBlockSize, predeclaredLength, key);
}

@Override
Expand Down Expand Up @@ -93,7 +97,7 @@ public Instant lastModified()
public boolean exists()
throws IOException
{
Optional<Blob> blob = getBlob(storage, location);
Optional<Blob> blob = getBlob(storage, location, blobGetOptions());
return blob.isPresent() && blob.get().exists();
}

Expand All @@ -106,7 +110,7 @@ public Location location()
private void loadProperties()
throws IOException
{
Blob blob = getBlobOrThrow(storage, location);
Blob blob = getBlobOrThrow(storage, location, blobGetOptions());
try {
length = OptionalLong.of(blob.getSize());
if (lastModified.isEmpty()) {
Expand All @@ -117,4 +121,11 @@ private void loadProperties()
throw handleGcsException(e, "fetching properties for file", location);
}
}

private Storage.BlobGetOption[] blobGetOptions()
{
return key
.map(encryption -> new Storage.BlobGetOption[]{Storage.BlobGetOption.decryptionKey(encodedKey(encryption))})
.orElseGet(() -> new Storage.BlobGetOption[0]);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
import com.google.cloud.storage.Blob;
import com.google.common.primitives.Ints;
import io.trino.filesystem.TrinoInputStream;
import io.trino.filesystem.encryption.EncryptionKey;

import java.io.EOFException;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Optional;
import java.util.OptionalLong;

import static io.trino.filesystem.gcs.GcsUtils.getReadChannel;
Expand All @@ -37,21 +39,23 @@ public class GcsInputStream
private final int readBlockSizeBytes;
private final long fileSize;
private final OptionalLong predeclaredLength;
private final Optional<EncryptionKey> key;
private ReadChannel readChannel;
// Used for read(). Similar to sun.nio.ch.ChannelInputStream
private final ByteBuffer readBuffer = ByteBuffer.allocate(1);
private long currentPosition;
private long nextPosition;
private boolean closed;

public GcsInputStream(GcsLocation location, Blob blob, int readBlockSizeBytes, OptionalLong predeclaredLength)
public GcsInputStream(GcsLocation location, Blob blob, int readBlockSizeBytes, OptionalLong predeclaredLength, Optional<EncryptionKey> key)
throws IOException
{
this.location = requireNonNull(location, "location is null");
this.blob = requireNonNull(blob, "blob is null");
this.readBlockSizeBytes = readBlockSizeBytes;
this.predeclaredLength = requireNonNull(predeclaredLength, "predeclaredLength is null");
this.fileSize = predeclaredLength.orElse(blob.getSize());
this.key = requireNonNull(key, "key is null");
openStream();
}

Expand Down Expand Up @@ -182,7 +186,7 @@ private void openStream()
throws IOException
{
try {
this.readChannel = getReadChannel(blob, location, 0L, readBlockSizeBytes, predeclaredLength);
this.readChannel = getReadChannel(blob, location, 0L, readBlockSizeBytes, predeclaredLength, key);
}
catch (RuntimeException e) {
throw handleGcsException(e, "reading file", location);
Expand Down
Loading

0 comments on commit b57552d

Please sign in to comment.