diff --git a/changelog/@unreleased/pr-794.v2.yml b/changelog/@unreleased/pr-794.v2.yml new file mode 100644 index 000000000..013c4ed00 --- /dev/null +++ b/changelog/@unreleased/pr-794.v2.yml @@ -0,0 +1,6 @@ +type: feature +feature: + description: Balanced channel now biases towards whichever node has the lowest latency, + which should reduce AWS spend by routing requests within AZ. + links: + - https://github.com/palantir/dialogue/pull/794 diff --git a/dialogue-core/src/main/java/com/palantir/dialogue/core/BalancedNodeSelectionStrategyChannel.java b/dialogue-core/src/main/java/com/palantir/dialogue/core/BalancedNodeSelectionStrategyChannel.java index a0ed1c400..d75f9a15c 100644 --- a/dialogue-core/src/main/java/com/palantir/dialogue/core/BalancedNodeSelectionStrategyChannel.java +++ b/dialogue-core/src/main/java/com/palantir/dialogue/core/BalancedNodeSelectionStrategyChannel.java @@ -41,12 +41,14 @@ import java.util.Random; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.IntStream; +import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Chooses nodes based on stats about each channel, i.e. how many requests are currently - * being served and also how many failures have been seen in the last few seconds. + * being served, how many failures have been seen in the last few seconds and (optionally) also what the best latency + * to each node is. Use {@link RttSampling#ENABLED} to switch this on. * * This is intended to be a strict improvement over Round Robin and Random Selection which can leave fast servers * underutilized, as it sends the same number to both a slow and fast node. It is *not* appropriate for transactional @@ -60,19 +62,34 @@ final class BalancedNodeSelectionStrategyChannel implements LimitedChannel { private static final Duration FAILURE_MEMORY = Duration.ofSeconds(30); private static final double FAILURE_WEIGHT = 10; + /** + * RTT_WEIGHT determines how sticky we are to the physically nearest node (as measured by RTT). If this is set + * too high, then we may deliver suboptimal perf by sending all requests to a slow node that is physically nearby, + * when there's actually a faster one further away. + * If this is too low, then we may prematurely spill across AZs and pay the $ cost. Keep this lower than + * {@link #FAILURE_WEIGHT} to ensure that a single 5xx makes a nearby node less attractive than a faraway node + * that exhibited zero failures. + */ + private static final double RTT_WEIGHT = 3; + private final ImmutableList channels; private final Random random; private final Ticker clock; + @Nullable + private final RttSampler rttSampler; + BalancedNodeSelectionStrategyChannel( ImmutableList channels, Random random, Ticker ticker, TaggedMetricRegistry taggedMetrics, - String channelName) { + String channelName, + RttSampling samplingEnabled) { Preconditions.checkState(channels.size() >= 2, "At least two channels required"); this.random = random; this.clock = ticker; + this.rttSampler = samplingEnabled == RttSampling.DEFAULT_OFF ? null : new RttSampler(channels, clock); this.channels = IntStream.range(0, channels.size()) .mapToObj(index -> new MutableChannelWithStats( channels.get(index), @@ -84,6 +101,11 @@ final class BalancedNodeSelectionStrategyChannel implements LimitedChannel { log.debug("Initialized", SafeArg.of("count", channels.size()), UnsafeArg.of("channels", channels)); } + enum RttSampling { + DEFAULT_OFF, + ENABLED + } + @Override public Optional> maybeExecute(Endpoint endpoint, Request request) { // pre-shuffling is pretty important here, otherwise when there are no requests in flight, we'd @@ -92,11 +114,15 @@ public Optional> maybeExecute(Endpoint endpoint, Requ // TODO(dfox): P2C optimization when we have high number of nodes to save CPU? // http://www.eecs.harvard.edu/~michaelm/NEWWORK/postscripts/twosurvey.pdf - SortableChannel[] sortedChannels = sortByScore(preShuffled); + SortableChannel[] sortableChannels = computeScores(rttSampler, preShuffled); + Arrays.sort(sortableChannels, BY_SCORE); - for (SortableChannel channel : sortedChannels) { + for (SortableChannel channel : sortableChannels) { Optional> maybe = channel.delegate.maybeExecute(endpoint, request); if (maybe.isPresent()) { + if (rttSampler != null) { + rttSampler.maybeSampleRtts(); + } return maybe; } } @@ -104,13 +130,19 @@ public Optional> maybeExecute(Endpoint endpoint, Requ return Optional.empty(); } - private static SortableChannel[] sortByScore(List preShuffled) { - SortableChannel[] sorted = new SortableChannel[preShuffled.size()]; - for (int i = 0; i < preShuffled.size(); i++) { - sorted[i] = preShuffled.get(i).computeScore(); + private static SortableChannel[] computeScores( + @Nullable RttSampler rttSampler, List chans) { + // if the feature is disabled (i.e. RttSampling.DEFAULT_OFF), then we just consider every host to have a + // rttSpectrum of '0' + float[] rttSpectrums = rttSampler != null ? rttSampler.computeRttSpectrums() : new float[chans.size()]; + + SortableChannel[] snapshotArray = new SortableChannel[chans.size()]; + for (int i = 0; i < chans.size(); i++) { + MutableChannelWithStats channel = chans.get(i); + float rttSpectrum = rttSpectrums[i]; + snapshotArray[i] = channel.computeScore(rttSpectrum); } - Arrays.sort(sorted, BY_SCORE); - return sorted; + return snapshotArray; } /** Returns a new shuffled list, without mutating the input list (which may be immutable). */ @@ -121,8 +153,8 @@ private static List shuffleImmutableList(ImmutableList sourceList, Ran } @VisibleForTesting - IntStream getScores() { - return channels.stream().mapToInt(c -> c.computeScore().score); + IntStream getScoresForTesting() { + return Arrays.stream(computeScores(rttSampler, channels)).mapToInt(SortableChannel::getScore); } @Override @@ -189,30 +221,32 @@ public Optional> maybeExecute(Endpoint endpoint, Requ return maybe; } - SortableChannel computeScore() { + SortableChannel computeScore(float rttSpectrum) { int requestsInflight = inflight.get(); double failureReservoir = recentFailuresReservoir.get(); // it's important that scores are integers because if we kept the full double precision, then a single 4xx // would end up influencing host selection long beyond its intended lifespan in the absence of other data. - int score = requestsInflight + Ints.saturatedCast(Math.round(failureReservoir)); + int score = requestsInflight + + Ints.saturatedCast(Math.round(failureReservoir)) + + Ints.saturatedCast(Math.round(rttSpectrum * RTT_WEIGHT)); - observability.debugLogComputedScore(requestsInflight, failureReservoir, score); + observability.debugLogComputedScore(requestsInflight, failureReservoir, rttSpectrum, score); return new SortableChannel(score, this); } @Override public String toString() { - return "MutableChannelWithStats{score=" + computeScore().score + return "MutableChannelWithStats{" + + "delegate=" + delegate + ", inflight=" + inflight + ", recentFailures=" + recentFailuresReservoir - + ", delegate=" + delegate + '}'; } } /** - * A dedicated immutable class ensures safe sorting, as otherwise there's a risk that the inflight AtomicInteger + * A dedicated value class ensures safe sorting, as otherwise there's a risk that the inflight AtomicInteger * might change mid-sort, leading to undefined behaviour. */ private static final class SortableChannel { @@ -252,7 +286,7 @@ private static void registerGauges( DialogueInternalWeakReducingGauge.getOrCreate( taggedMetrics, metricName, - c -> c.computeScore().getScore(), + c -> c.computeScore(0).getScore(), longStream -> { long[] longs = longStream.toArray(); if (log.isInfoEnabled() && longs.length > 1) { @@ -301,7 +335,7 @@ void debugLogStatusFailure(Response response) { } } - void debugLogComputedScore(int inflight, double failures, int score) { + void debugLogComputedScore(int inflight, double failures, float rttSpectrum, int score) { if (log.isDebugEnabled()) { log.debug( "Computed score", @@ -309,7 +343,8 @@ void debugLogComputedScore(int inflight, double failures, int score) { hostIndex, SafeArg.of("score", score), SafeArg.of("inflight", inflight), - SafeArg.of("failures", failures)); + SafeArg.of("failures", failures), + SafeArg.of("rttSpectrum", rttSpectrum)); } } diff --git a/dialogue-core/src/main/java/com/palantir/dialogue/core/DialogueNodeSelectionStrategy.java b/dialogue-core/src/main/java/com/palantir/dialogue/core/DialogueNodeSelectionStrategy.java index 5261d51aa..94218ec43 100644 --- a/dialogue-core/src/main/java/com/palantir/dialogue/core/DialogueNodeSelectionStrategy.java +++ b/dialogue-core/src/main/java/com/palantir/dialogue/core/DialogueNodeSelectionStrategy.java @@ -16,6 +16,7 @@ package com.palantir.dialogue.core; +import com.google.common.annotations.Beta; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; @@ -35,6 +36,8 @@ enum DialogueNodeSelectionStrategy { PIN_UNTIL_ERROR, PIN_UNTIL_ERROR_WITHOUT_RESHUFFLE, BALANCED, + @Beta + BALANCED_RTT, UNKNOWN; private static final Logger log = LoggerFactory.getLogger(DialogueNodeSelectionStrategy.class); @@ -45,17 +48,25 @@ static List fromHeader(String header) { Lists.transform(SPLITTER.splitToList(header), DialogueNodeSelectionStrategy::safeValueOf)); } - private static DialogueNodeSelectionStrategy safeValueOf(String value) { - String normalizedValue = value.toUpperCase(); - if (PIN_UNTIL_ERROR.name().equals(normalizedValue)) { - return PIN_UNTIL_ERROR; - } else if (PIN_UNTIL_ERROR_WITHOUT_RESHUFFLE.name().equals(normalizedValue)) { - return PIN_UNTIL_ERROR_WITHOUT_RESHUFFLE; - } else if (BALANCED.name().equals(normalizedValue)) { - return BALANCED; + /** + * We allow server-determined headers to access some incubating dialogue-specific strategies (e.g. BALANCED_RTT) + * which users can't normally configure. + */ + private static DialogueNodeSelectionStrategy safeValueOf(String string) { + String uppercaseString = string.toUpperCase(); + + switch (uppercaseString) { + case "PIN_UNTIL_ERROR": + return PIN_UNTIL_ERROR; + case "PIN_UNTIL_ERROR_WITHOUT_RESHUFFLE": + return PIN_UNTIL_ERROR_WITHOUT_RESHUFFLE; + case "BALANCED": + return BALANCED; + case "BALANCED_RTT": + return BALANCED_RTT; } - log.info("Received unknown selection strategy {}", SafeArg.of("strategy", value)); + log.info("Received unknown selection strategy {}", SafeArg.of("strategy", string)); return UNKNOWN; } diff --git a/dialogue-core/src/main/java/com/palantir/dialogue/core/NodeSelectionStrategyChannel.java b/dialogue-core/src/main/java/com/palantir/dialogue/core/NodeSelectionStrategyChannel.java index 461e56c39..aad2ee31c 100644 --- a/dialogue-core/src/main/java/com/palantir/dialogue/core/NodeSelectionStrategyChannel.java +++ b/dialogue-core/src/main/java/com/palantir/dialogue/core/NodeSelectionStrategyChannel.java @@ -24,6 +24,7 @@ import com.palantir.dialogue.Endpoint; import com.palantir.dialogue.Request; import com.palantir.dialogue.Response; +import com.palantir.dialogue.core.BalancedNodeSelectionStrategyChannel.RttSampling; import com.palantir.logsafe.SafeArg; import com.palantir.logsafe.exceptions.SafeRuntimeException; import com.palantir.tritium.metrics.registry.TaggedMetricRegistry; @@ -140,7 +141,13 @@ private NodeSelectionChannel createNodeSelectionChannel( // When people ask for 'ROUND_ROBIN', they usually just want something to load balance better. // We used to have a naive RoundRobinChannel, then tried RandomSelection and now use this heuristic: return channelBuilder - .channel(new BalancedNodeSelectionStrategyChannel(channels, random, tick, metrics, channelName)) + .channel(new BalancedNodeSelectionStrategyChannel( + channels, random, tick, metrics, channelName, RttSampling.DEFAULT_OFF)) + .build(); + case BALANCED_RTT: + return channelBuilder + .channel(new BalancedNodeSelectionStrategyChannel( + channels, random, tick, metrics, channelName, RttSampling.ENABLED)) .build(); case UNKNOWN: } diff --git a/dialogue-core/src/main/java/com/palantir/dialogue/core/RttSampler.java b/dialogue-core/src/main/java/com/palantir/dialogue/core/RttSampler.java new file mode 100644 index 000000000..237c985b2 --- /dev/null +++ b/dialogue-core/src/main/java/com/palantir/dialogue/core/RttSampler.java @@ -0,0 +1,281 @@ +/* + * (c) Copyright 2020 Palantir Technologies Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.palantir.dialogue.core; + +import com.github.benmanes.caffeine.cache.Ticker; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; +import com.palantir.conjure.java.api.config.service.UserAgent; +import com.palantir.conjure.java.api.config.service.UserAgent.Agent; +import com.palantir.conjure.java.api.config.service.UserAgents; +import com.palantir.dialogue.Endpoint; +import com.palantir.dialogue.HttpMethod; +import com.palantir.dialogue.Request; +import com.palantir.dialogue.UrlBuilder; +import com.palantir.logsafe.Preconditions; +import com.palantir.logsafe.SafeArg; +import java.io.Closeable; +import java.time.Duration; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import javax.annotation.concurrent.ThreadSafe; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +final class RttSampler { + private static final Logger log = LoggerFactory.getLogger(RttSampler.class); + private static final String USER_AGENT = + UserAgents.format(UserAgent.of(Agent.of(RttEndpoint.INSTANCE.serviceName(), RttEndpoint.INSTANCE.version())) + .addAgent(UserAgentEndpointChannel.DIALOGUE_AGENT)); + + private final ImmutableList channels; + private final RttMeasurement[] rtts; + private final RttMeasurementRateLimiter rateLimiter; + private final Ticker clock; + + RttSampler(ImmutableList channels, Ticker clock) { + this.channels = channels; + this.rateLimiter = new RttMeasurementRateLimiter(clock); + this.clock = clock; + this.rtts = IntStream.range(0, channels.size()) + .mapToObj(_i -> new RttMeasurement()) + .toArray(RttMeasurement[]::new); + } + + /** + * Latency (rtt) is measured in nanos, which is a tricky unit to include in our 'score' because adding + * it would dominate all the other data (which has the unit of 'num requests'). To avoid the need for a + * conversion fudge-factor, we instead figure out where each rtt lies on the spectrum from bestRttNanos + * to worstRttNanos, with 0 being best and 1 being worst. This ensures that if several nodes are all + * within the same AZ and can return in ~1 ms but others return in ~5ms, the 1ms nodes will all have + * a similar rttScore (near zero). Note, this can only be computed when we have all the snapshots in + * front of us. + */ + float[] computeRttSpectrums() { + long bestRttNanos = Long.MAX_VALUE; + long worstRttNanos = 0; + + // first we take a snapshot of all channels' RTT + OptionalLong[] snapshots = new OptionalLong[rtts.length]; + for (int i = 0; i < rtts.length; i++) { + OptionalLong rtt = rtts[i].getRttNanos(); + snapshots[i] = rtt; + + if (rtt.isPresent()) { + bestRttNanos = Math.min(bestRttNanos, rtt.getAsLong()); + worstRttNanos = Math.max(worstRttNanos, rtt.getAsLong()); + } + } + + // given the best & worst values, we can then compute the spectrums + float[] spectrums = new float[rtts.length]; + long rttRange = worstRttNanos - bestRttNanos; + if (rttRange <= 0) { + return spectrums; + } + + for (int i = 0; i < channels.size(); i++) { + OptionalLong rtt = snapshots[i]; + float rttSpectrum = rtt.isPresent() ? ((float) rtt.getAsLong() - bestRttNanos) / rttRange : 0; + Preconditions.checkState( + 0 <= rttSpectrum && rttSpectrum <= 1, + "rttSpectrum must be between 0 and 1", + SafeArg.of("value", rttSpectrum), + SafeArg.of("hostIndex", i)); + spectrums[i] = rttSpectrum; + } + + return spectrums; + } + + /** + * Non-blocking - should return pretty much instantly. + */ + void maybeSampleRtts() { + Optional maybePermit = rateLimiter.tryAcquire(); + if (!maybePermit.isPresent()) { + return; + } + + Request rttRequest = Request.builder() + // necessary as we've already gone through the UserAgentEndpointChannel + .putHeaderParams("user-agent", USER_AGENT) + .build(); + + List> futures = IntStream.range(0, channels.size()) + .mapToObj(i -> { + long before = clock.read(); + return channels.get(i) + .maybeExecute(RttEndpoint.INSTANCE, rttRequest) + .map(future -> Futures.transform( + future, + response -> { + long durationNanos = clock.read() - before; + rtts[i].addMeasurement(durationNanos); + response.close(); + return durationNanos; + }, + MoreExecutors.directExecutor())) + .orElseGet(() -> Futures.immediateFuture(Long.MAX_VALUE)); + }) + .collect(ImmutableList.toImmutableList()); + + DialogueFutures.addDirectCallback(Futures.allAsList(futures), new FutureCallback>() { + @Override + public void onSuccess(List result) { + maybePermit.get().close(); + + if (log.isDebugEnabled()) { + List millis = + result.stream().map(TimeUnit.NANOSECONDS::toMillis).collect(Collectors.toList()); + long[] best = Arrays.stream(rtts) + .mapToLong(rtt -> rtt.getRttNanos().orElse(Long.MAX_VALUE)) + .toArray(); + log.debug( + "RTTs {} {} {}", + SafeArg.of("nanos", result), + SafeArg.of("millis", millis), + SafeArg.of("best", Arrays.toString(best))); + } + } + + @Override + public void onFailure(Throwable throwable) { + maybePermit.get().close(); + log.info("Failed to sample RTT for channels", throwable); + } + }); + } + + @VisibleForTesting + enum RttEndpoint implements Endpoint { + INSTANCE; + + @Override + public void renderPath(Map _params, UrlBuilder _url) {} + + @Override + public HttpMethod httpMethod() { + return HttpMethod.OPTIONS; + } + + @Override + public String serviceName() { + return "RttSampler"; + } + + @Override + public String endpointName() { + return "rtt"; + } + + @Override + public String version() { + return UserAgentEndpointChannel.dialogueVersion(); + } + } + + /** + * Always returns the *minimum* value from the last few samples, so that we exclude slow calls that might include + * TLS handshakes. + */ + @VisibleForTesting + @ThreadSafe + static final class RttMeasurement { + private static final int NUM_MEASUREMENTS = 5; + + private final long[] samples; + private volatile long bestRttNanos = Long.MAX_VALUE; + + RttMeasurement() { + samples = new long[NUM_MEASUREMENTS]; + Arrays.fill(samples, Long.MAX_VALUE); + } + + public OptionalLong getRttNanos() { + return bestRttNanos == Long.MAX_VALUE ? OptionalLong.empty() : OptionalLong.of(bestRttNanos); + } + + synchronized void addMeasurement(long newMeasurement) { + Preconditions.checkArgument(newMeasurement > 0, "Must be greater than zero"); + Preconditions.checkArgument(newMeasurement < Long.MAX_VALUE, "Must be less than MAX_VALUE"); + + if (samples[0] == Long.MAX_VALUE) { + Arrays.fill(samples, newMeasurement); + bestRttNanos = newMeasurement; + } else { + System.arraycopy(samples, 1, samples, 0, NUM_MEASUREMENTS - 1); + samples[NUM_MEASUREMENTS - 1] = newMeasurement; + bestRttNanos = Arrays.stream(samples).min().getAsLong(); + } + } + + @Override + public String toString() { + return "RttMeasurement{" + "bestRttNanos=" + bestRttNanos + ", samples=" + Arrays.toString(samples) + '}'; + } + } + + private static final class RttMeasurementRateLimiter { + private static final long BETWEEN_SAMPLES = Duration.ofSeconds(1).toNanos(); + + private final Ticker clock; + private final AtomicBoolean currentlySampling = new AtomicBoolean(false); + private volatile long lastMeasured = 0; + + @SuppressWarnings("UnnecessaryLambda") // just let me avoid allocations + private final RttMeasurementPermit finishedSampling = () -> currentlySampling.set(false); + + private RttMeasurementRateLimiter(Ticker clock) { + this.clock = clock; + } + + /** + * The RttSamplePermit ensures that if a server black-holes one of our OPTIONS requests, we don't kick off + * more and more and more requests and eventually exhaust a threadpool. Permit is released in a future callback. + */ + Optional tryAcquire() { + if (lastMeasured + BETWEEN_SAMPLES > clock.read()) { + return Optional.empty(); + } + + if (!currentlySampling.get() && currentlySampling.compareAndSet(false, true)) { + lastMeasured = clock.read(); + return Optional.of(finishedSampling); + } else { + log.warn("Wanted to sample RTTs but an existing sample was still in progress"); + return Optional.empty(); + } + } + } + + private interface RttMeasurementPermit extends Closeable { + @Override + void close(); + } +} diff --git a/dialogue-core/src/main/java/com/palantir/dialogue/core/UserAgentEndpointChannel.java b/dialogue-core/src/main/java/com/palantir/dialogue/core/UserAgentEndpointChannel.java index 0a116dd79..46a5a5f88 100644 --- a/dialogue-core/src/main/java/com/palantir/dialogue/core/UserAgentEndpointChannel.java +++ b/dialogue-core/src/main/java/com/palantir/dialogue/core/UserAgentEndpointChannel.java @@ -31,7 +31,7 @@ * {@link Endpoint}'s target service and endpoint. */ final class UserAgentEndpointChannel implements EndpointChannel { - private static final UserAgent.Agent DIALOGUE_AGENT = extractDialogueAgent(); + static final UserAgent.Agent DIALOGUE_AGENT = extractDialogueAgent(); private final EndpointChannel delegate; private final String userAgent; @@ -62,8 +62,13 @@ private static UserAgent augmentUserAgent(UserAgent baseAgent, Endpoint endpoint } private static UserAgent.Agent extractDialogueAgent() { + String version = dialogueVersion(); + return UserAgent.Agent.of("dialogue", version); + } + + static String dialogueVersion() { String maybeDialogueVersion = Channel.class.getPackage().getImplementationVersion(); - return UserAgent.Agent.of("dialogue", maybeDialogueVersion != null ? maybeDialogueVersion : "0.0.0"); + return maybeDialogueVersion != null ? maybeDialogueVersion : "0.0.0"; } @Override diff --git a/dialogue-core/src/test/java/com/palantir/dialogue/core/BalancedNodeSelectionStrategyChannelTest.java b/dialogue-core/src/test/java/com/palantir/dialogue/core/BalancedNodeSelectionStrategyChannelTest.java index 9086c290c..e645978cc 100644 --- a/dialogue-core/src/test/java/com/palantir/dialogue/core/BalancedNodeSelectionStrategyChannelTest.java +++ b/dialogue-core/src/test/java/com/palantir/dialogue/core/BalancedNodeSelectionStrategyChannelTest.java @@ -18,6 +18,7 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -32,6 +33,7 @@ import com.palantir.dialogue.Response; import com.palantir.dialogue.TestEndpoint; import com.palantir.dialogue.TestResponse; +import com.palantir.dialogue.core.BalancedNodeSelectionStrategyChannel.RttSampling; import com.palantir.tritium.metrics.registry.DefaultTaggedMetricRegistry; import java.time.Duration; import java.util.Optional; @@ -57,6 +59,7 @@ class BalancedNodeSelectionStrategyChannelTest { private Endpoint endpoint = TestEndpoint.GET; private BalancedNodeSelectionStrategyChannel channel; + private BalancedNodeSelectionStrategyChannel rttChannel; @Mock Ticker clock; @@ -64,7 +67,19 @@ class BalancedNodeSelectionStrategyChannelTest { @BeforeEach public void before() { channel = new BalancedNodeSelectionStrategyChannel( - ImmutableList.of(chan1, chan2), random, clock, new DefaultTaggedMetricRegistry(), "channelName"); + ImmutableList.of(chan1, chan2), + random, + clock, + new DefaultTaggedMetricRegistry(), + "channelName", + RttSampling.DEFAULT_OFF); + rttChannel = new BalancedNodeSelectionStrategyChannel( + ImmutableList.of(chan1, chan2), + random, + clock, + new DefaultTaggedMetricRegistry(), + "channelName", + RttSampling.ENABLED); } @Test @@ -76,8 +91,8 @@ void when_one_channel_is_in_use_prefer_the_other() { for (int i = 0; i < 200; i++) { channel.maybeExecute(endpoint, request); } - verify(chan1, times(199)).maybeExecute(any(), any()); - verify(chan2, times(1)).maybeExecute(any(), any()); + verify(chan1, times(199)).maybeExecute(eq(endpoint), any()); + verify(chan2, times(1)).maybeExecute(eq(endpoint), any()); } @Test @@ -88,8 +103,8 @@ void when_both_channels_are_free_we_get_roughly_fair_tiebreaking() { for (int i = 0; i < 200; i++) { channel.maybeExecute(endpoint, request); } - verify(chan1, times(99)).maybeExecute(any(), any()); - verify(chan2, times(101)).maybeExecute(any(), any()); + verify(chan1, times(99)).maybeExecute(eq(endpoint), any()); + verify(chan2, times(101)).maybeExecute(eq(endpoint), any()); } @Test @@ -98,8 +113,8 @@ void when_channels_refuse_try_all_then_give_up() { when(chan2.maybeExecute(any(), any())).thenReturn(Optional.empty()); assertThat(channel.maybeExecute(endpoint, request)).isNotPresent(); - verify(chan1, times(1)).maybeExecute(any(), any()); - verify(chan2, times(1)).maybeExecute(any(), any()); + verify(chan1, times(1)).maybeExecute(eq(endpoint), any()); + verify(chan2, times(1)).maybeExecute(eq(endpoint), any()); } @Test @@ -111,13 +126,13 @@ void a_single_4xx_doesnt_move_the_needle() { clock.read() < start + Duration.ofSeconds(10).toNanos(); incrementClockBy(Duration.ofMillis(50))) { channel.maybeExecute(endpoint, request); - assertThat(channel.getScores()) + assertThat(channel.getScoresForTesting()) .describedAs("A single 400 at the beginning isn't enough to impact scores", channel) .containsExactly(0, 0); } - verify(chan1, times(99)).maybeExecute(any(), any()); - verify(chan2, times(101)).maybeExecute(any(), any()); + verify(chan1, times(99)).maybeExecute(eq(endpoint), any()); + verify(chan2, times(101)).maybeExecute(eq(endpoint), any()); } @Test @@ -126,26 +141,96 @@ void constant_4xxs_do_eventually_move_the_needle_but_we_go_back_to_fair_distribu when(chan2.maybeExecute(any(), any())).thenReturn(http(200)); for (int i = 0; i < 11; i++) { - channel.maybeExecute(endpoint, request); - assertThat(channel.getScores()) - .describedAs("%s %s: Scores not affected yet %s", i, Duration.ofNanos(clock.read()), channel) + rttChannel.maybeExecute(endpoint, request); + assertThat(rttChannel.getScoresForTesting()) + .describedAs("%s %s: Scores not affected yet %s", i, Duration.ofNanos(clock.read()), rttChannel) .containsExactly(0, 0); incrementClockBy(Duration.ofMillis(50)); } - channel.maybeExecute(endpoint, request); - assertThat(channel.getScores()) - .describedAs("%s: Constant 4xxs did move the needle %s", Duration.ofNanos(clock.read()), channel) + rttChannel.maybeExecute(endpoint, request); + assertThat(rttChannel.getScoresForTesting()) + .describedAs("%s: Constant 4xxs did move the needle %s", Duration.ofNanos(clock.read()), rttChannel) .containsExactly(1, 0); incrementClockBy(Duration.ofSeconds(5)); - assertThat(channel.getScores()) + assertThat(rttChannel.getScoresForTesting()) .describedAs( "%s: We quickly forget about 4xxs and go back to fair shuffling %s", - Duration.ofNanos(clock.read()), channel) + Duration.ofNanos(clock.read()), rttChannel) .containsExactly(0, 0); } + @Test + void rtt_is_measured_and_can_influence_choices() { + incrementClockBy(Duration.ofHours(1)); + + // when(chan1.maybeExecute(eq(endpoint), any())).thenReturn(http(200)); + when(chan2.maybeExecute(eq(endpoint), any())).thenReturn(http(200)); + + SettableFuture chan1OptionsResponse = SettableFuture.create(); + SettableFuture chan2OptionsResponse = SettableFuture.create(); + RttSampler.RttEndpoint rttEndpoint = RttSampler.RttEndpoint.INSTANCE; + when(chan1.maybeExecute(eq(rttEndpoint), any())).thenReturn(Optional.of(chan1OptionsResponse)); + when(chan2.maybeExecute(eq(rttEndpoint), any())).thenReturn(Optional.of(chan2OptionsResponse)); + + rttChannel.maybeExecute(endpoint, request); + + incrementClockBy(Duration.ofNanos(123)); + chan1OptionsResponse.set(new TestResponse().code(200)); + + incrementClockBy(Duration.ofNanos(456)); + chan2OptionsResponse.set(new TestResponse().code(200)); + + assertThat(rttChannel.getScoresForTesting()) + .describedAs("The poor latency of channel2 imposes a small constant penalty in the score") + .containsExactly(0, 3); + + for (int i = 0; i < 500; i++) { + incrementClockBy(Duration.ofMillis(10)); + rttChannel.maybeExecute(endpoint, request); + } + // rate limiter ensures a sensible amount of rtt sampling + verify(chan1, times(6)).maybeExecute(eq(rttEndpoint), any()); + verify(chan2, times(6)).maybeExecute(eq(rttEndpoint), any()); + } + + @Test + void when_rtt_measurements_are_limited_dont_freak_out() { + incrementClockBy(Duration.ofHours(1)); + + // when(chan1.maybeExecute(eq(endpoint), any())).thenReturn(http(200)); + when(chan2.maybeExecute(eq(endpoint), any())).thenReturn(http(200)); + + RttSampler.RttEndpoint rttEndpoint = RttSampler.RttEndpoint.INSTANCE; + when(chan1.maybeExecute(eq(rttEndpoint), any())).thenReturn(Optional.empty()); + when(chan2.maybeExecute(eq(rttEndpoint), any())).thenReturn(Optional.empty()); + + rttChannel.maybeExecute(endpoint, request); + + assertThat(channel.getScoresForTesting()).containsExactly(0, 0); + } + + @Test + void when_rtt_measurements_havent_returned_yet_consider_both_far_away() { + incrementClockBy(Duration.ofHours(1)); + // when(chan1.maybeExecute(eq(endpoint), any())).thenReturn(http(200)); + when(chan2.maybeExecute(eq(endpoint), any())).thenReturn(http(200)); + + RttSampler.RttEndpoint rttEndpoint = RttSampler.RttEndpoint.INSTANCE; + when(chan1.maybeExecute(eq(rttEndpoint), any())).thenReturn(Optional.of(SettableFuture.create())); + when(chan2.maybeExecute(eq(rttEndpoint), any())).thenReturn(Optional.of(SettableFuture.create())); + + for (int i = 0; i < 20; i++) { + incrementClockBy(Duration.ofSeconds(5)); + rttChannel.maybeExecute(endpoint, request); + } + + assertThat(rttChannel.getScoresForTesting()).containsExactly(0, 0); + verify(chan1, times(1)).maybeExecute(eq(rttEndpoint), any()); + verify(chan2, times(1)).maybeExecute(eq(rttEndpoint), any()); + } + private static void set200(LimitedChannel chan) { when(chan.maybeExecute(any(), any())).thenReturn(http(200)); } diff --git a/dialogue-core/src/test/java/com/palantir/dialogue/core/DialogueNodeSelectionStrategyTest.java b/dialogue-core/src/test/java/com/palantir/dialogue/core/DialogueNodeSelectionStrategyTest.java index 91b11a2fb..757ebcb9e 100644 --- a/dialogue-core/src/test/java/com/palantir/dialogue/core/DialogueNodeSelectionStrategyTest.java +++ b/dialogue-core/src/test/java/com/palantir/dialogue/core/DialogueNodeSelectionStrategyTest.java @@ -32,6 +32,10 @@ void parses_single_strategy() { void parses_multiple_strategies() { assertThat(DialogueNodeSelectionStrategy.fromHeader("BALANCED, PIN_UNTIL_ERROR")) .containsExactly(DialogueNodeSelectionStrategy.BALANCED, DialogueNodeSelectionStrategy.PIN_UNTIL_ERROR); + assertThat(DialogueNodeSelectionStrategy.fromHeader("BALANCED_RTT, BALANCED")) + .containsExactly(DialogueNodeSelectionStrategy.BALANCED_RTT, DialogueNodeSelectionStrategy.BALANCED); + assertThat(DialogueNodeSelectionStrategy.fromHeader("BALANCED_FUTURE_EXPERIMENT, BALANCED")) + .containsExactly(DialogueNodeSelectionStrategy.UNKNOWN, DialogueNodeSelectionStrategy.BALANCED); } @Test diff --git a/dialogue-core/src/test/java/com/palantir/dialogue/core/RttSamplerTest.java b/dialogue-core/src/test/java/com/palantir/dialogue/core/RttSamplerTest.java new file mode 100644 index 000000000..00ccb1be9 --- /dev/null +++ b/dialogue-core/src/test/java/com/palantir/dialogue/core/RttSamplerTest.java @@ -0,0 +1,43 @@ +/* + * (c) Copyright 2020 Palantir Technologies Inc. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.palantir.dialogue.core; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +class RttSamplerTest { + + @Test + void rtt_returns_the_min_of_the_last_5_measurements() { + RttSampler.RttMeasurement rtt = new RttSampler.RttMeasurement(); + rtt.addMeasurement(3); + assertThat(rtt.getRttNanos()).describedAs("%s", rtt).hasValue(3); + rtt.addMeasurement(1); + rtt.addMeasurement(2); + assertThat(rtt.getRttNanos()).describedAs("%s", rtt).hasValue(1); + + rtt.addMeasurement(500); + assertThat(rtt.getRttNanos()).describedAs("%s", rtt).hasValue(1); + rtt.addMeasurement(500); + rtt.addMeasurement(500); + rtt.addMeasurement(500); + assertThat(rtt.getRttNanos()).describedAs("%s", rtt).hasValue(2); + rtt.addMeasurement(500); + assertThat(rtt.getRttNanos()).describedAs("%s", rtt).hasValue(500); + } +} diff --git a/dialogue-jmh/src/main/java/com/palantir/dialogue/core/NodeSelectionBenchmark.java b/dialogue-jmh/src/main/java/com/palantir/dialogue/core/NodeSelectionBenchmark.java index 8f3b0076a..34c4f9176 100644 --- a/dialogue-jmh/src/main/java/com/palantir/dialogue/core/NodeSelectionBenchmark.java +++ b/dialogue-jmh/src/main/java/com/palantir/dialogue/core/NodeSelectionBenchmark.java @@ -26,6 +26,7 @@ import com.palantir.dialogue.Response; import com.palantir.dialogue.TestEndpoint; import com.palantir.dialogue.TestResponse; +import com.palantir.dialogue.core.BalancedNodeSelectionStrategyChannel.RttSampling; import com.palantir.logsafe.exceptions.SafeIllegalArgumentException; import com.palantir.random.SafeThreadLocalRandom; import com.palantir.tritium.metrics.registry.DefaultTaggedMetricRegistry; @@ -123,8 +124,8 @@ public void before() { "channelName"); break; case ROUND_ROBIN: - channel = - new BalancedNodeSelectionStrategyChannel(channels, random, ticker, metrics, "channelName"); + channel = new BalancedNodeSelectionStrategyChannel( + channels, random, ticker, metrics, "channelName", RttSampling.DEFAULT_OFF); break; default: throw new SafeIllegalArgumentException("Unsupported"); diff --git a/simulation/src/main/java/com/palantir/dialogue/core/SimulationServer.java b/simulation/src/main/java/com/palantir/dialogue/core/SimulationServer.java index f7fedf8bf..452b53f3c 100644 --- a/simulation/src/main/java/com/palantir/dialogue/core/SimulationServer.java +++ b/simulation/src/main/java/com/palantir/dialogue/core/SimulationServer.java @@ -25,6 +25,7 @@ import com.google.common.util.concurrent.MoreExecutors; import com.palantir.dialogue.Channel; import com.palantir.dialogue.Endpoint; +import com.palantir.dialogue.HttpMethod; import com.palantir.dialogue.Request; import com.palantir.dialogue.Response; import com.palantir.dialogue.TestResponse; @@ -74,6 +75,10 @@ public static Builder builder() { @Override public ListenableFuture execute(Endpoint endpoint, Request request) { + if (endpoint.httpMethod() == HttpMethod.OPTIONS) { + return Futures.immediateFuture(new TestResponse().code(204)); + } + Meter perEndpointRequests = MetricNames.requestMeter(simulation.taggedMetrics(), serverName, endpoint); activeRequests.inc(); diff --git a/simulation/src/test/resources/all_nodes_500[CONCURRENCY_LIMITER_ROUND_ROBIN].png b/simulation/src/test/resources/all_nodes_500[CONCURRENCY_LIMITER_ROUND_ROBIN].png index 870cd83ac..d8a479f77 100644 --- a/simulation/src/test/resources/all_nodes_500[CONCURRENCY_LIMITER_ROUND_ROBIN].png +++ b/simulation/src/test/resources/all_nodes_500[CONCURRENCY_LIMITER_ROUND_ROBIN].png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5acd44d3f56e2a84023c8cda0f6f1da0e0848dd9c3b2b7bd0df936f6166cd27b -size 113221 +oid sha256:17e4b03f7618600213e834e93b52bc594fa9e96f5cf6220030e284d6250a6176 +size 113249 diff --git a/simulation/src/test/resources/one_big_spike[CONCURRENCY_LIMITER_ROUND_ROBIN].png b/simulation/src/test/resources/one_big_spike[CONCURRENCY_LIMITER_ROUND_ROBIN].png index a7e4baf1c..0972c02e4 100644 --- a/simulation/src/test/resources/one_big_spike[CONCURRENCY_LIMITER_ROUND_ROBIN].png +++ b/simulation/src/test/resources/one_big_spike[CONCURRENCY_LIMITER_ROUND_ROBIN].png @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2863197850e4d4573230ae2b4420e31baadf10421546cba011a3d271ca63834c -size 78310 +oid sha256:76b09c3a5e31760626407eb8f6c1f50b2091a4d2ca7efe5e7d7a84cd099f7601 +size 83448