Skip to content

Commit

Permalink
Use AIJsonSchemaCreateOptions for excluding, and recognize FromKeyedS…
Browse files Browse the repository at this point in the history
…ervices
  • Loading branch information
stephentoub committed Mar 1, 2025
1 parent 687818c commit 08c4d4b
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Reflection;
using System.Text.Json.Nodes;

#pragma warning disable S1067 // Expressions should not be too complex
Expand All @@ -23,6 +24,17 @@ public sealed class AIJsonSchemaCreateOptions : IEquatable<AIJsonSchemaCreateOpt
/// </summary>
public Func<AIJsonSchemaCreateContext, JsonNode, JsonNode>? TransformSchemaNode { get; init; }

/// <summary>
/// Gets a callback that is invoked for every parameter in the <see cref="MethodBase"/> provided to
/// <see cref="AIJsonUtilities.CreateFunctionJsonSchema"/> in order to determine whether it should
/// be included in the generated schema.
/// </summary>
/// <remarks>
/// By default, when <see cref="IncludeParameter"/> is <see langword="null"/>,
/// all parameters are included in the generated schema.
/// </remarks>
public Func<ParameterInfo, bool>? IncludeParameter { get; init; }

/// <summary>
/// Gets a value indicating whether to include the type keyword in inferred schemas for .NET enums.
/// </summary>
Expand All @@ -44,19 +56,24 @@ public sealed class AIJsonSchemaCreateOptions : IEquatable<AIJsonSchemaCreateOpt
public bool RequireAllProperties { get; init; } = true;

/// <inheritdoc/>
public bool Equals(AIJsonSchemaCreateOptions? other)
{
return other is not null &&
TransformSchemaNode == other.TransformSchemaNode &&
IncludeTypeInEnumSchemas == other.IncludeTypeInEnumSchemas &&
DisallowAdditionalProperties == other.DisallowAdditionalProperties &&
IncludeSchemaKeyword == other.IncludeSchemaKeyword &&
RequireAllProperties == other.RequireAllProperties;
}
public bool Equals(AIJsonSchemaCreateOptions? other) =>
other is not null &&
TransformSchemaNode == other.TransformSchemaNode &&
IncludeParameter == other.IncludeParameter &&
IncludeTypeInEnumSchemas == other.IncludeTypeInEnumSchemas &&
DisallowAdditionalProperties == other.DisallowAdditionalProperties &&
IncludeSchemaKeyword == other.IncludeSchemaKeyword &&
RequireAllProperties == other.RequireAllProperties;

/// <inheritdoc />
public override bool Equals(object? obj) => obj is AIJsonSchemaCreateOptions other && Equals(other);

/// <inheritdoc />
public override int GetHashCode() => (TransformSchemaNode, IncludeTypeInEnumSchemas, DisallowAdditionalProperties, IncludeSchemaKeyword, RequireAllProperties).GetHashCode();
public override int GetHashCode() =>
(TransformSchemaNode,
IncludeParameter,
IncludeTypeInEnumSchemas,
DisallowAdditionalProperties,
IncludeSchemaKeyword,
RequireAllProperties).GetHashCode();
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
using System.Text.Json.Nodes;
using System.Text.Json.Schema;
using System.Text.Json.Serialization;
using System.Threading;
using Microsoft.Shared.Diagnostics;

#pragma warning disable S1121 // Assignments should not be made from within sub-expressions
Expand Down Expand Up @@ -77,17 +76,11 @@ public static JsonElement CreateFunctionJsonSchema(
Throw.ArgumentException(nameof(parameter), "Parameter is missing a name.");
}

if (parameter.ParameterType == typeof(CancellationToken))
if (inferenceOptions.IncludeParameter is { } includeParameter &&
!includeParameter(parameter))
{
// CancellationToken is a special case that, by convention, we don't want to include in the schema.
// Invocations of methods that include a CancellationToken argument should also special-case CancellationToken
// to pass along what relevant token into the method's invocation.
continue;
}

if (parameter.GetCustomAttribute<SkipJsonFunctionSchemaParameterAttribute>(inherit: true) is not null)
{
// Skip anything explicitly requested to not be included in the schema.
// Skip parameters that should not be included in the schema.
// By default, all parameters are included.
continue;
}

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,10 @@ namespace Microsoft.Extensions.AI;

