Skip to content

Commit

Permalink
Enable AIFunctionFactory to resolve parameters from an IServiceProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub committed Feb 24, 2025
1 parent 8ea523c commit 170d874
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ public static JsonElement CreateFunctionJsonSchema(
continue;
}

if (parameter.GetCustomAttribute<SkipJsonFunctionSchemaParameterAttribute>(inherit: true) is not null)
{
// Skip anything explicitly requested to not be included in the schema.
continue;
}

JsonNode parameterSchema = CreateJsonSchemaCore(
type: parameter.ParameterType,
parameterName: parameter.Name,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;

#pragma warning disable CA1813 // Avoid unsealed attributes

namespace Microsoft.Extensions.AI;

/// <summary>Indicates that a parameter to a method should not be included in a generated JSON schema by <see cref="AIJsonUtilities.CreateFunctionJsonSchema"/>.</summary>
[AttributeUsage(AttributeTargets.Parameter)]
public class SkipJsonFunctionSchemaParameterAttribute : Attribute
{
/// <summary>Initializes a new instance of the <see cref="SkipJsonFunctionSchemaParameterAttribute"/> class.</summary>
public SkipJsonFunctionSchemaParameterAttribute()
{
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections;
using System.Collections.Generic;
using Microsoft.Shared.Diagnostics;

#pragma warning disable CA1710 // Identifiers should have correct suffix

namespace Microsoft.Extensions.AI;

/// <summary>Represents arguments to be used with <see cref="AIFunction.InvokeAsync"/>.</summary>
/// <remarks>
/// <see cref="AIFunction.InvokeAsync"/> may be invoked with arbitary <see cref="IEnumerable{T}"/>
/// implementations. However, some <see cref="AIFunction"/> implementations may dynamically check
/// the type of the arguments, and if it's an <see cref="AIFunctionArguments"/>, use it to access
/// an <see cref="IServiceProvider"/> that's passed in separately from the arguments enumeration.
/// </remarks>
public class AIFunctionArguments : IEnumerable<KeyValuePair<string, object?>>
{
/// <summary>The arguments represented by this instance.</summary>
private readonly IEnumerable<KeyValuePair<string, object?>> _arguments;

/// <summary>Initializes a new instance of the <see cref="AIFunctionArguments"/> class.</summary>
/// <param name="arguments">The arguments represented by this instance.</param>
/// <param name="serviceProvider">Options services associated with these arguments.</param>
public AIFunctionArguments(IEnumerable<KeyValuePair<string, object?>>? arguments, IServiceProvider? serviceProvider = null)
{
_arguments = Throw.IfNull(arguments);
ServiceProvider = serviceProvider;
}

/// <summary>Gets the services associated with these arguments.</summary>
public IServiceProvider? ServiceProvider { get; }

/// <inheritdoc />
public IEnumerator<KeyValuePair<string, object?>> GetEnumerator() => _arguments.GetEnumerator();

/// <inheritdoc />
IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)_arguments).GetEnumerator();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;

namespace Microsoft.Extensions.AI;

/// <summary>Indicates that a parameter to an <see cref="AIFunction"/> should be sourced from an associated <see cref="IServiceProvider"/>.</summary>
[AttributeUsage(AttributeTargets.Parameter)]
public sealed class FromServiceProviderAttribute : SkipJsonFunctionSchemaParameterAttribute
{
/// <summary>Initializes a new instance of the <see cref="FromServiceProviderAttribute"/> class.</summary>
/// <param name="serviceKey">Optional key to use when resolving the service.</param>
public FromServiceProviderAttribute(object? serviceKey = null)
{
ServiceKey = serviceKey;
}

/// <summary>Gets the key to use when resolving the service.</summary>
public object? ServiceKey { get; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,13 @@ public partial class FunctionInvokingChatClient : DelegatingChatClient
/// </summary>
/// <param name="innerClient">The underlying <see cref="IChatClient"/>, or the next instance in a chain of clients.</param>
/// <param name="logger">An <see cref="ILogger"/> to use for logging information about function invocation.</param>
public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null)
/// <param name="services">An optional <see cref="IServiceProvider"/> to use for resolving services required by the <see cref="AIFunction"/> instances being invoked.</param>
public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null, IServiceProvider? services = null)
: base(innerClient)
{
_logger = logger ?? NullLogger.Instance;
_logger = logger ?? (ILogger?)services?.GetService(typeof(ILogger<FunctionInvokingChatClient>)) ?? NullLogger.Instance;
_activitySource = innerClient.GetService<ActivitySource>();
Services = services;
}

/// <summary>
Expand All @@ -77,6 +79,9 @@ public static FunctionInvocationContext? CurrentContext
protected set => _currentContext.Value = value;
}

/// <summary>Gets the <see cref="IServiceProvider"/> used for resolving services required by the <see cref="AIFunction"/> instances being invoked.</summary>
public IServiceProvider? Services { get; }

/// <summary>
/// Gets or sets a value indicating whether to handle exceptions that occur during function calls.
/// </summary>
Expand Down Expand Up @@ -687,8 +692,14 @@ FunctionResultContent CreateFunctionResultContent(FunctionInvocationResult resul
object? result = null;
try
{
IEnumerable<KeyValuePair<string, object?>>? arguments = context.CallContent.Arguments;
if (Services is not null)
{
arguments = new AIFunctionArguments(arguments, Services);
}

CurrentContext = context;
result = await context.Function.InvokeAsync(context.CallContent.Arguments, cancellationToken).ConfigureAwait(false);
result = await context.Function.InvokeAsync(arguments, cancellationToken).ConfigureAwait(false);
}
catch (Exception e)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public static ChatClientBuilder UseFunctionInvocation(
{
loggerFactory ??= services.GetService<ILoggerFactory>();

var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory?.CreateLogger(typeof(FunctionInvokingChatClient)));
var chatClient = new FunctionInvokingChatClient(innerClient, loggerFactory?.CreateLogger(typeof(FunctionInvokingChatClient)), services);
configure?.Invoke(chatClient);
return chatClient;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@
using System.Text.Json.Serialization.Metadata;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Shared.Collections;
using Microsoft.Shared.Diagnostics;

#pragma warning disable CA1031 // Do not catch general exception types
#pragma warning disable S2302 // "nameof" should be used
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields

namespace Microsoft.Extensions.AI;

/// <summary>Provides factory methods for creating commonly used implementations of <see cref="AIFunction"/>.</summary>
Expand Down Expand Up @@ -325,11 +330,31 @@ static bool IsAsyncMethod(MethodInfo method)
// Resolve the contract used to marshal the value from JSON -- can throw if not supported or not found.
Type parameterType = parameter.ParameterType;
JsonTypeInfo typeInfo = serializerOptions.GetTypeInfo(parameterType);
FromServiceProviderAttribute? fspAttr = parameter.GetCustomAttribute<FromServiceProviderAttribute>(inherit: true);

// Create a marshaller for the parameter. This produces a value for the parameter based on an ordered
// collection of rules.
return (arguments, cancellationToken) =>
{
// If the parameter is [FromServiceProvider], try to satisfy it from the service provider
// provided via arguments.
if (fspAttr is not null &&
(arguments as AIFunctionArguments)?.ServiceProvider is IServiceProvider services)
{
if (fspAttr.ServiceKey is object serviceKey)
{
if (services is IKeyedServiceProvider ksp &&
ksp.GetKeyedService(parameterType, serviceKey) is object keyedService)
{
return keyedService;
}
}
else if (services.GetService(parameterType) is object service)
{
return service;
}
}

// If the parameter has an argument specified in the dictionary, return that argument.
if (arguments.TryGetValue(parameter.Name, out object? value))
{
Expand All @@ -345,7 +370,6 @@ static bool IsAsyncMethod(MethodInfo method)

object? MarshallViaJsonRoundtrip(object value)
{
#pragma warning disable CA1031 // Do not catch general exception types
try
{
string json = JsonSerializer.Serialize(value, serializerOptions.GetTypeInfo(value.GetType()));
Expand All @@ -356,7 +380,6 @@ static bool IsAsyncMethod(MethodInfo method)
// Eat any exceptions and fall back to the original value to force a cast exception later on.
return value;
}
#pragma warning restore CA1031
}
}

Expand Down Expand Up @@ -476,9 +499,7 @@ private static MethodInfo GetMethodFromGenericMethodDefinition(Type specializedT
#if NET
return (MethodInfo)specializedType.GetMemberWithSameMetadataDefinitionAs(genericMethodDefinition);
#else
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
const BindingFlags All = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance;
#pragma warning restore S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
return specializedType.GetMethods(All).First(m => m.MetadataToken == genericMethodDefinition.MetadataToken);
#endif
}
Expand Down

0 comments on commit 170d874

Please sign in to comment.