Skip to content

Commit

Permalink
Shift authority to the new pub/sub client presence system
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-signal authored Nov 6, 2024
1 parent aad1267 commit 9d19fc9
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 213 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
storageServiceExecutor, storageServiceRetryExecutor, config.getSecureStorageServiceConfiguration());
ClientPresenceManager clientPresenceManager = new ClientPresenceManager(clientPresenceCluster, recurringJobExecutor,
keyspaceNotificationDispatchExecutor);
PubSubClientEventManager pubSubClientEventManager = new PubSubClientEventManager(messagesCluster, clientEventExecutor, experimentEnrollmentManager);
PubSubClientEventManager pubSubClientEventManager = new PubSubClientEventManager(messagesCluster, clientEventExecutor);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
MessagesCache messagesCache = new MessagesCache(messagesCluster, keyspaceNotificationDispatchExecutor,
messageDeliveryScheduler, messageDeletionAsyncExecutor, clock, dynamicConfigurationManager);
Expand Down Expand Up @@ -678,7 +678,7 @@ public void run(WhisperServerConfiguration config, Environment environment) thro
final AccountAuthenticator accountAuthenticator = new AccountAuthenticator(accountsManager);

final MessageSender messageSender =
new MessageSender(clientPresenceManager, pubSubClientEventManager, messagesManager, pushNotificationManager);
new MessageSender(pubSubClientEventManager, messagesManager, pushNotificationManager);
final ReceiptSender receiptSender = new ReceiptSender(accountsManager, messageSender, receiptSenderExecutor);
final TurnTokenGenerator turnTokenGenerator = new TurnTokenGenerator(dynamicConfigurationManager,
config.getTurnConfiguration().secret().value());
Expand Down Expand Up @@ -1018,7 +1018,7 @@ protected void configureServer(final ServerBuilder<?> serverBuilder) {
webSocketEnvironment.jersey().register(new RequestStatisticsFilter(TrafficSource.WEBSOCKET));
webSocketEnvironment.jersey().register(MultiRecipientMessageProvider.class);
webSocketEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET, clientReleaseManager));
webSocketEnvironment.jersey().register(new KeepAliveController(clientPresenceManager));
webSocketEnvironment.jersey().register(new KeepAliveController(pubSubClientEventManager));
webSocketEnvironment.jersey().register(new TimestampResponseFilter());

final List<SpamFilter> spamFilters = ServiceLoader.load(SpamFilter.class)
Expand Down Expand Up @@ -1159,7 +1159,7 @@ protected void configureServer(final ServerBuilder<?> serverBuilder) {
pubSubClientEventManager));
provisioningEnvironment.setConnectListener(new ProvisioningConnectListener(provisioningManager));
provisioningEnvironment.jersey().register(new MetricsApplicationEventListener(TrafficSource.WEBSOCKET, clientReleaseManager));
provisioningEnvironment.jersey().register(new KeepAliveController(clientPresenceManager));
provisioningEnvironment.jersey().register(new KeepAliveController(pubSubClientEventManager));
provisioningEnvironment.jersey().register(new TimestampResponseFilter());

