From 5a7afbb621f2273a05be4f5df6defc5666597e42 Mon Sep 17 00:00:00 2001 From: Dmitriy Tverdiakov <11927660+injectives@users.noreply.github.com> Date: Tue, 12 Nov 2024 13:23:46 +0000 Subject: [PATCH] Update RoutedBoltConnectionProvider (#1582) --- .../bolt/routedimpl/RoutedBoltConnection.java | 2 +- .../RoutedBoltConnectionProvider.java | 89 ++++++++++--------- 2 files changed, 47 insertions(+), 44 deletions(-) diff --git a/driver/src/main/java/org/neo4j/driver/internal/bolt/routedimpl/RoutedBoltConnection.java b/driver/src/main/java/org/neo4j/driver/internal/bolt/routedimpl/RoutedBoltConnection.java index 446862c01..b080c1de7 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/bolt/routedimpl/RoutedBoltConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/bolt/routedimpl/RoutedBoltConnection.java @@ -267,7 +267,7 @@ public CompletionStage forceClose(String reason) { @Override public CompletionStage close() { - provider.decreaseCount(serverAddress()); + provider.decrementInUseCount(serverAddress()); return delegate.close(); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/bolt/routedimpl/RoutedBoltConnectionProvider.java b/driver/src/main/java/org/neo4j/driver/internal/bolt/routedimpl/RoutedBoltConnectionProvider.java index 665017d75..9820b04e3 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/bolt/routedimpl/RoutedBoltConnectionProvider.java +++ b/driver/src/main/java/org/neo4j/driver/internal/bolt/routedimpl/RoutedBoltConnectionProvider.java @@ -17,7 +17,6 @@ package org.neo4j.driver.internal.bolt.routedimpl; import static java.lang.String.format; -import static org.neo4j.driver.internal.bolt.routedimpl.util.LockUtil.executeWithLock; import java.time.Clock; import java.util.ArrayList; @@ -30,7 +29,6 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.atomic.AtomicReference; -import java.util.concurrent.locks.ReentrantLock; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; @@ -71,7 +69,6 @@ public class RoutedBoltConnectionProvider implements BoltConnectionProvider { "Failed to obtain a connection towards address %s, will try other addresses if available. Complete failure is reported separately from this entry."; private final LoggingProvider logging; private final System.Logger log; - private final ReentrantLock lock = new ReentrantLock(); private final Supplier boltConnectionProviderSupplier; private final Map addressToProvider = new HashMap<>(); @@ -85,8 +82,6 @@ public class RoutedBoltConnectionProvider implements BoltConnectionProvider { private Rediscovery rediscovery; private RoutingTableRegistry registry; - private BoltServerAddress address; - private RoutingContext routingContext; private BoltAgent boltAgent; private String userAgent; @@ -107,13 +102,7 @@ public RoutedBoltConnectionProvider( this.resolver = Objects.requireNonNull(resolver); this.logging = Objects.requireNonNull(logging); this.log = logging.getLog(getClass()); - this.loadBalancingStrategy = new LeastConnectedLoadBalancingStrategy( - (addr) -> { - synchronized (this) { - return addressToInUseCount.getOrDefault(address, 0); - } - }, - logging); + this.loadBalancingStrategy = new LeastConnectedLoadBalancingStrategy(this::getInUseCount, logging); this.domainNameResolver = Objects.requireNonNull(domainNameResolver); this.routingTablePurgeDelayMs = routingTablePurgeDelayMs; this.rediscovery = rediscovery; @@ -121,14 +110,13 @@ public RoutedBoltConnectionProvider( } @Override - public CompletionStage init( + public synchronized CompletionStage init( BoltServerAddress address, RoutingContext routingContext, BoltAgent boltAgent, String userAgent, int connectTimeoutMillis, MetricsListener metricsListener) { - this.address = address; this.routingContext = routingContext; this.boltAgent = boltAgent; this.userAgent = userAgent; @@ -154,10 +142,12 @@ public CompletionStage connect( BoltProtocolVersion minVersion, NotificationConfig notificationConfig, Consumer databaseNameConsumer) { + RoutingTableRegistry registry; synchronized (this) { if (closeFuture != null) { return CompletableFuture.failedFuture(new IllegalStateException("Connection provider is closed.")); } + registry = this.registry; } var handlerRef = new AtomicReference(); @@ -196,6 +186,10 @@ public CompletionStage connect( @Override public CompletionStage verifyConnectivity(SecurityPlan securityPlan, Map authMap) { + RoutingTableRegistry registry; + synchronized (this) { + registry = this.registry; + } return supportsMultiDb(securityPlan, authMap) .thenCompose(supports -> registry.ensureRoutingTable( securityPlan, @@ -244,7 +238,7 @@ private synchronized void shutdownUnusedProviders(Set address while (iterator.hasNext()) { var entry = iterator.next(); var address = entry.getKey(); - if (!addressesToRetain.contains(address) && addressToInUseCount.getOrDefault(address, 0) == 0) { + if (!addressesToRetain.contains(address) && getInUseCount(address) == 0) { entry.getValue().close(); iterator.remove(); } @@ -256,8 +250,12 @@ private CompletionStage detectFeature( Map authMap, String baseErrorMessagePrefix, Function featureDetectionFunction) { - List addresses; + Rediscovery rediscovery; + synchronized (this) { + rediscovery = this.rediscovery; + } + List addresses; try { addresses = rediscovery.resolve(); } catch (Throwable error) { @@ -390,11 +388,7 @@ private void acquire( result.completeExceptionally(error); } } else { - synchronized (this) { - var inUse = addressToInUseCount.getOrDefault(address, 0); - inUse++; - addressToInUseCount.put(address, inUse); - } + incrementInUseCount(address); result.complete(connection); } }); @@ -414,16 +408,23 @@ private static List getAddressesByMode(AccessMode mode, Routi }; } - synchronized void decreaseCount(BoltServerAddress address) { - var inUse = addressToInUseCount.get(address); - if (inUse != null) { - inUse--; - if (inUse <= 0) { - addressToInUseCount.remove(address); + private synchronized int getInUseCount(BoltServerAddress address) { + return addressToInUseCount.getOrDefault(address, 0); + } + + private synchronized void incrementInUseCount(BoltServerAddress address) { + addressToInUseCount.merge(address, 1, Integer::sum); + } + + synchronized void decrementInUseCount(BoltServerAddress address) { + addressToInUseCount.compute(address, (ignored, value) -> { + if (value == null) { + return null; } else { - addressToInUseCount.put(address, inUse); + value--; + return value > 0 ? value : null; } - } + }); } @Override @@ -431,10 +432,14 @@ public CompletionStage close() { CompletableFuture closeFuture; synchronized (this) { if (this.closeFuture == null) { - var futures = executeWithLock(lock, () -> addressToProvider.values().stream() - .map(BoltConnectionProvider::close) - .map(CompletionStage::toCompletableFuture) - .toArray(CompletableFuture[]::new)); + @SuppressWarnings({"rawtypes", "RedundantSuppression"}) + var futures = new CompletableFuture[addressToProvider.size()]; + var iterator = addressToProvider.values().iterator(); + var index = 0; + while (iterator.hasNext()) { + futures[index++] = iterator.next().close().toCompletableFuture(); + iterator.remove(); + } this.closeFuture = CompletableFuture.allOf(futures); } closeFuture = this.closeFuture; @@ -442,15 +447,13 @@ public CompletionStage close() { return closeFuture; } - private BoltConnectionProvider get(BoltServerAddress address) { - return executeWithLock(lock, () -> { - var provider = addressToProvider.get(address); - if (provider == null) { - provider = boltConnectionProviderSupplier.get(); - provider.init(address, routingContext, boltAgent, userAgent, connectTimeoutMillis, metricsListener); - addressToProvider.put(address, provider); - } - return provider; - }); + private synchronized BoltConnectionProvider get(BoltServerAddress address) { + var provider = addressToProvider.get(address); + if (provider == null) { + provider = boltConnectionProviderSupplier.get(); + provider.init(address, routingContext, boltAgent, userAgent, connectTimeoutMillis, metricsListener); + addressToProvider.put(address, provider); + } + return provider; } }