Skip to content

Commit

Permalink
Add API endpoints for waiting for transfer archives
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-signal committed Oct 15, 2024
1 parent 7ff4815 commit 73fb1fc
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import java.time.Duration;
import java.time.Instant;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nullable;
import javax.validation.Valid;
Expand Down Expand Up @@ -58,7 +60,9 @@
import org.whispersystems.textsecuregcm.entities.LinkDeviceRequest;
import org.whispersystems.textsecuregcm.entities.PreKeySignatureValidator;
import org.whispersystems.textsecuregcm.entities.ProvisioningMessage;
import org.whispersystems.textsecuregcm.entities.RemoteAttachment;
import org.whispersystems.textsecuregcm.entities.SetPublicKeyRequest;
import org.whispersystems.textsecuregcm.entities.TransferArchiveUploadedRequest;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
Expand Down Expand Up @@ -318,9 +322,9 @@ public LinkDeviceResponse linkDevice(@HeaderParam(HttpHeaders.AUTHORIZATION) Bas
Waits for a new device to be linked to an account and returns basic information about the new device when
available.
""")
@ApiResponse(responseCode = "200", description = "The specified was linked to an account",
@ApiResponse(responseCode = "200", description = "A device was linked to an account using the token associated with the given token identifier",
content = @Content(schema = @Schema(implementation = DeviceInfo.class)))
@ApiResponse(responseCode = "204", description = "No device was linked to the account before the call completed")
@ApiResponse(responseCode = "204", description = "No device was linked to the account before the call completed; clients may repeat the call to continue waiting")
@ApiResponse(responseCode = "400", description = "The given token identifier or timeout was invalid")
@ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay")
public CompletableFuture<Response> waitForLinkedDevice(
Expand Down Expand Up @@ -432,4 +436,66 @@ private static boolean isCapabilityDowngrade(Account account, DeviceCapabilities
isDowngrade |= account.isVersionedExpirationTimerSupported() && !capabilities.versionedExpirationTimer();
return isDowngrade;
}

@PUT
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
@Path("/transfer_archive")
@Operation(
summary = "Signals that a transfer archive has been uploaded for a specific linked device",
description = """
Signals that a transfer archive has been uploaded for a specific linked device. Devices waiting via the "wait
for transfer archive" endpoint will be notified that the new archive is available.
""")
@ApiResponse(responseCode = "204", description = "Success")
@ApiResponse(responseCode = "422", description = "The request object could not be parsed or was otherwise invalid")
@ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay")
public CompletionStage<Void> recordTransferArchiveUploaded(@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice,
@NotNull @Valid final TransferArchiveUploadedRequest transferArchiveUploadedRequest) {

return rateLimiters.getUploadTransferArchiveLimiter().validateAsync(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI))
.thenCompose(ignored -> accounts.recordTransferArchiveUpload(authenticatedDevice.getAccount(),
transferArchiveUploadedRequest.destinationDeviceId(),
Instant.ofEpochMilli(transferArchiveUploadedRequest.destinationDeviceCreated()),
transferArchiveUploadedRequest.transferArchive()));
}

@GET
@Produces(MediaType.APPLICATION_JSON)
@Path("/transfer_archive")
@Operation(summary = "Wait for a new transfer archive to be uploaded",
description = """
Waits for a new transfer archive to be uploaded for the authenticated device and returns the location of the
archive when available.
""")
@ApiResponse(responseCode = "200", description = "A new transfer archive was uploaded for the authenticated device",
content = @Content(schema = @Schema(implementation = RemoteAttachment.class)))
@ApiResponse(responseCode = "204", description = "No transfer archive was uploaded before the call completed; clients may repeat the call to continue waiting")
@ApiResponse(responseCode = "400", description = "The given timeout was invalid")
@ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay")
public CompletionStage<Response> waitForTransferArchive(@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice,

@QueryParam("timeout")
@DefaultValue("30")
@Min(1)
@Max(3600)
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED,
minimum = "1",
maximum = "3600",
description = """
The amount of time (in seconds) to wait for a response. If a transfer archive for the authenticated
device is not available within the given amount of time, this endpoint will return a status of HTTP/204.
""") final int timeoutSeconds) {

final String rateLimiterKey = authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI) +
":" + authenticatedDevice.getAuthenticatedDevice().getId();

return rateLimiters.getWaitForTransferArchiveLimiter().validateAsync(rateLimiterKey)
.thenCompose(ignored -> accounts.waitForTransferArchive(authenticatedDevice.getAccount(),
authenticatedDevice.getAuthenticatedDevice(),
Duration.ofSeconds(timeoutSeconds)))
.thenApply(maybeTransferArchive -> maybeTransferArchive
.map(transferArchive -> Response.status(Response.Status.OK).entity(transferArchive).build())
.orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/

package org.whispersystems.textsecuregcm.entities;

import io.swagger.v3.oas.annotations.media.Schema;
import org.whispersystems.textsecuregcm.storage.Device;
import javax.validation.Valid;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.Positive;

public record TransferArchiveUploadedRequest(@Min(1)
@Max(Device.MAXIMUM_DEVICE_ID)
@Schema(description = "The ID of the device for which the transfer archive has been prepared")
byte destinationDeviceId,

@Positive
@Schema(description = "The timestamp, in milliseconds since the epoch, at which the destination device was created")
long destinationDeviceCreated,

@Schema(description = "The location of the transfer archive")
@Valid
RemoteAttachment transferArchive) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ public enum For implements RateLimiterDescriptor {
KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", true, new RateLimiterConfig(100, Duration.ofSeconds(15))),
KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", true, new RateLimiterConfig(100, Duration.ofSeconds(15))),
WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", true, new RateLimiterConfig(10, Duration.ofSeconds(30))),
UPLOAD_TRANSFER_ARCHIVE("uploadTransferArchive", true, new RateLimiterConfig(10, Duration.ofMinutes(1))),
WAIT_FOR_TRANSFER_ARCHIVE("waitForTransferArchive", true, new RateLimiterConfig(10, Duration.ofSeconds(30))),
;

private final String id;
Expand Down Expand Up @@ -210,4 +212,12 @@ public RateLimiter getStoriesLimiter() {
public RateLimiter getWaitForLinkedDeviceLimiter() {
return forDescriptor(For.WAIT_FOR_LINKED_DEVICE);
}

public RateLimiter getUploadTransferArchiveLimiter() {
return forDescriptor(For.UPLOAD_TRANSFER_ARCHIVE);
}

public RateLimiter getWaitForTransferArchiveLimiter() {
return forDescriptor(For.WAIT_FOR_TRANSFER_ARCHIVE);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.amazonaws.util.Base64;
import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.AuthValueFactoryProvider;
import io.dropwizard.testing.junit5.DropwizardExtensionsSupport;
import io.dropwizard.testing.junit5.ResourceExtension;
import io.lettuce.core.cluster.api.async.RedisAdvancedClusterAsyncCommands;
import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -64,7 +66,9 @@
import org.whispersystems.textsecuregcm.entities.GcmRegistrationId;
import org.whispersystems.textsecuregcm.entities.KEMSignedPreKey;
import org.whispersystems.textsecuregcm.entities.LinkDeviceRequest;
import org.whispersystems.textsecuregcm.entities.RemoteAttachment;
import org.whispersystems.textsecuregcm.entities.SetPublicKeyRequest;
import org.whispersystems.textsecuregcm.entities.TransferArchiveUploadedRequest;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.RateLimiter;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
Expand Down Expand Up @@ -132,6 +136,8 @@ void setup() {
when(rateLimiters.getAllocateDeviceLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getVerifyDeviceLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getWaitForLinkedDeviceLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getUploadTransferArchiveLimiter()).thenReturn(rateLimiter);
when(rateLimiters.getWaitForTransferArchiveLimiter()).thenReturn(rateLimiter);

when(primaryDevice.getId()).thenReturn(Device.PRIMARY_ID);

Expand Down Expand Up @@ -957,7 +963,7 @@ void waitForLinkedDevice() {
System.currentTimeMillis(),
System.currentTimeMillis());

final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);

when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(deviceInfo)));
Expand All @@ -980,7 +986,7 @@ void waitForLinkedDevice() {

@Test
void waitForLinkedDeviceNoDeviceLinked() {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);

when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));
Expand All @@ -997,7 +1003,7 @@ void waitForLinkedDeviceNoDeviceLinked() {

@Test
void waitForLinkedDeviceBadTokenIdentifier() {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);

when(accountsManager.waitForNewLinkedDevice(eq(tokenIdentifier), any()))
.thenReturn(CompletableFuture.failedFuture(new IllegalArgumentException()));
Expand All @@ -1015,7 +1021,7 @@ void waitForLinkedDeviceBadTokenIdentifier() {
@ParameterizedTest
@MethodSource
void waitForLinkedDeviceBadTimeout(final int timeoutSeconds) {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);

try (final Response response = resources.getJerseyTest()
.target("/v1/devices/wait_for_linked_device/" + tokenIdentifier)
Expand Down Expand Up @@ -1052,7 +1058,7 @@ private static List<String> waitForLinkedDeviceBadTokenIdentifierLength() {

@Test
void waitForLinkedDeviceRateLimited() throws RateLimitExceededException {
final String tokenIdentifier = Base64.encodeAsString(new byte[32]);
final String tokenIdentifier = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]);

doThrow(new RateLimitExceededException(null)).when(rateLimiter).validate(AuthHelper.VALID_UUID);

Expand All @@ -1065,4 +1071,158 @@ void waitForLinkedDeviceRateLimited() throws RateLimitExceededException {
assertEquals(429, response.getStatus());
}
}

@Test
void recordTransferArchiveUploaded() {
final byte deviceId = Device.PRIMARY_ID + 1;
final Instant deviceCreated = Instant.now().truncatedTo(ChronoUnit.MILLIS);
final RemoteAttachment transferArchive =
new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)));

when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferArchive))
.thenReturn(CompletableFuture.completedFuture(null));

try (final Response response = resources.getJerseyTest()
.target("/v1/devices/transfer_archive")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new TransferArchiveUploadedRequest(deviceId, deviceCreated.toEpochMilli(), transferArchive),
MediaType.APPLICATION_JSON_TYPE))) {

assertEquals(204, response.getStatus());

verify(accountsManager)
.recordTransferArchiveUpload(AuthHelper.VALID_ACCOUNT, deviceId, deviceCreated, transferArchive);
}
}

@ParameterizedTest
@MethodSource
void recordTransferArchiveUploadedBadRequest(final TransferArchiveUploadedRequest request) {
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID)).thenReturn(CompletableFuture.completedFuture(null));

try (final Response response = resources.getJerseyTest()
.target("/v1/devices/transfer_archive")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(request, MediaType.APPLICATION_JSON_TYPE))) {

assertEquals(422, response.getStatus());

verify(accountsManager, never())
.recordTransferArchiveUpload(any(), anyByte(), any(), any());
}
}

@SuppressWarnings("DataFlowIssue")
private static List<TransferArchiveUploadedRequest> recordTransferArchiveUploadedBadRequest() {
final RemoteAttachment validTransferArchive =
new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("archive".getBytes(StandardCharsets.UTF_8)));

return List.of(
// Invalid device ID
new TransferArchiveUploadedRequest((byte) -1, System.currentTimeMillis(), validTransferArchive),

// Invalid "created at" timestamp
new TransferArchiveUploadedRequest(Device.PRIMARY_ID, -1, validTransferArchive),

// Missing CDN number
new TransferArchiveUploadedRequest(Device.PRIMARY_ID, System.currentTimeMillis(),
new RemoteAttachment(null, Base64.getUrlEncoder().encodeToString("archive".getBytes(StandardCharsets.UTF_8)))),

// Bad attachment key
new TransferArchiveUploadedRequest(Device.PRIMARY_ID, System.currentTimeMillis(),
new RemoteAttachment(3, "This is not a valid base64 string"))
);
}

@Test
void recordTransferArchiveRateLimited() {
when(rateLimiter.validateAsync(AuthHelper.VALID_UUID))
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null)));

try (final Response response = resources.getJerseyTest()
.target("/v1/devices/transfer_archive")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.put(Entity.entity(new TransferArchiveUploadedRequest(Device.PRIMARY_ID, System.currentTimeMillis(),
new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)))),
MediaType.APPLICATION_JSON_TYPE))) {

assertEquals(429, response.getStatus());

verify(accountsManager, never())
.recordTransferArchiveUpload(any(), anyByte(), any(), any());
}
}

@Test
void waitForTransferArchive() {
final RemoteAttachment transferArchive =
new RemoteAttachment(3, Base64.getUrlEncoder().encodeToString("test".getBytes(StandardCharsets.UTF_8)));

when(rateLimiter.validateAsync(anyString())).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.waitForTransferArchive(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.of(transferArchive)));

try (final Response response = resources.getJerseyTest()
.target("/v1/devices/transfer_archive/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {

assertEquals(200, response.getStatus());
assertEquals(transferArchive, response.readEntity(RemoteAttachment.class));
}
}

@Test
void waitForTransferArchiveNoArchiveUploaded() {
when(rateLimiter.validateAsync(anyString())).thenReturn(CompletableFuture.completedFuture(null));
when(accountsManager.waitForTransferArchive(eq(AuthHelper.VALID_ACCOUNT), eq(AuthHelper.VALID_DEVICE), any()))
.thenReturn(CompletableFuture.completedFuture(Optional.empty()));

try (final Response response = resources.getJerseyTest()
.target("/v1/devices/transfer_archive/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {

assertEquals(204, response.getStatus());
}
}

@ParameterizedTest
@MethodSource
void waitForTransferArchiveBadTimeout(final int timeoutSeconds) {
try (final Response response = resources.getJerseyTest()
.target("/v1/devices/transfer_archive/")
.queryParam("timeout", timeoutSeconds)
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {

assertEquals(400, response.getStatus());
}
}

private static List<Integer> waitForTransferArchiveBadTimeout() {
return List.of(0, -1, 3601);
}

@Test
void waitForTransferArchiveRateLimited() {
when(rateLimiter.validateAsync(anyString()))
.thenReturn(CompletableFuture.failedFuture(new RateLimitExceededException(null)));

try (final Response response = resources.getJerseyTest()
.target("/v1/devices/transfer_archive/")
.request()
.header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD))
.get()) {

assertEquals(429, response.getStatus());
}
}
}

0 comments on commit 73fb1fc

Please sign in to comment.