registerCorsFilter(environment);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice;
import org.whispersystems.textsecuregcm.metrics.UserAgentTagUtil;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.websocket.auth.ReadOnly;
import org.whispersystems.websocket.session.WebSocketSession;
import org.whispersystems.websocket.session.WebSocketSessionContext;
Expand All @@ -34,22 +35,22 @@ public class KeepAliveController {

private final Logger logger = LoggerFactory.getLogger(KeepAliveController.class);

private final ClientPresenceManager clientPresenceManager;
private final PubSubClientEventManager pubSubClientEventManager;

private static final String CLOSED_CONNECTION_AGE_DISTRIBUTION_NAME = name(KeepAliveController.class,
"closedConnectionAge");


public KeepAliveController(final ClientPresenceManager clientPresenceManager) {
this.clientPresenceManager = clientPresenceManager;
public KeepAliveController(final PubSubClientEventManager pubSubClientEventManager) {
this.pubSubClientEventManager = pubSubClientEventManager;
}

@GET
public Response getKeepAlive(@ReadOnly @Auth Optional<AuthenticatedDevice> maybeAuth,
@WebSocketSession WebSocketSessionContext context) {

maybeAuth.ifPresent(auth -> {
if (!clientPresenceManager.isLocallyPresent(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId())) {
if (!pubSubClientEventManager.isLocallyPresent(auth.getAccount().getUuid(), auth.getAuthenticatedDevice().getId())) {

final Duration age = Duration.between(context.getClient().getCreated(), Instant.now());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
import static com.codahale.metrics.MetricRegistry.name;
import static org.whispersystems.textsecuregcm.entities.MessageProtos.Envelope;

import com.google.common.annotations.VisibleForTesting;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import org.whispersystems.textsecuregcm.identity.IdentityType;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.storage.Account;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.storage.MessagesManager;
import java.util.Objects;
import org.whispersystems.textsecuregcm.util.Util;
import java.util.concurrent.CompletableFuture;

/**
* A MessageSender sends Signal messages to destination devices. Messages may be "normal" user-to-user messages,
Expand All @@ -29,7 +33,6 @@
*/
public class MessageSender {

private final ClientPresenceManager clientPresenceManager;
private final PubSubClientEventManager pubSubClientEventManager;
private final MessagesManager messagesManager;
private final PushNotificationManager pushNotificationManager;
Expand All @@ -38,71 +41,68 @@ public class MessageSender {
private static final String CHANNEL_TAG_NAME = "channel";
private static final String EPHEMERAL_TAG_NAME = "ephemeral";
private static final String CLIENT_ONLINE_TAG_NAME = "clientOnline";
private static final String PUB_SUB_CLIENT_ONLINE_TAG_NAME = "pubSubClientOnline";
private static final String URGENT_TAG_NAME = "urgent";
private static final String STORY_TAG_NAME = "story";
private static final String SEALED_SENDER_TAG_NAME = "sealedSender";

public MessageSender(final ClientPresenceManager clientPresenceManager,
final PubSubClientEventManager pubSubClientEventManager,
private static final Counter CLIENT_PRESENCE_ERROR =
Metrics.counter(MetricsUtil.name(MessageSender.class, "clientPresenceError"));

public MessageSender(final PubSubClientEventManager pubSubClientEventManager,
final MessagesManager messagesManager,
final PushNotificationManager pushNotificationManager) {

this.clientPresenceManager = clientPresenceManager;
this.pubSubClientEventManager = pubSubClientEventManager;
this.messagesManager = messagesManager;
this.pushNotificationManager = pushNotificationManager;
}

public void sendMessage(final Account account, final Device device, final Envelope message, final boolean online) {

final String channel;
public CompletableFuture<Void> sendMessage(final Account account, final Device device, final Envelope message, final boolean online) {
messagesManager.insert(account.getUuid(),
device.getId(),
online ? message.toBuilder().setEphemeral(true).build() : message);

return pubSubClientEventManager.handleNewMessageAvailable(account.getIdentifier(IdentityType.ACI), device.getId())
.exceptionally(throwable -> {
// It's unlikely that the message insert (synchronous) would succeed and sending a "new message available"
// event would fail since both things happen in the same cluster, but just in case, we should "fail open" and
// act as if the client wasn't present if this happens. This is a conservative measure that biases toward
// sending more push notifications, though again, it shouldn't happen often.
CLIENT_PRESENCE_ERROR.increment();
return false;
})
.thenApply(clientPresent -> {
if (!clientPresent && !online) {
try {
pushNotificationManager.sendNewMessageNotification(account, device.getId(), message.getUrgent());
} catch (final NotPushRegisteredException ignored) {
}
}

return clientPresent;
})
.whenComplete((clientPresent, throwable) -> Metrics.counter(SEND_COUNTER_NAME,
CHANNEL_TAG_NAME, getDeliveryChannelName(device),
EPHEMERAL_TAG_NAME, String.valueOf(online),
CLIENT_ONLINE_TAG_NAME, String.valueOf(clientPresent),
URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
STORY_TAG_NAME, String.valueOf(message.getStory()),
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()))
.increment())
.thenRun(Util.NOOP)
.toCompletableFuture();
}

@VisibleForTesting
static String getDeliveryChannelName(final Device device) {
if (device.getGcmId() != null) {
channel = "gcm";
return "gcm";
} else if (device.getApnId() != null) {
channel = "apn";
return "apn";
} else if (device.getFetchesMessages()) {
channel = "websocket";
return "websocket";
} else {
channel = "none";
return "none";
}

final boolean clientPresent;

if (online) {
clientPresent = clientPresenceManager.isPresent(account.getUuid(), device.getId());

if (clientPresent) {
messagesManager.insert(account.getUuid(), device.getId(), message.toBuilder().setEphemeral(true).build());
} else {
messagesManager.removeRecipientViewFromMrmData(device.getId(), message);
}
} else {
messagesManager.insert(account.getUuid(), device.getId(), message);

// We check for client presence after inserting the message to take a conservative view of notifications. If the
// client wasn't present at the time of insertion but is now, they'll retrieve the message. If they were present
// but disconnected before the message was delivered, we should send a notification.
clientPresent = clientPresenceManager.isPresent(account.getUuid(), device.getId());

if (!clientPresent) {
try {
pushNotificationManager.sendNewMessageNotification(account, device.getId(), message.getUrgent());
} catch (final NotPushRegisteredException ignored) {
}
}
}

pubSubClientEventManager.handleNewMessageAvailable(account.getIdentifier(IdentityType.ACI), device.getId())
.whenComplete((present, throwable) -> Metrics.counter(SEND_COUNTER_NAME,
CHANNEL_TAG_NAME, channel,
EPHEMERAL_TAG_NAME, String.valueOf(online),
CLIENT_ONLINE_TAG_NAME, String.valueOf(clientPresent),
PUB_SUB_CLIENT_ONLINE_TAG_NAME, String.valueOf(Objects.requireNonNullElse(present, false)),
URGENT_TAG_NAME, String.valueOf(message.getUrgent()),
STORY_TAG_NAME, String.valueOf(message.getStory()),
SEALED_SENDER_TAG_NAME, String.valueOf(!message.hasSourceServiceId()))
.increment());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,28 @@
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Metrics;
import io.micrometer.core.instrument.Tags;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.metrics.MetricsUtil;
import org.whispersystems.textsecuregcm.redis.FaultTolerantPubSubClusterConnection;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.storage.Device;
import org.whispersystems.textsecuregcm.util.RedisClusterUtil;
import org.whispersystems.textsecuregcm.util.UUIDUtil;
import org.whispersystems.textsecuregcm.util.Util;
import javax.annotation.Nullable;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;

/**
* The pub/sub-based client presence manager uses the Redis 7 sharded pub/sub system to notify connected clients that
Expand All @@ -54,9 +58,6 @@ public class PubSubClientEventManager extends RedisClusterPubSubAdapter<byte[],
.build()
.toByteArray();

private final ExperimentEnrollmentManager experimentEnrollmentManager;
static final String EXPERIMENT_NAME = "pubSubPresenceManager";

@Nullable
private FaultTolerantPubSubClusterConnection<byte[], byte[]> pubSubConnection;

Expand Down Expand Up @@ -90,12 +91,10 @@ private record AccountAndDeviceIdentifier(UUID accountIdentifier, byte deviceId)
}

public PubSubClientEventManager(final FaultTolerantRedisClusterClient clusterClient,
final Executor listenerEventExecutor,
final ExperimentEnrollmentManager experimentEnrollmentManager) {
final Executor listenerEventExecutor) {

this.clusterClient = clusterClient;
this.listenerEventExecutor = listenerEventExecutor;
this.experimentEnrollmentManager = experimentEnrollmentManager;

this.listenersByAccountAndDeviceIdentifier =
Metrics.gaugeMapSize(LISTENER_GAUGE_NAME, Tags.empty(), new ConcurrentHashMap<>());
Expand Down Expand Up @@ -140,10 +139,6 @@ public CompletionStage<UUID> handleClientConnected(final UUID accountIdentifier,
throw new IllegalStateException("Presence manager not started");
}

if (!experimentEnrollmentManager.isEnrolled(accountIdentifier, EXPERIMENT_NAME)) {
return CompletableFuture.completedFuture(UUID.randomUUID());
}

final UUID connectionId = UUID.randomUUID();
final byte[] clientPresenceKey = getClientPresenceKey(accountIdentifier, deviceId);
final AtomicReference<ClientEventListener> displacedListener = new AtomicReference<>();
Expand Down Expand Up @@ -201,10 +196,6 @@ public CompletionStage<Void> handleClientDisconnected(final UUID accountIdentifi
throw new IllegalStateException("Presence manager not started");
}

if (!experimentEnrollmentManager.isEnrolled(accountIdentifier, EXPERIMENT_NAME)) {
return CompletableFuture.completedFuture(null);
}

final AtomicReference<CompletionStage<Void>> unsubscribeFuture = new AtomicReference<>();

// Note that we're relying on some specific implementation details of `ConcurrentHashMap#compute(...)`. In
Expand Down Expand Up @@ -248,10 +239,6 @@ public CompletionStage<Boolean> handleNewMessageAvailable(final UUID accountIden
throw new IllegalStateException("Presence manager not started");
}

if (!experimentEnrollmentManager.isEnrolled(accountIdentifier, EXPERIMENT_NAME)) {
return CompletableFuture.completedFuture(false);
}

return pubSubConnection.withPubSubConnection(connection ->
connection.async().spublish(getClientPresenceKey(accountIdentifier, deviceId), NEW_MESSAGE_EVENT_BYTES))
.thenApply(listeners -> listeners > 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,17 @@ public void handleDisplacement(final boolean connectedElsewhere) {
);

Metrics.counter(DISPLACEMENT_COUNTER_NAME, tags).increment();
}

@Override
public void handleConnectionDisplaced(final boolean connectedElsewhere) {
final Tags tags = Tags.of(
UserAgentTagUtil.getPlatformTag(client.getUserAgent()),
Tag.of("connectedElsewhere", String.valueOf(connectedElsewhere)),
Tag.of(PRESENCE_MANAGER_TAG, "pubsub")
);

Metrics.counter(DISPLACEMENT_COUNTER_NAME, tags).increment();

final int code;
final String message;
Expand All @@ -534,17 +545,6 @@ public void handleDisplacement(final boolean connectedElsewhere) {
}
}

@Override
public void handleConnectionDisplaced(final boolean connectedElsewhere) {
final Tags tags = Tags.of(
UserAgentTagUtil.getPlatformTag(client.getUserAgent()),
Tag.of("connectedElsewhere", String.valueOf(connectedElsewhere)),
Tag.of(PRESENCE_MANAGER_TAG, "pubsub")
);

Metrics.counter(DISPLACEMENT_COUNTER_NAME, tags).increment();
}

private record StoredMessageInfo(UUID guid, long serverTimestamp) {

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,17 @@
import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration;
import org.whispersystems.textsecuregcm.controllers.SecureStorageController;
import org.whispersystems.textsecuregcm.controllers.SecureValueRecovery2Controller;
import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager;
import org.whispersystems.textsecuregcm.experiment.PushNotificationExperimentSamples;
import org.whispersystems.textsecuregcm.limits.RateLimiters;
import org.whispersystems.textsecuregcm.metrics.MicrometerAwsSdkMetricPublisher;
import org.whispersystems.textsecuregcm.push.APNSender;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.push.ClientPresenceManager;
import org.whispersystems.textsecuregcm.push.FcmSender;
import org.whispersystems.textsecuregcm.push.PubSubClientEventManager;
import org.whispersystems.textsecuregcm.push.PushNotificationManager;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.push.PushNotificationScheduler;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClient;
import org.whispersystems.textsecuregcm.redis.FaultTolerantRedisClusterClient;
import org.whispersystems.textsecuregcm.securestorage.SecureStorageClient;
import org.whispersystems.textsecuregcm.securevaluerecovery.SecureValueRecovery2Client;
import org.whispersystems.textsecuregcm.storage.AccountLockManager;
Expand Down Expand Up @@ -218,9 +217,7 @@ static CommandDependencies build(
storageServiceExecutor, storageServiceRetryExecutor, configuration.getSecureStorageServiceConfiguration());
ClientPresenceManager clientPresenceManager = new ClientPresenceManager(clientPresenceCluster,
recurringJobExecutor, keyspaceNotificationDispatchExecutor);
ExperimentEnrollmentManager experimentEnrollmentManager = new ExperimentEnrollmentManager(
dynamicConfigurationManager);
PubSubClientEventManager pubSubClientEventManager = new PubSubClientEventManager(messagesCluster, clientEventExecutor, experimentEnrollmentManager);
PubSubClientEventManager pubSubClientEventManager = new PubSubClientEventManager(messagesCluster, clientEventExecutor);
MessagesCache messagesCache = new MessagesCache(messagesCluster, keyspaceNotificationDispatchExecutor,
messageDeliveryScheduler, messageDeletionExecutor, Clock.systemUTC(), dynamicConfigurationManager);
ProfilesManager profilesManager = new ProfilesManager(profiles, cacheCluster);
Expand Down
Loading

0 comments on commit 9d19fc9

Please sign in to comment.