Skip to content

Commit

Permalink
Make sure we support keyed services (dotnet#20014)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattleibow authored Jan 22, 2024
1 parent 5519ab3 commit 7dbe853
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 1 deletion.
28 changes: 27 additions & 1 deletion src/Core/src/MauiContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ public MauiContext(IServiceProvider services, Android.Content.Context context)

public MauiContext(IServiceProvider services)
{
_services = new WrappedServiceProvider(services ?? throw new ArgumentNullException(nameof(services)));
_ = services ?? throw new ArgumentNullException(nameof(services));
_services = services is IKeyedServiceProvider
? new KeyedWrappedServiceProvider(services)
: new WrappedServiceProvider(services);

_handlers = new Lazy<IMauiHandlersFactory>(() => _services.GetRequiredService<IMauiHandlersFactory>());
#if ANDROID
_context = new Lazy<Android.Content.Context?>(() => _services.GetService<Android.Content.Context>());
Expand Down Expand Up @@ -73,5 +77,27 @@ public void AddSpecific(Type type, Func<object, object?> getter, object state)
_scopeStatic[type] = (state, getter);
}
}

class KeyedWrappedServiceProvider : WrappedServiceProvider, IKeyedServiceProvider
{
public KeyedWrappedServiceProvider(IServiceProvider serviceProvider)
: base(serviceProvider)
{
}

public object? GetKeyedService(Type serviceType, object? serviceKey)
{
if (Inner is IKeyedServiceProvider provider)
return provider.GetKeyedService(serviceType, serviceKey);

// we know this won't work, but we need to call it to throw the right exception
return Inner.GetRequiredKeyedService(serviceType, serviceKey);
}

public object GetRequiredKeyedService(Type serviceType, object? serviceKey)
{
return Inner.GetRequiredKeyedService(serviceType, serviceKey);
}
}
}
}
122 changes: 122 additions & 0 deletions src/Core/tests/UnitTests/MauiContextTests.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using System;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Maui.Hosting;
using Microsoft.Maui.Hosting.Internal;
using Xunit;

Expand Down Expand Up @@ -99,6 +101,126 @@ public void CloneCanOverrideIncludeService()
Assert.Same(obj2, second.Services.GetService<TestThing>());
}

[Fact]
public void MauiContextSupportsKeyedServices()
{
var collection = new ServiceCollection();
collection.AddKeyedTransient<IFooService, FooService>("foo");
collection.AddKeyedTransient<IFooService, FooService2>("foo2");
var services = collection.BuildServiceProvider();

var context = new MauiContext(services);

var foo = context.Services.GetRequiredKeyedService<IFooService>("foo");
Assert.IsType<FooService>(foo);

var foo2 = context.Services.GetRequiredKeyedService<IFooService>("foo2");
Assert.IsType<FooService2>(foo2);
}

[Fact]
public void MauiContextSupportsKeyedServicesUsingAttributes()
{
var collection = new ServiceCollection();
collection.AddKeyedTransient<IFooService, FooService>("foo");
collection.AddKeyedTransient<IBarService, BarService>("bar");
collection.AddTransient<IFooBarService, FooBarKeyedService>();
var services = collection.BuildServiceProvider();

var context = new MauiContext(services);

var foobar = context.Services.GetRequiredService<IFooBarService>();
var keyed = Assert.IsType<FooBarKeyedService>(foobar);
Assert.NotNull(keyed.Foo);
Assert.NotNull(keyed.Bar);
}
[Fact]
public void NonKeyedProviderStaysNonKeyed()
{
var builder = MauiApp.CreateBuilder(useDefaults: false);
builder.ConfigureContainer(new KeyedOrNonKeyedProviderFactory(false));
var mauiApp = builder.Build();

var context = new MauiContext(mauiApp.Services);

Assert.IsAssignableFrom<IServiceProvider>(context.Services);
Assert.IsNotAssignableFrom<IKeyedServiceProvider>(context.Services);

var context2 = new MauiContext(context.Services);

Assert.IsAssignableFrom<IServiceProvider>(context2.Services);
Assert.IsNotAssignableFrom<IKeyedServiceProvider>(context2.Services);
}

[Fact]
public void KeyedProviderStaysKeyed()
{
var builder = MauiApp.CreateBuilder(useDefaults: false);
builder.ConfigureContainer(new KeyedOrNonKeyedProviderFactory(true));
var mauiApp = builder.Build();

var context = new MauiContext(mauiApp.Services);

Assert.IsAssignableFrom<IServiceProvider>(context.Services);
Assert.IsAssignableFrom<IKeyedServiceProvider>(context.Services);

var context2 = new MauiContext(context.Services);

Assert.IsAssignableFrom<IServiceProvider>(context2.Services);
Assert.IsAssignableFrom<IKeyedServiceProvider>(context2.Services);
}

private class KeyedOrNonKeyedProviderFactory : IServiceProviderFactory<ServiceCollection>
{
public KeyedOrNonKeyedProviderFactory(bool keyed)
{
Keyed = keyed;
}

public bool Keyed { get; }

public ServiceCollection CreateBuilder(IServiceCollection services) =>
new() { services };

public IServiceProvider CreateServiceProvider(ServiceCollection containerBuilder)
{
var real = containerBuilder.BuildServiceProvider();
return Keyed ? new KeyedProvider(real) : new NonKeyedProvider(real);
}
}

private class NonKeyedProvider : IServiceProvider
{
public NonKeyedProvider(ServiceProvider provider)
{
Provider = provider;
}

public ServiceProvider Provider { get; }

public object GetService(Type serviceType) =>
Provider.GetService(serviceType);
}

private class KeyedProvider : IServiceProvider, IKeyedServiceProvider
{
public KeyedProvider(ServiceProvider provider)
{
Provider = provider;
}

public ServiceProvider Provider { get; }

public object GetKeyedService(Type serviceType, object serviceKey) =>
Provider.GetKeyedService(serviceType, serviceKey);

public object GetRequiredKeyedService(Type serviceType, object serviceKey) =>
Provider.GetRequiredKeyedService(serviceType, serviceKey);

public object GetService(Type serviceType) =>
Provider.GetService(serviceType);
}

class TestThing
{
}
Expand Down
14 changes: 14 additions & 0 deletions src/Core/tests/UnitTests/TestClasses/TestServices.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Collections.Generic;
using Microsoft.Extensions.DependencyInjection;

namespace Microsoft.Maui.UnitTests
{
Expand Down Expand Up @@ -64,6 +65,19 @@ public FooBarService(IFooService foo, IBarService bar)
public IBarService Bar { get; }
}

class FooBarKeyedService : IFooBarService
{
public FooBarKeyedService([FromKeyedServices("foo")] IFooService foo, [FromKeyedServices("bar")] IBarService bar)
{
Foo = foo;
Bar = bar;
}

public IFooService Foo { get; }

public IBarService Bar { get; }
}

class FooTrioConstructor : IFooBarService
{
public FooTrioConstructor()
Expand Down

0 comments on commit 7dbe853

Please sign in to comment.