Skip to content

Commit

Permalink
xds: Per-rpc rewriting of the authority header based on the selected …
Browse files Browse the repository at this point in the history
…route. (#11631)

Implementation of A81.
  • Loading branch information
kannanjgithub authored Oct 30, 2024
1 parent 3562380 commit c167ead
Show file tree
Hide file tree
Showing 28 changed files with 875 additions and 309 deletions.
33 changes: 33 additions & 0 deletions api/src/main/java/io/grpc/LoadBalancer.java
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ public static final class PickResult {
private final Status status;
// True if the result is created by withDrop()
private final boolean drop;
@Nullable private final String authorityOverride;

private PickResult(
@Nullable Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory,
Expand All @@ -560,6 +561,17 @@ private PickResult(
this.streamTracerFactory = streamTracerFactory;
this.status = checkNotNull(status, "status");
this.drop = drop;
this.authorityOverride = null;
}

private PickResult(
@Nullable Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory,
Status status, boolean drop, @Nullable String authorityOverride) {
this.subchannel = subchannel;
this.streamTracerFactory = streamTracerFactory;
this.status = checkNotNull(status, "status");
this.drop = drop;
this.authorityOverride = authorityOverride;
}

/**
Expand Down Expand Up @@ -639,6 +651,19 @@ public static PickResult withSubchannel(
false);
}

/**
* Same as {@code withSubchannel(subchannel, streamTracerFactory)} but with an authority name
* to override in the host header.
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11656")
public static PickResult withSubchannel(
Subchannel subchannel, @Nullable ClientStreamTracer.Factory streamTracerFactory,
@Nullable String authorityOverride) {
return new PickResult(
checkNotNull(subchannel, "subchannel"), streamTracerFactory, Status.OK,
false, authorityOverride);
}

/**
* Equivalent to {@code withSubchannel(subchannel, null)}.
*
Expand Down Expand Up @@ -682,6 +707,13 @@ public static PickResult withNoResult() {
return NO_RESULT;
}

/** Returns the authority override if any. */
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11656")
@Nullable
public String getAuthorityOverride() {
return authorityOverride;
}

/**
* The Subchannel if this result was created by {@link #withSubchannel withSubchannel()}, or
* null otherwise.
Expand Down Expand Up @@ -736,6 +768,7 @@ public String toString() {
.add("streamTracerFactory", streamTracerFactory)
.add("status", status)
.add("drop", drop)
.add("authority-override", authorityOverride)
.toString();
}

Expand Down
12 changes: 11 additions & 1 deletion core/src/main/java/io/grpc/internal/DelayedClientTransport.java
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,17 @@ public final ClientStream newStream(
}
if (state.lastPicker != null) {
PickResult pickResult = state.lastPicker.pickSubchannel(args);
callOptions = args.getCallOptions();
// User code provided authority takes precedence over the LB provided one.
if (callOptions.getAuthority() == null
&& pickResult.getAuthorityOverride() != null) {
callOptions = callOptions.withAuthority(pickResult.getAuthorityOverride());
}
ClientTransport transport = GrpcUtil.getTransportFromPickResult(pickResult,
callOptions.isWaitForReady());
if (transport != null) {
return transport.newStream(
args.getMethodDescriptor(), args.getHeaders(), args.getCallOptions(),
args.getMethodDescriptor(), args.getHeaders(), callOptions,
tracers);
}
}
Expand Down Expand Up @@ -281,6 +287,10 @@ final void reprocess(@Nullable SubchannelPicker picker) {
for (final PendingStream stream : toProcess) {
PickResult pickResult = picker.pickSubchannel(stream.args);
CallOptions callOptions = stream.args.getCallOptions();
// User code provided authority takes precedence over the LB provided one.
if (callOptions.getAuthority() == null && pickResult.getAuthorityOverride() != null) {
stream.setAuthority(pickResult.getAuthorityOverride());
}
final ClientTransport transport = GrpcUtil.getTransportFromPickResult(pickResult,
callOptions.isWaitForReady());
if (transport != null) {
Expand Down
1 change: 0 additions & 1 deletion core/src/main/java/io/grpc/internal/DelayedStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ private void delayOrExecute(Runnable runnable) {

@Override
public void setAuthority(final String authority) {
checkState(listener == null, "May only be called before start");
checkNotNull(authority, "authority");
preStartPendingCalls.add(new Runnable() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,43 @@ public void uncaughtException(Thread t, Throwable e) {
verify(transportListener).transportTerminated();
}

@Test
public void reprocess_authorityOverridePresentInCallOptions_authorityOverrideFromLbIsIgnored() {
DelayedStream delayedStream = (DelayedStream) delayedTransport.newStream(
method, headers, callOptions, tracers);
delayedStream.start(mock(ClientStreamListener.class));
SubchannelPicker picker = mock(SubchannelPicker.class);
PickResult pickResult = PickResult.withSubchannel(
mockSubchannel, null, "authority-override-hostname-from-lb");
when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(pickResult);

delayedTransport.reprocess(picker);
fakeExecutor.runDueTasks();

verify(mockRealStream, never()).setAuthority("authority-override-hostname-from-lb");
}

@Test
public void
reprocess_authorityOverrideNotInCallOptions_authorityOverrideFromLbIsSetIntoStream() {
DelayedStream delayedStream = (DelayedStream) delayedTransport.newStream(
method, headers, callOptions.withAuthority(null), tracers);
delayedStream.start(mock(ClientStreamListener.class));
SubchannelPicker picker = mock(SubchannelPicker.class);
PickResult pickResult = PickResult.withSubchannel(
mockSubchannel, null, "authority-override-hostname-from-lb");
when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(pickResult);
when(mockRealTransport.newStream(
same(method), same(headers), any(CallOptions.class),
ArgumentMatchers.any()))
.thenReturn(mockRealStream);

delayedTransport.reprocess(picker);
fakeExecutor.runDueTasks();

verify(mockRealStream).setAuthority("authority-override-hostname-from-lb");
}

@Test
public void reprocess_NoPendingStream() {
SubchannelPicker picker = mock(SubchannelPicker.class);
Expand All @@ -525,6 +562,55 @@ public void reprocess_NoPendingStream() {
assertSame(mockRealStream, stream);
}

@Test
public void newStream_assignsTransport_authorityFromCallOptionsSupersedesAuthorityFromLB() {
SubchannelPicker picker = mock(SubchannelPicker.class);
AbstractSubchannel subchannel = mock(AbstractSubchannel.class);
when(subchannel.getInternalSubchannel()).thenReturn(mockInternalSubchannel);
PickResult pickResult = PickResult.withSubchannel(
subchannel, null, "authority-override-hostname-from-lb");
when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(pickResult);
ArgumentCaptor<CallOptions> callOptionsArgumentCaptor =
ArgumentCaptor.forClass(CallOptions.class);
when(mockRealTransport.newStream(
any(MethodDescriptor.class), any(Metadata.class), callOptionsArgumentCaptor.capture(),
ArgumentMatchers.<ClientStreamTracer[]>any()))
.thenReturn(mockRealStream);
delayedTransport.reprocess(picker);
verifyNoMoreInteractions(picker);
verifyNoMoreInteractions(transportListener);

CallOptions callOptions =
CallOptions.DEFAULT.withAuthority("authority-override-hosstname-from-calloptions");
delayedTransport.newStream(method, headers, callOptions, tracers);
assertThat(callOptionsArgumentCaptor.getValue().getAuthority()).isEqualTo(
"authority-override-hosstname-from-calloptions");
}

@Test
public void newStream_assignsTransport_authorityFromLB() {
SubchannelPicker picker = mock(SubchannelPicker.class);
AbstractSubchannel subchannel = mock(AbstractSubchannel.class);
when(subchannel.getInternalSubchannel()).thenReturn(mockInternalSubchannel);
PickResult pickResult = PickResult.withSubchannel(
subchannel, null, "authority-override-hostname-from-lb");
when(picker.pickSubchannel(any(PickSubchannelArgs.class))).thenReturn(pickResult);
ArgumentCaptor<CallOptions> callOptionsArgumentCaptor =
ArgumentCaptor.forClass(CallOptions.class);
when(mockRealTransport.newStream(
any(MethodDescriptor.class), any(Metadata.class), callOptionsArgumentCaptor.capture(),
ArgumentMatchers.<ClientStreamTracer[]>any()))
.thenReturn(mockRealStream);
delayedTransport.reprocess(picker);
verifyNoMoreInteractions(picker);
verifyNoMoreInteractions(transportListener);

CallOptions callOptions = CallOptions.DEFAULT;
delayedTransport.newStream(method, headers, callOptions, tracers);
assertThat(callOptionsArgumentCaptor.getValue().getAuthority()).isEqualTo(
"authority-override-hostname-from-lb");
}

@Test
public void reprocess_newStreamRacesWithReprocess() throws Exception {
final CyclicBarrier barrier = new CyclicBarrier(2);
Expand Down
6 changes: 0 additions & 6 deletions core/src/test/java/io/grpc/internal/DelayedStreamTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,6 @@ public void setStream_setAuthority() {
inOrder.verify(realStream).start(any(ClientStreamListener.class));
}

@Test(expected = IllegalStateException.class)
public void setAuthority_afterStart() {
stream.start(listener);
stream.setAuthority("notgonnawork");
}

@Test(expected = IllegalStateException.class)
public void start_afterStart() {
stream.start(listener);
Expand Down
27 changes: 21 additions & 6 deletions xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.internal.ForwardingClientStreamTracer;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.ObjectPool;
import io.grpc.services.MetricReport;
import io.grpc.util.ForwardingLoadBalancerHelper;
Expand Down Expand Up @@ -231,10 +232,16 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) {
args.getAddresses().get(0).getAttributes());
AtomicReference<ClusterLocality> localityAtomicReference = new AtomicReference<>(
clusterLocality);
Attributes attrs = args.getAttributes().toBuilder()
.set(ATTR_CLUSTER_LOCALITY, localityAtomicReference)
.build();
args = args.toBuilder().setAddresses(addresses).setAttributes(attrs).build();
Attributes.Builder attrsBuilder = args.getAttributes().toBuilder()
.set(ATTR_CLUSTER_LOCALITY, localityAtomicReference);
if (GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", false)) {
String hostname = args.getAddresses().get(0).getAttributes()
.get(InternalXdsAttributes.ATTR_ADDRESS_NAME);
if (hostname != null) {
attrsBuilder.set(InternalXdsAttributes.ATTR_ADDRESS_NAME, hostname);
}
}
args = args.toBuilder().setAddresses(addresses).setAttributes(attrsBuilder.build()).build();
final Subchannel subchannel = delegate().createSubchannel(args);

return new ForwardingSubchannel() {
Expand Down Expand Up @@ -389,7 +396,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) {
Status.UNAVAILABLE.withDescription("Dropped: " + dropOverload.category()));
}
}
final PickResult result = delegate.pickSubchannel(args);
PickResult result = delegate.pickSubchannel(args);
if (result.getStatus().isOk() && result.getSubchannel() != null) {
if (enableCircuitBreaking) {
if (inFlights.get() >= maxConcurrentRequests) {
Expand All @@ -415,9 +422,17 @@ public PickResult pickSubchannel(PickSubchannelArgs args) {
stats, inFlights, result.getStreamTracerFactory());
ClientStreamTracer.Factory orcaTracerFactory = OrcaPerRequestUtil.getInstance()
.newOrcaClientStreamTracerFactory(tracerFactory, new OrcaPerRpcListener(stats));
return PickResult.withSubchannel(result.getSubchannel(), orcaTracerFactory);
result = PickResult.withSubchannel(result.getSubchannel(),
orcaTracerFactory);
}
}
if (args.getCallOptions().getOption(XdsNameResolver.AUTO_HOST_REWRITE_KEY) != null
&& args.getCallOptions().getOption(XdsNameResolver.AUTO_HOST_REWRITE_KEY)) {
result = PickResult.withSubchannel(result.getSubchannel(),
result.getStreamTracerFactory(),
result.getSubchannel().getAttributes().get(
InternalXdsAttributes.ATTR_ADDRESS_NAME));
}
}
return result;
}
Expand Down
10 changes: 9 additions & 1 deletion xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ public void run() {
.set(InternalXdsAttributes.ATTR_LOCALITY_WEIGHT,
localityLbInfo.localityWeight())
.set(InternalXdsAttributes.ATTR_SERVER_WEIGHT, weight)
.set(InternalXdsAttributes.ATTR_ADDRESS_NAME, endpoint.hostname())
.build();
EquivalentAddressGroup eag = new EquivalentAddressGroup(
endpoint.eag().getAddresses(), attr);
Expand Down Expand Up @@ -567,7 +568,7 @@ void start() {
handleEndpointResolutionError();
return;
}
resolver.start(new NameResolverListener());
resolver.start(new NameResolverListener(dnsHostName));
}

void refresh() {
Expand Down Expand Up @@ -606,6 +607,12 @@ public void run() {
}

private class NameResolverListener extends NameResolver.Listener2 {
private final String dnsHostName;

NameResolverListener(String dnsHostName) {
this.dnsHostName = dnsHostName;
}

@Override
public void onResult(final ResolutionResult resolutionResult) {
class NameResolved implements Runnable {
Expand All @@ -625,6 +632,7 @@ public void run() {
Attributes attr = eag.getAttributes().toBuilder()
.set(InternalXdsAttributes.ATTR_LOCALITY, LOGICAL_DNS_CLUSTER_LOCALITY)
.set(InternalXdsAttributes.ATTR_LOCALITY_NAME, localityName)
.set(InternalXdsAttributes.ATTR_ADDRESS_NAME, dnsHostName)
.build();
eag = new EquivalentAddressGroup(eag.getAddresses(), attr);
eag = AddressFilter.setPathFilter(eag, Arrays.asList(priorityName, localityName));
Expand Down
10 changes: 6 additions & 4 deletions xds/src/main/java/io/grpc/xds/Endpoints.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,19 @@ abstract static class LbEndpoint {
// Whether the endpoint is healthy.
abstract boolean isHealthy();

abstract String hostname();

static LbEndpoint create(EquivalentAddressGroup eag, int loadBalancingWeight,
boolean isHealthy) {
return new AutoValue_Endpoints_LbEndpoint(eag, loadBalancingWeight, isHealthy);
boolean isHealthy, String hostname) {
return new AutoValue_Endpoints_LbEndpoint(eag, loadBalancingWeight, isHealthy, hostname);
}

// Only for testing.
@VisibleForTesting
static LbEndpoint create(
String address, int port, int loadBalancingWeight, boolean isHealthy) {
String address, int port, int loadBalancingWeight, boolean isHealthy, String hostname) {
return LbEndpoint.create(new EquivalentAddressGroup(new InetSocketAddress(address, port)),
loadBalancingWeight, isHealthy);
loadBalancingWeight, isHealthy, hostname);
}
}

Expand Down
5 changes: 5 additions & 0 deletions xds/src/main/java/io/grpc/xds/InternalXdsAttributes.java
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ public final class InternalXdsAttributes {
static final Attributes.Key<Long> ATTR_SERVER_WEIGHT =
Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.serverWeight");

/** Name associated with individual address, if available (e.g., DNS name). */
@EquivalentAddressGroup.Attr
static final Attributes.Key<String> ATTR_ADDRESS_NAME =
Attributes.Key.create("io.grpc.xds.InternalXdsAttributes.addressName");

/**
* Filter chain match for network filters.
*/
Expand Down
Loading

0 comments on commit c167ead

Please sign in to comment.