Skip to content

Commit

Permalink
Update RoutedBoltConnectionProvider (#1582)
Browse files Browse the repository at this point in the history
  • Loading branch information
injectives authored Nov 12, 2024
1 parent c5b2868 commit 5a7afbb
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ public CompletionStage<Void> forceClose(String reason) {

@Override
public CompletionStage<Void> close() {
provider.decreaseCount(serverAddress());
provider.decrementInUseCount(serverAddress());
return delegate.close();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.neo4j.driver.internal.bolt.routedimpl;

import static java.lang.String.format;
import static org.neo4j.driver.internal.bolt.routedimpl.util.LockUtil.executeWithLock;

import java.time.Clock;
import java.util.ArrayList;
Expand All @@ -30,7 +29,6 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
Expand Down Expand Up @@ -71,7 +69,6 @@ public class RoutedBoltConnectionProvider implements BoltConnectionProvider {
"Failed to obtain a connection towards address %s, will try other addresses if available. Complete failure is reported separately from this entry.";
private final LoggingProvider logging;
private final System.Logger log;
private final ReentrantLock lock = new ReentrantLock();
private final Supplier<BoltConnectionProvider> boltConnectionProviderSupplier;

private final Map<BoltServerAddress, BoltConnectionProvider> addressToProvider = new HashMap<>();
Expand All @@ -85,8 +82,6 @@ public class RoutedBoltConnectionProvider implements BoltConnectionProvider {
private Rediscovery rediscovery;
private RoutingTableRegistry registry;

private BoltServerAddress address;

private RoutingContext routingContext;
private BoltAgent boltAgent;
private String userAgent;
Expand All @@ -107,28 +102,21 @@ public RoutedBoltConnectionProvider(
this.resolver = Objects.requireNonNull(resolver);
this.logging = Objects.requireNonNull(logging);
this.log = logging.getLog(getClass());
this.loadBalancingStrategy = new LeastConnectedLoadBalancingStrategy(
(addr) -> {
synchronized (this) {
return addressToInUseCount.getOrDefault(address, 0);
}
},
logging);
this.loadBalancingStrategy = new LeastConnectedLoadBalancingStrategy(this::getInUseCount, logging);
this.domainNameResolver = Objects.requireNonNull(domainNameResolver);
this.routingTablePurgeDelayMs = routingTablePurgeDelayMs;
this.rediscovery = rediscovery;
this.clock = Objects.requireNonNull(clock);
}

@Override
public CompletionStage<Void> init(
public synchronized CompletionStage<Void> init(
BoltServerAddress address,
RoutingContext routingContext,
BoltAgent boltAgent,
String userAgent,
int connectTimeoutMillis,
MetricsListener metricsListener) {
this.address = address;
this.routingContext = routingContext;
this.boltAgent = boltAgent;
this.userAgent = userAgent;
Expand All @@ -154,10 +142,12 @@ public CompletionStage<BoltConnection> connect(
BoltProtocolVersion minVersion,
NotificationConfig notificationConfig,
Consumer<DatabaseName> databaseNameConsumer) {
RoutingTableRegistry registry;
synchronized (this) {
if (closeFuture != null) {
return CompletableFuture.failedFuture(new IllegalStateException("Connection provider is closed."));
}
registry = this.registry;
}

var handlerRef = new AtomicReference<RoutingTableHandler>();
Expand Down Expand Up @@ -196,6 +186,10 @@ public CompletionStage<BoltConnection> connect(

@Override
public CompletionStage<Void> verifyConnectivity(SecurityPlan securityPlan, Map<String, Value> authMap) {
RoutingTableRegistry registry;
synchronized (this) {
registry = this.registry;
}
return supportsMultiDb(securityPlan, authMap)
.thenCompose(supports -> registry.ensureRoutingTable(
securityPlan,
Expand Down Expand Up @@ -244,7 +238,7 @@ private synchronized void shutdownUnusedProviders(Set<BoltServerAddress> address
while (iterator.hasNext()) {
var entry = iterator.next();
var address = entry.getKey();
if (!addressesToRetain.contains(address) && addressToInUseCount.getOrDefault(address, 0) == 0) {
if (!addressesToRetain.contains(address) && getInUseCount(address) == 0) {
entry.getValue().close();
iterator.remove();
}
Expand All @@ -256,8 +250,12 @@ private CompletionStage<Boolean> detectFeature(
Map<String, Value> authMap,
String baseErrorMessagePrefix,
Function<BoltConnection, Boolean> featureDetectionFunction) {
List<BoltServerAddress> addresses;
Rediscovery rediscovery;
synchronized (this) {
rediscovery = this.rediscovery;
}

List<BoltServerAddress> addresses;
try {
addresses = rediscovery.resolve();
} catch (Throwable error) {
Expand Down Expand Up @@ -390,11 +388,7 @@ private void acquire(
result.completeExceptionally(error);
}
} else {
synchronized (this) {
var inUse = addressToInUseCount.getOrDefault(address, 0);
inUse++;
addressToInUseCount.put(address, inUse);
}
incrementInUseCount(address);
result.complete(connection);
}
});
Expand All @@ -414,43 +408,52 @@ private static List<BoltServerAddress> getAddressesByMode(AccessMode mode, Routi
};
}

synchronized void decreaseCount(BoltServerAddress address) {
var inUse = addressToInUseCount.get(address);
if (inUse != null) {
inUse--;
if (inUse <= 0) {
addressToInUseCount.remove(address);
private synchronized int getInUseCount(BoltServerAddress address) {
return addressToInUseCount.getOrDefault(address, 0);
}

private synchronized void incrementInUseCount(BoltServerAddress address) {
addressToInUseCount.merge(address, 1, Integer::sum);
}

synchronized void decrementInUseCount(BoltServerAddress address) {
addressToInUseCount.compute(address, (ignored, value) -> {
if (value == null) {
return null;
} else {
addressToInUseCount.put(address, inUse);
value--;
return value > 0 ? value : null;
}
}
});
}

@Override
public CompletionStage<Void> close() {
CompletableFuture<Void> closeFuture;
synchronized (this) {
if (this.closeFuture == null) {
var futures = executeWithLock(lock, () -> addressToProvider.values().stream()
.map(BoltConnectionProvider::close)
.map(CompletionStage::toCompletableFuture)
.toArray(CompletableFuture[]::new));
@SuppressWarnings({"rawtypes", "RedundantSuppression"})
var futures = new CompletableFuture[addressToProvider.size()];
var iterator = addressToProvider.values().iterator();
var index = 0;
while (iterator.hasNext()) {
futures[index++] = iterator.next().close().toCompletableFuture();
iterator.remove();
}
this.closeFuture = CompletableFuture.allOf(futures);
}
closeFuture = this.closeFuture;
}
return closeFuture;
}

private BoltConnectionProvider get(BoltServerAddress address) {
return executeWithLock(lock, () -> {
var provider = addressToProvider.get(address);
if (provider == null) {
provider = boltConnectionProviderSupplier.get();
provider.init(address, routingContext, boltAgent, userAgent, connectTimeoutMillis, metricsListener);
addressToProvider.put(address, provider);
}
return provider;
});
private synchronized BoltConnectionProvider get(BoltServerAddress address) {
var provider = addressToProvider.get(address);
if (provider == null) {
provider = boltConnectionProviderSupplier.get();
provider.init(address, routingContext, boltAgent, userAgent, connectTimeoutMillis, metricsListener);
addressToProvider.put(address, provider);
}
return provider;
}
}

0 comments on commit 5a7afbb

Please sign in to comment.