Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub committed Feb 26, 2025
1 parent 83268b1 commit 0f43eed
Showing 1 changed file with 25 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,6 @@ static bool IsAsyncMethod(MethodInfo method)
// Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found.
Type parameterType = parameter.ParameterType;
JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(parameterType);
FromServiceProviderAttribute? fspAttr = parameter.GetCustomAttribute<FromServiceProviderAttribute>(inherit: true);

// For CancellationToken parameters, we always bind to the token passed directly to InvokeAsync.
if (parameterType == typeof(CancellationToken))
Expand All @@ -343,28 +342,40 @@ static bool IsAsyncMethod(MethodInfo method)
cancellationToken;
}

// For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary.
return (arguments, _) =>
// For DI-based parameters, try to resolve from the service provider.
if (parameter.GetCustomAttribute<FromServiceProviderAttribute>(inherit: true) is FromServiceProviderAttribute fspAttr)
{
// If the parameter is [FromServiceProvider], try to satisfy it from the service provider
// provided via arguments.
if (fspAttr is not null &&
(arguments as AIFunctionArguments)?.ServiceProvider is IServiceProvider services)
return (arguments, _) =>
{
if (fspAttr.ServiceKey is object serviceKey)
if ((arguments as AIFunctionArguments)?.ServiceProvider is IServiceProvider services)
{
if (services is IKeyedServiceProvider ksp &&
ksp.GetKeyedService(parameterType, serviceKey) is object keyedService)
if (fspAttr.ServiceKey is object serviceKey)
{
if ((services as IKeyedServiceProvider)?.GetKeyedService(parameterType, serviceKey) is object keyedService)
{
return keyedService;
}
}
else if (services.GetService(parameterType) is object service)
{
return keyedService;
return service;
}
}
else if (services.GetService(parameterType) is object service)

// No service could be resolved. Does it have a default value?
if (parameter.HasDefaultValue)
{
return service;
return parameter.DefaultValue;
}
}

// It's a required argument, and we couldn't resolve a service. Throw.
throw new InvalidOperationException($"Unable to resolve service of type '{parameterType}' for parameter '{parameter.Name}'.");
};
}

// For all other parameters, create a marshaller that tries to extract the value from the arguments dictionary.
return (arguments, _) =>
{
// If the parameter has an argument specified in the dictionary, return that argument.
if (arguments.TryGetValue(parameter.Name, out object? value))
{
Expand Down

0 comments on commit 0f43eed

Please sign in to comment.