Skip to content

Commit

Permalink
Add API endpoints for waiting for newly-linked devices
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-signal authored Oct 10, 2024
1 parent 087c2b6 commit 8c30a35
Show file tree
Hide file tree
Showing 16 changed files with 793 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
ClientPublicKeysManager clientPublicKeysManager =
new ClientPublicKeysManager(clientPublicKeys, accountLockManager, accountLockExecutor);
AccountsManager accountsManager = new AccountsManager(accounts, phoneNumberIdentifiers, cacheCluster,
accountLockManager, keysManager, messagesManager, profilesManager,
pubsubClient, accountLockManager, keysManager, messagesManager, profilesManager,
secureStorageClient, secureValueRecovery2Client,
clientPresenceManager,
registrationRecoveryPasswordsManager, clientPublicKeysManager, accountLockExecutor, clientPresenceExecutor,
Expand Down Expand Up @@ -764,6 +764,7 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
environment.lifecycle().manage(keyTransparencyServiceClient);
environment.lifecycle().manage(clientReleaseManager);
environment.lifecycle().manage(virtualThreadPinEventMonitor);
environment.lifecycle().manage(accountsManager);

final RegistrationCaptchaManager registrationCaptchaManager = new RegistrationCaptchaManager(captchaChecker);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,49 @@
*/
package org.whispersystems.textsecuregcm.controllers;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.net.HttpHeaders;
import io.dropwizard.auth.Auth;
import io.lettuce.core.RedisException;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.headers.Header;
import io.swagger.v3.oas.annotations.media.Schema;
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.util.LinkedList;
import java.time.Duration;
import java.util.Arrays;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import javax.validation.Valid;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotNull;
import javax.validation.constraints.Size;
import javax.ws.rs.Consumes;
import javax.ws.rs.DELETE;
import javax.ws.rs.DefaultValue;
import javax.ws.rs.ForbiddenException;
import javax.ws.rs.GET;
import javax.ws.rs.HeaderParam;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.PathParam;
import javax.ws.rs.Produces;
import javax.ws.rs.QueryParam;
import javax.ws.rs.WebApplicationException;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import io.swagger.v3.oas.annotations.tags.Tag;
import org.glassfish.jersey.server.ContainerRequest;
import org.whispersystems.textsecuregcm.auth.LinkedDeviceRefreshRequirementProvider;
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
Expand All @@ -47,14 +63,20 @@
import org.whispersystems.textsecuregcm.entities.SetPublicKeyRequest;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.AccountsManager;
import org.whispersystems.textsecuregcm.storage.ClientPublicKeysManager;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.Device.DeviceCapabilities;
import org.whispersystems.textsecuregcm.storage.DeviceSpec;
import org.whispersystems.textsecuregcm.storage.LinkDeviceTokenAlreadyUsedException;
import org.whispersystems.textsecuregcm.util.VerificationCode;
import org.whispersystems.textsecuregcm.util.ExceptionUtils;
import org.whispersystems.textsecuregcm.util.LinkDeviceToken;
import org.whispersystems.textsecuregcm.util.ua.ClientPlatform;
import org.whispersystems.textsecuregcm.util.ua.UnrecognizedUserAgentException;
import org.whispersystems.textsecuregcm.util.ua.UserAgentUtil;
import org.whispersystems.websocket.auth.Mutable;
import org.whispersystems.websocket.auth.ReadOnly;

Expand All @@ -69,6 +91,21 @@ public class DeviceController {
private final RateLimiters rateLimiters;
private final Map<String, Integer> maxDeviceConfiguration;

private final EnumMap<ClientPlatform, AtomicInteger> linkedDeviceListenersByPlatform;
private final AtomicInteger linkedDeviceListenersForUnrecognizedPlatforms;

private static final String LINKED_DEVICE_LISTENER_GAUGE_NAME =
MetricsUtil.name(DeviceController.class, "linkedDeviceListeners");

private static final String WAIT_FOR_LINKED_DEVICE_TIMER_NAME =
MetricsUtil.name(DeviceController.class, "waitForLinkedDeviceDuration");

@VisibleForTesting
static final int MIN_TOKEN_IDENTIFIER_LENGTH = 32;

@VisibleForTesting
static final int MAX_TOKEN_IDENTIFIER_LENGTH = 64;

public DeviceController(final AccountsManager accounts,
final ClientPublicKeysManager clientPublicKeysManager,
final RateLimiters rateLimiters,
Expand All @@ -78,19 +115,32 @@ public DeviceController(final AccountsManager accounts,
this.clientPublicKeysManager = clientPublicKeysManager;
this.rateLimiters = rateLimiters;
this.maxDeviceConfiguration = maxDeviceConfiguration;

linkedDeviceListenersByPlatform = Arrays.stream(ClientPlatform.values())
.collect(Collectors.toMap(
Function.identity(),
clientPlatform -> buildGauge(clientPlatform.name().toLowerCase()),
(a, b) -> {
throw new AssertionError("Duplicate client platform enumeration key");
},
() -> new EnumMap<>(ClientPlatform.class)
));

linkedDeviceListenersForUnrecognizedPlatforms = buildGauge("unknown");
}

private static AtomicInteger buildGauge(final String clientPlatformName) {
return Metrics.gauge(LINKED_DEVICE_LISTENER_GAUGE_NAME,
Tags.of(io.micrometer.core.instrument.Tag.of(UserAgentTagUtil.PLATFORM_TAG, clientPlatformName)),
new AtomicInteger(0));
}

@GET
@Produces(MediaType.APPLICATION_JSON)
public DeviceInfoList getDevices(@ReadOnly @Auth AuthenticatedDevice auth) {
List<DeviceInfo> devices = new LinkedList<>();

for (Device device : auth.getAccount().getDevices()) {
devices.add(new DeviceInfo(device.getId(), device.getName(),
device.getLastSeen(), device.getCreated()));
}

return new DeviceInfoList(devices);
return new DeviceInfoList(auth.getAccount().getDevices().stream()
.map(DeviceInfo::forDevice)
.toList());
}

@DELETE
Expand Down Expand Up @@ -138,7 +188,7 @@ public void removeDevice(@Mutable @Auth AuthenticatedDevice auth, @PathParam("de
@ApiResponse(responseCode = "429", description = "Too many attempts", headers = @Header(
name = "Retry-After",
description = "If present, an positive integer indicating the number of seconds before a subsequent attempt could succeed"))
public VerificationCode createDeviceToken(@ReadOnly @Auth AuthenticatedDevice auth)
public LinkDeviceToken createDeviceToken(@ReadOnly @Auth AuthenticatedDevice auth)
throws RateLimitExceededException, DeviceLimitExceededException {

final Account account = auth.getAccount();
Expand All @@ -159,7 +209,9 @@ public VerificationCode createDeviceToken(@ReadOnly @Auth AuthenticatedDevice au
throw new WebApplicationException(Response.Status.UNAUTHORIZED);
}

return new VerificationCode(accounts.generateDeviceLinkingToken(account.getUuid()));
final String token = accounts.generateLinkDeviceToken(account.getUuid());

return new LinkDeviceToken(token, AccountsManager.getLinkDeviceTokenIdentifier(token));
}

@PUT
Expand Down Expand Up @@ -266,6 +318,83 @@ public DeviceResponse linkDevice(@HeaderParam(HttpHeaders.AUTHORIZATION) BasicAu
}
}

@GET
@Path("/wait_for_linked_device/{tokenIdentifier}")
@Produces(MediaType.APPLICATION_JSON)
@Operation(summary = "Wait for a new device to be linked to an account",
description = """
Waits for a new device to be linked to an account and returns basic information about the new device when
available.
""")
@ApiResponse(responseCode = "200", description = "The specified was linked to an account")
@ApiResponse(responseCode = "204", description = "No device was linked to the account before the call completed")
@ApiResponse(responseCode = "400", description = "The given token identifier or timeout was invalid")
@ApiResponse(responseCode = "429", description = "Rate-limited; try again after the prescribed delay")
@Schema(description = "Basic information about the linked device", implementation = DeviceInfo.class)
public CompletableFuture<Response> waitForLinkedDevice(
@ReadOnly @Auth final AuthenticatedDevice authenticatedDevice,

@PathParam("tokenIdentifier")
@Schema(description = "A 'link device' token identifier provided by the 'create link device token' endpoint")
@Size(min = MIN_TOKEN_IDENTIFIER_LENGTH, max = MAX_TOKEN_IDENTIFIER_LENGTH)
final String tokenIdentifier,

@QueryParam("timeout")
@DefaultValue("30")
@Min(1)
@Max(3600)
@Schema(requiredMode = Schema.RequiredMode.NOT_REQUIRED,
minimum = "1",
maximum = "3600",
description = """
The amount of time (in seconds) to wait for a response. If the expected device is not linked within the
given amount of time, this endpoint will return a status of HTTP/204.
""") final int timeoutSeconds,

@HeaderParam(HttpHeaders.USER_AGENT) String userAgent) throws RateLimitExceededException {

rateLimiters.getWaitForLinkedDeviceLimiter().validate(authenticatedDevice.getAccount().getIdentifier(IdentityType.ACI));

final AtomicInteger linkedDeviceListenerCounter = getCounterForLinkedDeviceListeners(userAgent);
linkedDeviceListenerCounter.incrementAndGet();

final Timer.Sample sample = Timer.start();

try {
return accounts.waitForNewLinkedDevice(tokenIdentifier, Duration.ofSeconds(timeoutSeconds))
.thenApply(maybeDeviceInfo -> maybeDeviceInfo
.map(deviceInfo -> Response.status(Response.Status.OK).entity(deviceInfo).build())
.orElseGet(() -> Response.status(Response.Status.NO_CONTENT).build()))
.exceptionally(ExceptionUtils.exceptionallyHandler(IllegalArgumentException.class,
e -> Response.status(Response.Status.BAD_REQUEST).build()))
.whenComplete((response, throwable) -> {
linkedDeviceListenerCounter.decrementAndGet();

if (response != null) {
sample.stop(Timer.builder(WAIT_FOR_LINKED_DEVICE_TIMER_NAME)
.publishPercentileHistogram(true)
.tags(Tags.of(UserAgentTagUtil.getPlatformTag(userAgent),
io.micrometer.core.instrument.Tag.of("deviceFound",
String.valueOf(response.getStatus() == Response.Status.OK.getStatusCode()))))
.register(Metrics.globalRegistry));
}
});
} catch (final RedisException e) {
// `waitForNewLinkedDevice` could fail synchronously if the Redis circuit breaker is open; prevent counter drift
// if that happens
linkedDeviceListenerCounter.decrementAndGet();
throw e;
}
}

private AtomicInteger getCounterForLinkedDeviceListeners(final String userAgent) {
try {
return linkedDeviceListenersByPlatform.get(UserAgentUtil.parseUserAgentString(userAgent).getPlatform());
} catch (final UnrecognizedUserAgentException ignored) {
return linkedDeviceListenersForUnrecognizedPlatforms;
}
}

@PUT
@Produces(MediaType.APPLICATION_JSON)
@Path("/unauthenticated_delivery")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.ByteArrayBase64WithPaddingAdapter;

public record DeviceInfo(long id,
Expand All @@ -17,4 +18,8 @@ public record DeviceInfo(long id,

long lastSeen,
long created) {

public static DeviceInfo forDevice(final Device device) {
return new DeviceInfo(device.getId(), device.getName(), device.getLastSeen(), device.getCreated());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public enum For implements RateLimiterDescriptor {
EXTERNAL_SERVICE_CREDENTIALS("externalServiceCredentials", true, new RateLimiterConfig(100, Duration.ofMinutes(15))),
KEY_TRANSPARENCY_SEARCH_PER_IP("keyTransparencySearch", true, new RateLimiterConfig(100, Duration.ofSeconds(15))),
KEY_TRANSPARENCY_MONITOR_PER_IP("keyTransparencyMonitor", true, new RateLimiterConfig(100, Duration.ofSeconds(15))),
WAIT_FOR_LINKED_DEVICE("waitForLinkedDevice", true, new RateLimiterConfig(10, Duration.ofSeconds(30))),
;

private final String id;
Expand Down Expand Up @@ -205,4 +206,8 @@ public RateLimiter getInboundMessageBytes() {
public RateLimiter getStoriesLimiter() {
return forDescriptor(For.STORIES);
}

public RateLimiter getWaitForLinkedDeviceLimiter() {
return forDescriptor(For.WAIT_FOR_LINKED_DEVICE);
}
}
Loading

0 comments on commit 8c30a35

Please sign in to comment.