Skip to content

Commit

Permalink
fix: Race condition between GrpcWorkerConnection open and agent type …
Browse files Browse the repository at this point in the history
…registration (#5521)

This finishes the fix for the race condition between opening a
GrpcWorkerConnection and registering agent types on that worker. Now,
instead of failing to register, we return from the call (with the
expectation that we will finish registration as we set up the
connection)

Part 1: #5494 
Part 2: #5514

---------

Co-authored-by: Ryan Sweet <[email protected]>
  • Loading branch information
lokitoth and rysweet authored Feb 13, 2025
1 parent 7f0acd7 commit 62954ea
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,21 @@ public async ValueTask<RegisterAgentTypeResponse> RegisterAgentTypeAsync(Registe
{
var clientId = context.RequestHeaders.Get("client-id")?.Value ??
throw new RpcException(new Status(StatusCode.InvalidArgument, "Grpc Client ID is required."));
if (!_workers.TryGetValue(clientId, out var connection))

Func<ValueTask> registerLambda = async () =>
{
throw new RpcException(new Status(StatusCode.InvalidArgument, $"Grpc Worker Connection not found for ClientId {clientId}."));
}
connection.AddSupportedType(request.Type);
_supportedAgentTypes.GetOrAdd(request.Type, _ => []).Add(connection);
if (!_workers.TryGetValue(clientId, out var connection))
{
throw new RpcException(new Status(StatusCode.InvalidArgument, $"Grpc Worker Connection not found for ClientId {clientId}. Retry after you call OpenChannel() first."));
}
connection.AddSupportedType(request.Type);
_supportedAgentTypes.GetOrAdd(request.Type, _ => []).Add(connection);

await _gatewayRegistry.RegisterAgentTypeAsync(request, clientId, _reference).ConfigureAwait(true);
};

await InvokeOrDeferRegistrationAction(clientId, registerLambda).ConfigureAwait(true);

await _gatewayRegistry.RegisterAgentTypeAsync(request, clientId, _reference).ConfigureAwait(true);
return new RegisterAgentTypeResponse { };
}
catch (Exception ex)
Expand All @@ -138,6 +145,8 @@ public async ValueTask<AddSubscriptionResponse> SubscribeAsync(AddSubscriptionRe
{
try
{
// We do not actually need to defer these, since we do not listen to ClientId on this for some reason...
// TODO: Fix this
await _gatewayRegistry.SubscribeAsync(request).ConfigureAwait(true);
return new AddSubscriptionResponse { };
}
Expand All @@ -157,6 +166,8 @@ public async ValueTask<RemoveSubscriptionResponse> UnsubscribeAsync(RemoveSubscr
{
try
{
// We do not need to defer here because we will never have a guid to send to this unless the deferred
// AddSubscription calls were run after a client connection was established.
await _gatewayRegistry.UnsubscribeAsync(request).ConfigureAwait(true);
return new RemoveSubscriptionResponse { };
}
Expand Down Expand Up @@ -216,9 +227,38 @@ internal async Task ConnectToWorkerProcess(IAsyncStreamReader<Message> requestSt
throw new RpcException(new Status(StatusCode.InvalidArgument, "Client ID is required."));
var workerProcess = new GrpcWorkerConnection(this, requestStream, responseStream, context);
_workers.GetOrAdd(clientId, workerProcess);

await this.AttachDanglingRegistrations(clientId).ConfigureAwait(false);

await workerProcess.Connect().ConfigureAwait(false);
}

private ConcurrentDictionary<string, ConcurrentQueue<Func<ValueTask>>> _danglingRequests = new();
private async Task InvokeOrDeferRegistrationAction(string clientId, Func<ValueTask> action)
{
if (_workers.TryGetValue(clientId, out var _))
{
await action().ConfigureAwait(false);
}
else
{
ConcurrentQueue<Func<ValueTask>> danglingRequestQueue = _danglingRequests.GetOrAdd(clientId, _ => new ConcurrentQueue<Func<ValueTask>>());
danglingRequestQueue.Enqueue(action);
}
}

private async Task AttachDanglingRegistrations(string clientId)
{
_logger.LogInformation("Attaching dangling registrations for {ClientId}.", clientId);
if (_danglingRequests.TryRemove(clientId, out var requests))
{
foreach (var request in requests)
{
await request().ConfigureAwait(false);
}
}
}

/// <summary>
/// Handles received messages from a worker connection.
/// </summary>
Expand Down
10 changes: 0 additions & 10 deletions python/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 62954ea

Please sign in to comment.