/// <summary>Indicates that a parameter to an <see cref="AIFunction"/> should be sourced from an associated <see cref="IServiceProvider"/>.</summary>
[AttributeUsage(AttributeTargets.Parameter)]
public sealed class FromServiceProviderAttribute : SkipJsonFunctionSchemaParameterAttribute
public sealed class FromServicesAttribute : Attribute
{
/// <summary>Initializes a new instance of the <see cref="FromServiceProviderAttribute"/> class.</summary>
/// <param name="serviceKey">Optional key to use when resolving the service.</param>
public FromServiceProviderAttribute(object? serviceKey = null)
/// <summary>Initializes a new instance of the <see cref="FromServicesAttribute"/> class.</summary>
public FromServicesAttribute()
{
ServiceKey = serviceKey;
}

/// <summary>Gets the key to use when resolving the service.</summary>
public object? ServiceKey { get; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Shared.Diagnostics;
Expand Down Expand Up @@ -62,7 +63,7 @@ public partial class FunctionInvokingChatClient : DelegatingChatClient
public FunctionInvokingChatClient(IChatClient innerClient, ILogger? logger = null, IServiceProvider? services = null)
: base(innerClient)
{
_logger = logger ?? (ILogger?)services?.GetService(typeof(ILogger<FunctionInvokingChatClient>)) ?? NullLogger.Instance;
_logger = logger ?? (ILogger?)services?.GetService<ILogger<FunctionInvokingChatClient>>() ?? NullLogger.Instance;
_activitySource = innerClient.GetService<ActivitySource>();
Services = services;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,32 @@ public static ReflectionAIFunctionDescriptor GetOrCreate(MethodInfo method, AIFu

private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions serializerOptions)
{
AIJsonSchemaCreateOptions schemaOptions = new()
{
// This needs to be kept in sync with the shape of AIJsonSchemaCreateOptions.
TransformSchemaNode = key.SchemaOptions.TransformSchemaNode,
IncludeParameter = parameterInfo =>
{
// Explicitly exclude from the schema CancellationToken parameters as well
// as those annotated as [FromServices] or [FromKeyedServices]. These will be satisfied
// from sources other than arguments to InvokeAsync.
if (parameterInfo.ParameterType == typeof(CancellationToken) ||
parameterInfo.GetCustomAttribute<FromServicesAttribute>(inherit: true) is not null ||
parameterInfo.GetCustomAttribute<FromKeyedServicesAttribute>(inherit: true) is not null)
{
return false;
}

// For all other parameters, delegate to whatever behavior is specified in the options.
// If none is specified, include the parameter.
return key.SchemaOptions.IncludeParameter?.Invoke(parameterInfo) ?? true;
},
IncludeTypeInEnumSchemas = key.SchemaOptions.IncludeTypeInEnumSchemas,
DisallowAdditionalProperties = key.SchemaOptions.DisallowAdditionalProperties,
IncludeSchemaKeyword = key.SchemaOptions.IncludeSchemaKeyword,
RequireAllProperties = key.SchemaOptions.RequireAllProperties,
};

// Get marshaling delegates for parameters.
ParameterInfo[] parameters = key.Method.GetParameters();
ParameterMarshallers = new Func<IReadOnlyDictionary<string, object?>, CancellationToken, object?>[parameters.Length];
Expand All @@ -269,7 +295,7 @@ private ReflectionAIFunctionDescriptor(DescriptorKey key, JsonSerializerOptions
Name,
Description,
serializerOptions,
key.SchemaOptions);
schemaOptions);
}

public string Name { get; }
Expand Down Expand Up @@ -343,33 +369,36 @@ static bool IsAsyncMethod(MethodInfo method)
}

// For DI-based parameters, try to resolve from the service provider.
if (parameter.GetCustomAttribute<FromServiceProviderAttribute>(inherit: true) is FromServiceProviderAttribute fspAttr)
if (parameter.GetCustomAttribute<FromServicesAttribute>(inherit: true) is { } fsAttr)
{
return (arguments, _) =>
{
if ((arguments as AIFunctionArguments)?.ServiceProvider is IServiceProvider services)
if ((arguments as AIFunctionArguments)?.ServiceProvider is IServiceProvider services &&
services.GetService(parameterType) is object service)
{
if (fspAttr.ServiceKey is object serviceKey)
{
if ((services as IKeyedServiceProvider)?.GetKeyedService(parameterType, serviceKey) is object keyedService)
{
return keyedService;
}
}
else if (services.GetService(parameterType) is object service)
{
return service;
}
return service;
}

// No service could be resolved. Does it have a default value?
if (parameter.HasDefaultValue)
// No service could be resolved. Return a default value if it's optional, otherwise throw.
return parameter.HasDefaultValue ?
parameter.DefaultValue :
throw new InvalidOperationException($"Unable to resolve service of type '{parameterType}' for parameter '{parameter.Name}'.");
};
}
else if (parameter.GetCustomAttribute<FromKeyedServicesAttribute>(inherit: true) is { } fksAttr)
{
return (arguments, _) =>
{
if ((arguments as AIFunctionArguments)?.ServiceProvider is IKeyedServiceProvider services &&
services.GetKeyedService(parameterType, fksAttr.Key) is object service)
{
return parameter.DefaultValue;
return service;
}

// It's a required argument, and we couldn't resolve a service. Throw.
throw new InvalidOperationException($"Unable to resolve service of type '{parameterType}' for parameter '{parameter.Name}'.");
// No service could be resolved. Return a default value if it's optional, otherwise throw.
return parameter.HasDefaultValue ?
parameter.DefaultValue :
throw new InvalidOperationException($"Unable to resolve service of type '{parameterType}' with key '{fksAttr.Key}' for parameter '{parameter.Name}'.");
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ public static void AIJsonSchemaCreateOptions_UsesStructuralEquality()
property.SetValue(options2, transformer);
break;
case null when property.PropertyType == typeof(Func<ParameterInfo, bool>):
Func<ParameterInfo, bool> includeParameter = static (parameter) => true;
property.SetValue(options1, includeParameter);
property.SetValue(options2, includeParameter);
break;
default:
Assert.Fail($"Unexpected property type: {property.PropertyType}");
break;
Expand Down

0 comments on commit 08c4d4b

Please sign in to comment.