diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnection.java b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnection.java index 6c9956d81..fe795d11a 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnection.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnection.java @@ -35,7 +35,18 @@ protected FaultTolerantPubSubClusterConnection(final String name, public void subscribeToClusterTopologyChangedEvents(final Runnable eventHandler) { usePubSubConnection(connection -> connection.getResources().eventBus().get() - .filter(event -> event instanceof ClusterTopologyChangedEvent) + .filter(event -> { + // If we use shared `ClientResources` for multiple clients, we may receive topology change events for clusters + // other than our own. Filter for clusters that have at least one node in common with our current view of our + // partitions. + if (event instanceof ClusterTopologyChangedEvent clusterTopologyChangedEvent) { + return withPubSubConnection(c -> c.getPartitions().stream().anyMatch(redisClusterNode -> + clusterTopologyChangedEvent.before().contains(redisClusterNode) || + clusterTopologyChangedEvent.after().contains(redisClusterNode))); + } + + return false; + }) .subscribeOn(topologyChangedEventScheduler) .subscribe(event -> { logger.info("Got topology change event for {}, resubscribing all keyspace notifications", getName()); diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnectionTest.java b/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnectionTest.java index f66cc10d4..cd5e60779 100644 --- a/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnectionTest.java +++ b/service/src/test/java/org/whispersystems/textsecuregcm/redis/FaultTolerantPubSubClusterConnectionTest.java @@ -5,6 +5,7 @@ package org.whispersystems.textsecuregcm.redis; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.atLeastOnce; @@ -19,17 +20,18 @@ import io.github.resilience4j.retry.RetryConfig; import io.lettuce.core.RedisException; import io.lettuce.core.cluster.event.ClusterTopologyChangedEvent; +import io.lettuce.core.cluster.models.partitions.Partitions; +import io.lettuce.core.cluster.models.partitions.RedisClusterNode; import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection; import io.lettuce.core.cluster.pubsub.api.sync.RedisClusterPubSubCommands; import io.lettuce.core.event.Event; import io.lettuce.core.event.EventBus; import io.lettuce.core.resource.ClientResources; -import java.util.Collections; +import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.whispersystems.textsecuregcm.configuration.RetryConfiguration; import reactor.core.publisher.Flux; @@ -42,15 +44,31 @@ class FaultTolerantPubSubClusterConnectionTest { private RedisClusterPubSubCommands pubSubCommands; private FaultTolerantPubSubClusterConnection faultTolerantPubSubConnection; + private TestPublisher eventPublisher; + + private Runnable resubscribe; + + private AtomicInteger resubscribeCounter; + private CountDownLatch resubscribeFailure; + private CountDownLatch resubscribeSuccess; + + private RedisClusterNode nodeInCluster; @SuppressWarnings("unchecked") @BeforeEach public void setUp() { pubSubConnection = mock(StatefulRedisClusterPubSubConnection.class); - pubSubCommands = mock(RedisClusterPubSubCommands.class); + nodeInCluster = mock(RedisClusterNode.class); + + final ClientResources clientResources = mock(ClientResources.class); + + final Partitions partitions = new Partitions(); + partitions.add(nodeInCluster); when(pubSubConnection.sync()).thenReturn(pubSubCommands); + when(pubSubConnection.getResources()).thenReturn(clientResources); + when(pubSubConnection.getPartitions()).thenReturn(partitions); final RetryConfiguration retryConfiguration = new RetryConfiguration(); retryConfiguration.setMaxAttempts(3); @@ -64,108 +82,100 @@ public void setUp() { faultTolerantPubSubConnection = new FaultTolerantPubSubClusterConnection<>("test", pubSubConnection, resubscribeRetry, Schedulers.newSingle("test")); - } - @Nested - class ClusterTopologyChangedEventTest { + eventPublisher = TestPublisher.createCold(); - private TestPublisher eventPublisher; + final EventBus eventBus = mock(EventBus.class); + when(clientResources.eventBus()).thenReturn(eventBus); - private Runnable resubscribe; + final Flux eventFlux = Flux.from(eventPublisher); + when(eventBus.get()).thenReturn(eventFlux); - private AtomicInteger resubscribeCounter; - private CountDownLatch resubscribeFailure; - private CountDownLatch resubscribeSuccess; + resubscribeCounter = new AtomicInteger(); - @BeforeEach - @SuppressWarnings("unchecked") - void setup() { - // ignore inherited stubbing - reset(pubSubConnection); + resubscribe = () -> { + try { + resubscribeCounter.incrementAndGet(); + pubSubConnection.sync().nodes((ignored) -> true); + resubscribeSuccess.countDown(); + } catch (final RuntimeException e) { + resubscribeFailure.countDown(); + throw e; + } + }; - eventPublisher = TestPublisher.createCold(); + resubscribeSuccess = new CountDownLatch(1); + resubscribeFailure = new CountDownLatch(1); + } - final ClientResources clientResources = mock(ClientResources.class); - when(pubSubConnection.getResources()) - .thenReturn(clientResources); - final EventBus eventBus = mock(EventBus.class); - when(clientResources.eventBus()) - .thenReturn(eventBus); + @SuppressWarnings("unchecked") + @Test + void testSubscribeToClusterTopologyChangedEvents() throws Exception { - final Flux eventFlux = Flux.from(eventPublisher); - when(eventBus.get()).thenReturn(eventFlux); + when(pubSubConnection.sync()) + .thenThrow(new RedisException("Cluster unavailable")); - resubscribeCounter = new AtomicInteger(); + eventPublisher.next(new ClusterTopologyChangedEvent(List.of(nodeInCluster), List.of(nodeInCluster))); - resubscribe = () -> { - try { - resubscribeCounter.incrementAndGet(); - pubSubConnection.sync().nodes((ignored) -> true); - resubscribeSuccess.countDown(); - } catch (final RuntimeException e) { - resubscribeFailure.countDown(); - throw e; - } - }; + faultTolerantPubSubConnection.subscribeToClusterTopologyChangedEvents(resubscribe); - resubscribeSuccess = new CountDownLatch(1); - resubscribeFailure = new CountDownLatch(1); - } + assertTrue(resubscribeFailure.await(1, TimeUnit.SECONDS)); - @SuppressWarnings("unchecked") - @Test - void testSubscribeToClusterTopologyChangedEvents() throws Exception { + // simulate cluster recovery - no more exceptions, run the retry + reset(pubSubConnection); + clearInvocations(pubSubCommands); + when(pubSubConnection.sync()) + .thenReturn(pubSubCommands); - when(pubSubConnection.sync()) - .thenThrow(new RedisException("Cluster unavailable")); + assertTrue(resubscribeSuccess.await(1, TimeUnit.SECONDS)); - eventPublisher.next(new ClusterTopologyChangedEvent(Collections.emptyList(), Collections.emptyList())); + assertTrue(resubscribeCounter.get() >= 2, String.format("resubscribe called %d times", resubscribeCounter.get())); + verify(pubSubCommands).nodes(any()); + } - faultTolerantPubSubConnection.subscribeToClusterTopologyChangedEvents(resubscribe); + @Test + void testFilterClusterTopologyChangeEvents() throws InterruptedException { + final CountDownLatch topologyEventLatch = new CountDownLatch(1); - assertTrue(resubscribeFailure.await(1, TimeUnit.SECONDS)); + faultTolerantPubSubConnection.subscribeToClusterTopologyChangedEvents(topologyEventLatch::countDown); - // simulate cluster recovery - no more exceptions, run the retry - reset(pubSubConnection); - clearInvocations(pubSubCommands); - when(pubSubConnection.sync()) - .thenReturn(pubSubCommands); + final RedisClusterNode nodeFromDifferentCluster = mock(RedisClusterNode.class); - assertTrue(resubscribeSuccess.await(1, TimeUnit.SECONDS)); + eventPublisher.next(new ClusterTopologyChangedEvent(List.of(nodeFromDifferentCluster), List.of(nodeFromDifferentCluster))); - assertTrue(resubscribeCounter.get() >= 2, String.format("resubscribe called %d times", resubscribeCounter.get())); - verify(pubSubCommands).nodes(any()); - } + assertFalse(topologyEventLatch.await(1, TimeUnit.SECONDS)); + } - @Test - @SuppressWarnings("unchecked") - void testMultipleEventsWithPendingRetries() throws Exception { - // more complicated scenario: multiple events while retries are pending + @Test + @SuppressWarnings("unchecked") + void testMultipleEventsWithPendingRetries() throws Exception { + // more complicated scenario: multiple events while retries are pending - // cluster is down - when(pubSubConnection.sync()) - .thenThrow(new RedisException("Cluster unavailable")); + // cluster is down + when(pubSubConnection.sync()) + .thenThrow(new RedisException("Cluster unavailable")); - // publish multiple topology changed events - eventPublisher.next(new ClusterTopologyChangedEvent(Collections.emptyList(), Collections.emptyList())); - eventPublisher.next(new ClusterTopologyChangedEvent(Collections.emptyList(), Collections.emptyList())); - eventPublisher.next(new ClusterTopologyChangedEvent(Collections.emptyList(), Collections.emptyList())); - eventPublisher.next(new ClusterTopologyChangedEvent(Collections.emptyList(), Collections.emptyList())); + // publish multiple topology changed events + final ClusterTopologyChangedEvent clusterTopologyChangedEvent = + new ClusterTopologyChangedEvent(List.of(nodeInCluster), List.of(nodeInCluster)); - faultTolerantPubSubConnection.subscribeToClusterTopologyChangedEvents(resubscribe); + eventPublisher.next(clusterTopologyChangedEvent); + eventPublisher.next(clusterTopologyChangedEvent); + eventPublisher.next(clusterTopologyChangedEvent); + eventPublisher.next(clusterTopologyChangedEvent); - assertTrue(resubscribeFailure.await(1, TimeUnit.SECONDS)); + faultTolerantPubSubConnection.subscribeToClusterTopologyChangedEvents(resubscribe); - // simulate cluster recovery - no more exceptions, run the retry - reset(pubSubConnection); - clearInvocations(pubSubCommands); - when(pubSubConnection.sync()) - .thenReturn(pubSubCommands); + assertTrue(resubscribeFailure.await(1, TimeUnit.SECONDS)); - assertTrue(resubscribeSuccess.await(1, TimeUnit.SECONDS)); + // simulate cluster recovery - no more exceptions, run the retry + reset(pubSubConnection); + clearInvocations(pubSubCommands); + when(pubSubConnection.sync()) + .thenReturn(pubSubCommands); - verify(pubSubCommands, atLeastOnce()).nodes(any()); - } - } + assertTrue(resubscribeSuccess.await(1, TimeUnit.SECONDS)); + verify(pubSubCommands, atLeastOnce()).nodes(any()); + } }