diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java index ba469054a..f2d7abddb 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/DeviceController.java @@ -17,11 +17,13 @@ import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.responses.ApiResponse; import java.time.Duration; +import java.time.Instant; import java.util.EnumMap; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; +import java.util.concurrent.CompletionStage; import java.util.concurrent.atomic.AtomicInteger; import javax.annotation.Nullable; import javax.validation.Valid; @@ -58,7 +60,9 @@ import org.whispersystems.textsecuregcm.entities.LinkDeviceRequest; import org.whispersystems.textsecuregcm.entities.PreKeySignatureValidator; import org.whispersystems.textsecuregcm.entities.ProvisioningMessage; +import org.whispersystems.textsecuregcm.entities.RemoteAttachment; import org.whispersystems.textsecuregcm.entities.SetPublicKeyRequest; +import org.whispersystems.textsecuregcm.entities.TransferArchiveUploadedRequest; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.limits.RateLimiters; import org.whispersystems.textsecuregcm.metrics.MetricsUtil; @@ -318,9 +322,9 @@ public LinkDeviceResponse linkDevice(@HeaderParam(HttpHeaders.AUTHORIZATION) Bas Waits for a new device to be linked to an account and returns basic information about the new device when available. """) - @ApiResponse(responseCode = "200", description = "The specified was linked to an account", + @ApiResponse(responseCode = "200", description = "A device was linked to an account using the token associated with the given token identifier", content = @Content(schema = @Schema(implementation = DeviceInfo.class))) - @ApiResponse(responseCode = "204", description = "No device was linked to the account before the call completed") + @ApiResponse(responseCode = "204", description = "No device was linked to the account before the call completed; clients may repeat the call to continue waiting") @ApiResponse(responseCode = "400", description = "The given token identifier or timeout was invalid") @ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay") public CompletableFuture waitForLinkedDevice( @@ -432,4 +436,66 @@ private static boolean isCapabilityDowngrade(Account account, DeviceCapabilities isDowngrade |= account.isVersionedExpirationTimerSupported() && !capabilities.versionedExpirationTimer(); return isDowngrade; } + + @PUT + @Consumes(MediaType.APPLICATION_JSON) + @Produces(MediaType.APPLICATION_JSON) + @Path("/transfer_archive") + @Operation( + summary = "Signals that a transfer archive has been uploaded for a specific linked device", + description = """ + Signals that a transfer archive has been uploaded for a specific linked device. Devices waiting via the "wait + for transfer archive" endpoint will be notified that the new archive is available. + """) + @ApiResponse(responseCode = "204", description = "Success") + @ApiResponse(responseCode = "422", description = "The request object could not be parsed or was otherwise invalid") + @ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay") + public CompletionStage recordTransferArchiveUploaded(@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice, + @NotNull @Valid final TransferArchiveUploadedRequest transferArchiveUploadedRequest) { + + return rateLimiters.getUploadTransferArchiveLimiter().validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI)) + .thenCompose(ignored -> accounts.recordTransferArchiveUpload(authenticatedDevice.getAccount(), + transferArchiveUploadedRequest.destinationDeviceId(), + Instant.ofEpochMilli(transferArchiveUploadedRequest.destinationDeviceCreated()), + transferArchiveUploadedRequest.transferArchive())); + } + + @GET + @Produces(MediaType.APPLICATION_JSON) + @Path("/transfer_archive") + @Operation(summary = "Wait for a new transfer archive to be uploaded", + description = """ + Waits for a new transfer archive to be uploaded for the authenticated device and returns the location of the + archive when available. + """) + @ApiResponse(responseCode = "200", description = "A new transfer archive was uploaded for the authenticated device", + content = @Content(schema = @Schema(implementation = RemoteAttachment.class))) + @ApiResponse(responseCode = "204", description = "No transfer archive was uploaded before the call completed; clients may repeat the call to continue waiting") + @ApiResponse(responseCode = "400", description = "The given timeout was invalid") + @ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay") + public CompletionStage waitForTransferArchive(@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice, + + @QueryParam("timeout") + @DefaultValue("30") + @Min(1) + @Max(3600) + @Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED, + minimum = "1", + maximum = "3600", + description = """ + The amount of time (in seconds) to wait for a response. If a transfer archive for the authenticated + device is not available within the given amount of time, this endpoint will return a status of HTTP/204. + """) final int timeoutSeconds) { + + final String rateLimiterKey = authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI) + + ":" + authenticatedDevice.getAuthenticatedDevice().getId(); + + return rateLimiters.getWaitForTransferArchiveLimiter().validateAsync(rateLimiterKey) + .thenCompose(ignored -> accounts.waitForTransferArchive(authenticatedDevice.getAccount(), + authenticatedDevice.getAuthenticatedDevice(), + Duration.ofSeconds(timeoutSeconds))) + .thenApply(maybeTransferArchive -> maybeTransferArchive + .map(transferArchive -> Response.status(Response.Status.OK).entity(transferArchive).build()) + .orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build())); + } } diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/entities/TransferArchiveUploadedRequest.java b/service/src/main/java/org/whispersystems/textsecuregcm/entities/TransferArchiveUploadedRequest.java new file mode 100644 index 000000000..60f27be89 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/entities/TransferArchiveUploadedRequest.java @@ -0,0 +1,27 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.entities; + +import io.swagger.v3.oas.annotations.media.Schema; +import org.whispersystems.textsecuregcm.storage.Device; +import javax.validation.Valid; +import javax.validation.constraints.Max; +import javax.validation.constraints.Min; +import javax.validation.constraints.Positive; + +public record TransferArchiveUploadedRequest(@Min(1) + @Max(Device.MAXIMUM_DEVICE_ID) + @Schema(description = "The ID of the device for which the transfer archive has been prepared") + byte destinationDeviceId, + + @Positive + @Schema(description = "The timestamp, in milliseconds since the epoch, at which the destination device was created") + long destinationDeviceCreated, + + @Schema(description = "The location of the transfer archive") + @Valid + RemoteAttachment transferArchive) { +} diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java index 02bb63cb5..aa07085e6 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/limits/RateLimiters.java @@ -51,6 +51,8 @@ public enum For implements RateLimiterDescriptor { KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", true, new RateLimiterConfig(100, Duration.ofSeconds(15))), KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", true, new RateLimiterConfig(100, Duration.ofSeconds(15))), WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", true, new RateLimiterConfig(10, Duration.ofSeconds(30))), + UPLOAD_TRANSFER_ARCHIVE("uploadTransferArchive", true, new RateLimiterConfig(10, Duration.ofMinutes(1))), + WAIT_FOR_TRANSFER_ARCHIVE("waitForTransferArchive", true, new RateLimiterConfig(10, Duration.ofSeconds(30))), ; private final String id; @@ -210,4 +212,12 @@ public RateLimiter getStoriesLimiter() { public RateLimiter getWaitForLinkedDeviceLimiter() { return forDescriptor(For.WAIT_FOR_LINKED_DEVICE); } + + public RateLimiter getUploadTransferArchiveLimiter() { + return forDescriptor(For.UPLOAD_TRANSFER_ARCHIVE); + } + + public RateLimiter getWaitForTransferArchiveLimiter() { + return forDescriptor(For.WAIT_FOR_TRANSFER_ARCHIVE); + } } diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java index 319e64df7..adf75bed0 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/DeviceControllerTest.java @@ -20,7 +20,6 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import com.amazonaws.util.Base64; import com.google.common.net.HttpHeaders; import io.dropwizard.auth.AuthValueFactoryProvider; import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; @@ -28,6 +27,9 @@ import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands; import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Base64; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -64,7 +66,9 @@ import org.whispersystems.textsecuregcm.entities.GcmRegistrationId; import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey; import org.whispersystems.textsecuregcm.entities.LinkDeviceRequest; +import org.whispersystems.textsecuregcm.entities.RemoteAttachment; import org.whispersystems.textsecuregcm.entities.SetPublicKeyRequest; +import org.whispersystems.textsecuregcm.entities.TransferArchiveUploadedRequest; import org.whispersystems.textsecuregcm.identity.IdentityType; import org.whispersystems.textsecuregcm.limits.RateLimiter; import org.whispersystems.textsecuregcm.limits.RateLimiters; @@ -132,6 +136,8 @@ void setup() { when(rateLimiters.getAllocateDeviceLimiter()).thenReturn(rateLimiter); when(rateLimiters.getVerifyDeviceLimiter()).thenReturn(rateLimiter); when(rateLimiters.getWaitForLinkedDeviceLimiter()).thenReturn(rateLimiter); + when(rateLimiters.getUploadTransferArchiveLimiter()).thenReturn(rateLimiter); + when(rateLimiters.getWaitForTransferArchiveLimiter()).thenReturn(rateLimiter); when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID); @@ -957,7 +963,7 @@ void waitForLinkedDevice() { System.currentTimeMillis(), System.currentTimeMillis()); - final String tokenIdentifier = Base64.encodeAsString(new byte[32]); + final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]); when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any())) .thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo))); @@ -980,7 +986,7 @@ void waitForLinkedDevice() { @Test void waitForLinkedDeviceNoDeviceLinked() { - final String tokenIdentifier = Base64.encodeAsString(new byte[32]); + final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]); when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any())) .thenReturn(CompletableFuture.completedFuture(Optional.empty())); @@ -997,7 +1003,7 @@ void waitForLinkedDeviceNoDeviceLinked() { @Test void waitForLinkedDeviceBadTokenIdentifier() { - final String tokenIdentifier = Base64.encodeAsString(new byte[32]); + final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]); when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any())) .thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException())); @@ -1015,7 +1021,7 @@ void waitForLinkedDeviceBadTokenIdentifier() { @ParameterizedTest @MethodSource void waitForLinkedDeviceBadTimeout(final int timeoutSeconds) { - final String tokenIdentifier = Base64.encodeAsString(new byte[32]); + final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]); try (final Response response = resources.getJerseyTest() .target("/v1/devices/wait_for_linked_device/" + tokenIdentifier) @@ -1052,7 +1058,7 @@ private static List waitForLinkedDeviceBadTokenIdentifierLength() { @Test void waitForLinkedDeviceRateLimited() throws RateLimitExceededException { - final String tokenIdentifier = Base64.encodeAsString(new byte[32]); + final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]); doThrow(new RateLimitExceededException(null)).when(rateLimiter).validate(AuthHelper.VALID_UUID); @@ -1065,4 +1071,158 @@ void waitForLinkedDeviceRateLimited() throws RateLimitExceededException { assertEquals(429, response.getStatus()); } } + + @Test + void recordTransferArchiveUploaded() { + final byte deviceId = Device.PRIMARY_ID + 1; + final Instant deviceCreated = Instant.now().truncatedTo(ChronoUnit.MILLIS); + final RemoteAttachment transferArchive = + new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8))); + + when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null)); + when(accountsManager.recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferArchive)) + .thenReturn(CompletableFuture.completedFuture(null)); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/transfer_archive") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(Entity.entity(new TransferArchiveUploadedRequest(deviceId, deviceCreated.toEpochMilli(), transferArchive), + MediaType.APPLICATION_JSON_TYPE))) { + + assertEquals(204, response.getStatus()); + + verify(accountsManager) + .recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferArchive); + } + } + + @ParameterizedTest + @MethodSource + void recordTransferArchiveUploadedBadRequest(final TransferArchiveUploadedRequest request) { + when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null)); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/transfer_archive") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) { + + assertEquals(422, response.getStatus()); + + verify(accountsManager, never()) + .recordTransferArchiveUpload(any(), anyByte(), any(), any()); + } + } + + @SuppressWarnings("DataFlowIssue") + private static List recordTransferArchiveUploadedBadRequest() { + final RemoteAttachment validTransferArchive = + new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("archive".getBytes(StandardCharsets.UTF_8))); + + return List.of( + // Invalid device ID + new TransferArchiveUploadedRequest((byte) -1, System.currentTimeMillis(), validTransferArchive), + + // Invalid "created at" timestamp + new TransferArchiveUploadedRequest(Device.PRIMARY_ID, -1, validTransferArchive), + + // Missing CDN number + new TransferArchiveUploadedRequest(Device.PRIMARY_ID, System.currentTimeMillis(), + new RemoteAttachment(null, Base64.getUrlEncoder().encodeToString("archive".getBytes(StandardCharsets.UTF_8)))), + + // Bad attachment key + new TransferArchiveUploadedRequest(Device.PRIMARY_ID, System.currentTimeMillis(), + new RemoteAttachment(3, "This is not a valid base64 string")) + ); + } + + @Test + void recordTransferArchiveRateLimited() { + when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)) + .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null))); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/transfer_archive") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .put(Entity.entity(new TransferArchiveUploadedRequest(Device.PRIMARY_ID, System.currentTimeMillis(), + new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)))), + MediaType.APPLICATION_JSON_TYPE))) { + + assertEquals(429, response.getStatus()); + + verify(accountsManager, never()) + .recordTransferArchiveUpload(any(), anyByte(), any(), any()); + } + } + + @Test + void waitForTransferArchive() { + final RemoteAttachment transferArchive = + new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8))); + + when(rateLimiter.validateAsync(anyString())).thenReturn(CompletableFuture.completedFuture(null)); + when(accountsManager.waitForTransferArchive(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any())) + .thenReturn(CompletableFuture.completedFuture(Optional.of(transferArchive))); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/transfer_archive/") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get()) { + + assertEquals(200, response.getStatus()); + assertEquals(transferArchive, response.readEntity(RemoteAttachment.class)); + } + } + + @Test + void waitForTransferArchiveNoArchiveUploaded() { + when(rateLimiter.validateAsync(anyString())).thenReturn(CompletableFuture.completedFuture(null)); + when(accountsManager.waitForTransferArchive(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any())) + .thenReturn(CompletableFuture.completedFuture(Optional.empty())); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/transfer_archive/") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get()) { + + assertEquals(204, response.getStatus()); + } + } + + @ParameterizedTest + @MethodSource + void waitForTransferArchiveBadTimeout(final int timeoutSeconds) { + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/transfer_archive/") + .queryParam("timeout", timeoutSeconds) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get()) { + + assertEquals(400, response.getStatus()); + } + } + + private static List waitForTransferArchiveBadTimeout() { + return List.of(0, -1, 3601); + } + + @Test + void waitForTransferArchiveRateLimited() { + when(rateLimiter.validateAsync(anyString())) + .thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null))); + + try (final Response response = resources.getJerseyTest() + .target("/v1/devices/transfer_archive/") + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get()) { + + assertEquals(429, response.getStatus()); + } + } }