diff --git a/examples/AI/ConversationalAI/Program.cs b/examples/AI/ConversationalAI/Program.cs
index bd3dc906a..6315db87a 100644
--- a/examples/AI/ConversationalAI/Program.cs
+++ b/examples/AI/ConversationalAI/Program.cs
@@ -3,7 +3,7 @@
var builder = WebApplication.CreateBuilder(args);
-builder.Services.AddDaprAiConversation();
+builder.Services.AddDaprConversationClient();
var app = builder.Build();
diff --git a/src/Dapr.AI/Conversation/Extensions/DaprAiConversationBuilderExtensions.cs b/src/Dapr.AI/Conversation/Extensions/DaprAiConversationBuilderExtensions.cs
index 902fd82a3..2f049a906 100644
--- a/src/Dapr.AI/Conversation/Extensions/DaprAiConversationBuilderExtensions.cs
+++ b/src/Dapr.AI/Conversation/Extensions/DaprAiConversationBuilderExtensions.cs
@@ -26,7 +26,7 @@ public static class DaprAiConversationBuilderExtensions
/// Registers the necessary functionality for the Dapr AI conversation functionality.
///
///
- public static IDaprAiConversationBuilder AddDaprAiConversation(this IServiceCollection services, Action? configure = null, ServiceLifetime lifetime = ServiceLifetime.Singleton)
+ public static IDaprAiConversationBuilder AddDaprConversationClient(this IServiceCollection services, Action? configure = null, ServiceLifetime lifetime = ServiceLifetime.Singleton)
{
ArgumentNullException.ThrowIfNull(services, nameof(services));
diff --git a/test/Dapr.AI.Test/Conversation/Extensions/DaprAiConversationBuilderExtensionsTest.cs b/test/Dapr.AI.Test/Conversation/Extensions/DaprAiConversationBuilderExtensionsTest.cs
index 95a8e1e8c..2ee321895 100644
--- a/test/Dapr.AI.Test/Conversation/Extensions/DaprAiConversationBuilderExtensionsTest.cs
+++ b/test/Dapr.AI.Test/Conversation/Extensions/DaprAiConversationBuilderExtensionsTest.cs
@@ -13,7 +13,9 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Net.Http;
+using System.Threading.Tasks;
using Dapr.AI.Conversation;
using Dapr.AI.Conversation.Extensions;
using Microsoft.Extensions.Configuration;
@@ -34,7 +36,7 @@ public void AddDaprConversationClient_FromIConfiguration()
var services = new ServiceCollection();
services.AddSingleton(configuration);
- services.AddDaprAiConversation();
+ services.AddDaprConversationClient();
var app = services.BuildServiceProvider();
@@ -45,18 +47,66 @@ public void AddDaprConversationClient_FromIConfiguration()
}
[Fact]
- public void AddDaprAiConversation_WithoutConfigure_ShouldAddServices()
+ public void AddDaprConversationClient_RegistersDaprClientOnlyOnce()
{
var services = new ServiceCollection();
- var builder = services.AddDaprAiConversation();
+
+ var clientBuilder = new Action((sp, builder) =>
+ {
+ builder.UseDaprApiToken("abc");
+ });
+
+ services.AddDaprConversationClient(); //Sets a default API token value of an empty string
+ services.AddDaprConversationClient(clientBuilder); //Sets the API token value
+
+ var serviceProvider = services.BuildServiceProvider();
+ var daprConversationClient = serviceProvider.GetService();
+
+ Assert.NotNull(daprConversationClient!.HttpClient);
+ Assert.False(daprConversationClient.HttpClient.DefaultRequestHeaders.TryGetValues("dapr-api-token", out var _));
+ }
+
+ [Fact]
+ public void AddDaprConversationClient_RegistersUsingDependencyFromIServiceProvider()
+ {
+ var services = new ServiceCollection();
+ services.AddSingleton();
+ services.AddDaprConversationClient((provider, builder) =>
+ {
+ var configProvider = provider.GetRequiredService();
+ var apiToken = configProvider.GetApiTokenValue();
+ builder.UseDaprApiToken(apiToken);
+ });
+
+ var serviceProvider = services.BuildServiceProvider();
+ var client = serviceProvider.GetRequiredService();
+
+ //Validate it's set on the GrpcClient - note that it doesn't get set on the HttpClient
+ Assert.NotNull(client);
+ Assert.NotNull(client.DaprApiToken);
+ Assert.Equal("abcdef", client.DaprApiToken);
+ Assert.NotNull(client.HttpClient);
+
+ if (!client.HttpClient.DefaultRequestHeaders.TryGetValues("dapr-api-token", out var daprApiToken))
+ {
+ Assert.Fail();
+ }
+ Assert.Equal("abcdef", daprApiToken.FirstOrDefault());
+ }
+
+ [Fact]
+ public void AddDaprConversationClient_WithoutConfigure_ShouldAddServices()
+ {
+ var services = new ServiceCollection();
+ var builder = services.AddDaprConversationClient();
Assert.NotNull(builder);
}
[Fact]
- public void AddDaprAiConversation_RegistersIHttpClientFactory()
+ public void AddDaprConversationClient_RegistersIHttpClientFactory()
{
var services = new ServiceCollection();
- services.AddDaprAiConversation();
+ services.AddDaprConversationClient();
var serviceProvider = services.BuildServiceProvider();
var httpClientFactory = serviceProvider.GetService();
@@ -67,9 +117,66 @@ public void AddDaprAiConversation_RegistersIHttpClientFactory()
}
[Fact]
- public void AddDaprAiConversation_NullServices_ShouldThrowException()
+ public void AddDaprConversationClient_NullServices_ShouldThrowException()
{
IServiceCollection services = null;
- Assert.Throws(() => services.AddDaprAiConversation());
+ Assert.Throws(() => services.AddDaprConversationClient());
+ }
+
+ [Fact]
+ public void AddDaprConversationClient_ShouldRegisterSingleton_WhenLifetimeIsSingleton()
+ {
+ var services = new ServiceCollection();
+
+ services.AddDaprConversationClient((_, _) => { }, ServiceLifetime.Singleton);
+ var serviceProvider = services.BuildServiceProvider();
+
+ var daprConversationClient1 = serviceProvider.GetService();
+ var daprConversationClient2 = serviceProvider.GetService();
+
+ Assert.NotNull(daprConversationClient1);
+ Assert.NotNull(daprConversationClient2);
+
+ Assert.Same(daprConversationClient1, daprConversationClient2);
+ }
+
+ [Fact]
+ public async Task AddDaprConversationClient_ShouldRegisterScoped_WhenLifetimeIsScoped()
+ {
+ var services = new ServiceCollection();
+
+ services.AddDaprConversationClient((_, _) => { }, ServiceLifetime.Scoped);
+ var serviceProvider = services.BuildServiceProvider();
+
+ await using var scope1 = serviceProvider.CreateAsyncScope();
+ var daprConversationClient1 = scope1.ServiceProvider.GetService();
+
+ await using var scope2 = serviceProvider.CreateAsyncScope();
+ var daprConversationClient2 = scope2.ServiceProvider.GetService();
+
+ Assert.NotNull(daprConversationClient1);
+ Assert.NotNull(daprConversationClient2);
+ Assert.NotSame(daprConversationClient1, daprConversationClient2);
+ }
+
+ [Fact]
+ public void AddDaprConversationClient_ShouldRegisterTransient_WhenLifetimeIsTransient()
+ {
+ var services = new ServiceCollection();
+
+ services.AddDaprConversationClient((_, _) => { }, ServiceLifetime.Transient);
+ var serviceProvider = services.BuildServiceProvider();
+
+ var daprConversationClient1 = serviceProvider.GetService();
+ var daprConversationClient2 = serviceProvider.GetService();
+
+ Assert.NotNull(daprConversationClient1);
+ Assert.NotNull(daprConversationClient2);
+ Assert.NotSame(daprConversationClient1, daprConversationClient2);
+ }
+
+ private class TestSecretRetriever
+ {
+ public string GetApiTokenValue() => "abcdef";
}
}
diff --git a/test/Dapr.Jobs.Test/Extensions/DaprJobsServiceCollectionExtensionsTests.cs b/test/Dapr.Jobs.Test/Extensions/DaprJobsServiceCollectionExtensionsTests.cs
index bd5e4acd0..3b2c5f990 100644
--- a/test/Dapr.Jobs.Test/Extensions/DaprJobsServiceCollectionExtensionsTests.cs
+++ b/test/Dapr.Jobs.Test/Extensions/DaprJobsServiceCollectionExtensionsTests.cs
@@ -89,7 +89,7 @@ public void AddDaprJobsClient_RegistersUsingDependencyFromIServiceProvider()
services.AddDaprJobsClient((provider, builder) =>
{
var configProvider = provider.GetRequiredService();
- var apiToken = TestSecretRetriever.GetApiTokenValue();
+ var apiToken = configProvider.GetApiTokenValue();
builder.UseDaprApiToken(apiToken);
});
@@ -114,7 +114,7 @@ public void RegisterJobsClient_ShouldRegisterSingleton_WhenLifetimeIsSingleton()
{
var services = new ServiceCollection();
- services.AddDaprJobsClient((serviceProvider, options) => { }, ServiceLifetime.Singleton);
+ services.AddDaprJobsClient((_, _) => { }, ServiceLifetime.Singleton);
var serviceProvider = services.BuildServiceProvider();
var daprJobsClient1 = serviceProvider.GetService();
@@ -131,7 +131,7 @@ public async Task RegisterJobsClient_ShouldRegisterScoped_WhenLifetimeIsScoped()
{
var services = new ServiceCollection();
- services.AddDaprJobsClient((serviceProvider, options) => { }, ServiceLifetime.Scoped);
+ services.AddDaprJobsClient((_, _) => { }, ServiceLifetime.Scoped);
var serviceProvider = services.BuildServiceProvider();
await using var scope1 = serviceProvider.CreateAsyncScope();
@@ -150,7 +150,7 @@ public void RegisterJobsClient_ShouldRegisterTransient_WhenLifetimeIsTransient()
{
var services = new ServiceCollection();
- services.AddDaprJobsClient((serviceProvider, options) => { }, ServiceLifetime.Transient);
+ services.AddDaprJobsClient((_, _) => { }, ServiceLifetime.Transient);
var serviceProvider = services.BuildServiceProvider();
var daprJobsClient1 = serviceProvider.GetService();
@@ -163,6 +163,6 @@ public void RegisterJobsClient_ShouldRegisterTransient_WhenLifetimeIsTransient()
private class TestSecretRetriever
{
- public static string GetApiTokenValue() => "abcdef";
+ public string GetApiTokenValue() => "abcdef";
}
}