Skip to content

Commit

Permalink
Add support for defining a custom metadata lookup in the servicebinder (
Browse files Browse the repository at this point in the history
#121) (#138)

* Move `GetMetadata` to the `ServiceBinder` to allow override (#121)

* Add singleton with `Authorize` attribute example (#121)
  • Loading branch information
Euan-McVie authored Jan 4, 2021
1 parent 786f359 commit e7d5887
Show file tree
Hide file tree
Showing 10 changed files with 229 additions and 71 deletions.
15 changes: 10 additions & 5 deletions examples/pb-net-grpc/Client_CS/Program.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
using Grpc.Core;
using System;
using System.Threading;
using System.Threading.Tasks;
using Grpc.Core;
using Grpc.Net.Client;
using MegaCorp;
using ProtoBuf.Grpc;
using ProtoBuf.Grpc.Client;
using Shared_CS;
using System;
using System.Threading;
using System.Threading.Tasks;

namespace Client_CS
{
Expand All @@ -21,6 +21,7 @@ static async Task Main()
Console.WriteLine(result.Result); // 48

var clock = http.CreateGrpcService<ITimeService>();
var counter = http.CreateGrpcService<ICounter>();
using var cancel = new CancellationTokenSource(TimeSpan.FromMinutes(1));
var options = new CallOptions(cancellationToken: cancel.Token);

Expand All @@ -29,10 +30,14 @@ static async Task Main()
await foreach (var time in clock.SubscribeAsync(new CallContext(options)))
{
Console.WriteLine($"The time is now: {time.Time}");
var currentInc = await counter.IncrementAsync(new IncrementRequest { Inc = 1 });
Console.WriteLine($"Time received {currentInc.Result} times");
}
}
catch (RpcException) { }
catch (RpcException ex) { Console.WriteLine(ex); }
catch (OperationCanceledException) { }
Console.WriteLine("Press [Enter] to exit");
Console.ReadLine();
}
}
}
38 changes: 38 additions & 0 deletions examples/pb-net-grpc/Server_CS/FakeAuthenticationHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using System.Security.Claims;
using System.Text.Encodings.Web;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authentication;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace Server_CS
{
class FakeAuthHandler : AuthenticationHandler<FakeAuthOptions>
{
public const string SchemeName = "Fake";

public FakeAuthHandler(
IOptionsMonitor<FakeAuthOptions> options,
ILoggerFactory logger,
UrlEncoder encoder,
ISystemClock clock)
: base(options, logger, encoder, clock)
{
}

protected override Task<AuthenticateResult> HandleAuthenticateAsync()
{
if (!Options.AlwaysAuthenticate)
return Task.FromResult(AuthenticateResult.NoResult());

var claimsIdentity = new ClaimsIdentity(SchemeName);
var ticket = new AuthenticationTicket(new ClaimsPrincipal(claimsIdentity), Scheme.Name);
return Task.FromResult(AuthenticateResult.Success(ticket));
}
}

class FakeAuthOptions : AuthenticationSchemeOptions
{
public bool AlwaysAuthenticate { get; set; } = false;
}
}
23 changes: 23 additions & 0 deletions examples/pb-net-grpc/Server_CS/MyCounter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization;
using Shared_CS;

namespace Server_CS
{
[Authorize]
public class MyCounter : ICounter
{
private int counter = 0;
private readonly object counterLock = new object();

ValueTask<IncrementResult> ICounter.IncrementAsync(IncrementRequest request)
{
lock (counterLock)
{
counter += request.Inc;
var result = new IncrementResult { Result = counter };
return new ValueTask<IncrementResult>(result);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
},
"Server_CS": {
"commandName": "Project",
"launchBrowser": true,
"launchBrowser": false,
"applicationUrl": "https://localhost:5001;http://localhost:5000",
"environmentVariables": {
"ASPNETCORE_ENVIRONMENT": "Development"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using Microsoft.Extensions.DependencyInjection;
using ProtoBuf.Grpc.Configuration;

namespace Server_CS
{
internal class ServiceBinderWithServiceResolutionFromServiceCollection : ServiceBinder
{
private readonly IServiceCollection services;

public ServiceBinderWithServiceResolutionFromServiceCollection(IServiceCollection services)
{
this.services = services;
}

public override IList<object> GetMetadata(MethodInfo method, Type contractType, Type serviceType)
{
var resolvedServiceType = serviceType;
if (serviceType.IsInterface)
resolvedServiceType = services.SingleOrDefault(x => x.ServiceType == serviceType)?.ImplementationType ?? serviceType;

return base.GetMetadata(method, contractType, resolvedServiceType);
}
}
}
13 changes: 13 additions & 0 deletions examples/pb-net-grpc/Server_CS/Startup.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using ProtoBuf.Grpc.Configuration;
using ProtoBuf.Grpc.Server;
using Shared_CS;

namespace Server_CS
{
Expand All @@ -15,16 +18,26 @@ public void ConfigureServices(IServiceCollection services)
{
config.ResponseCompressionLevel = System.IO.Compression.CompressionLevel.Optimal;
});
services.TryAddSingleton(BinderConfiguration.Create(binder: new ServiceBinderWithServiceResolutionFromServiceCollection(services)));
services.AddCodeFirstGrpcReflection();

services.AddAuthentication(FakeAuthHandler.SchemeName)
.AddScheme<FakeAuthOptions, FakeAuthHandler>(FakeAuthHandler.SchemeName, options => options.AlwaysAuthenticate = true);
services.AddAuthorization();
services.AddSingleton<ICounter, MyCounter>();
}

// This method gets called by the runtime. Use this method to configure the HTTP request pipeline.
public void Configure(IApplicationBuilder app, IWebHostEnvironment _)
{
app.UseRouting();

app.UseAuthentication();
app.UseAuthorization();

app.UseEndpoints(endpoints =>
{
endpoints.MapGrpcService<ICounter>();
endpoints.MapGrpcService<MyCalculator>();
endpoints.MapGrpcService<MyTimeService>();
endpoints.MapCodeFirstGrpcReflectionService();
Expand Down
25 changes: 25 additions & 0 deletions examples/pb-net-grpc/Shared_CS/Counter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using System.Runtime.Serialization;
using System.ServiceModel;
using System.Threading.Tasks;
namespace Shared_CS
{
[ServiceContract]
public interface ICounter
{
ValueTask<IncrementResult> IncrementAsync(IncrementRequest request);
}

[DataContract]
public class IncrementRequest
{
[DataMember(Order = 1)]
public int Inc { get; set; }
}

[DataContract]
public class IncrementResult
{
[DataMember(Order = 1)]
public int Result { get; set; }
}
}
76 changes: 18 additions & 58 deletions src/protobuf-net.Grpc/Configuration/ServerBinder.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
using Grpc.Core;
using ProtoBuf.Grpc.Internal;
using System;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading.Tasks;
using Grpc.Core;
using ProtoBuf.Grpc.Internal;

namespace ProtoBuf.Grpc.Configuration
{
Expand Down Expand Up @@ -49,7 +49,7 @@ public int Bind(object state, Type serviceType, BinderConfiguration? binderConfi

var serviceContractSimplifiedExceptions = serviceImplSimplifiedExceptions || serviceContract.IsDefined(typeof(SimpleRpcExceptionsAttribute));
int svcOpCount = 0;
var bindCtx = new ServiceBindContext(serviceContract, serviceType, state);
var bindCtx = new ServiceBindContext(serviceContract, serviceType, state, binderConfiguration.Binder);
foreach (var op in ContractOperation.FindOperations(binderConfiguration, serviceContract, this))
{
if (ServerInvokerLookup.TryGetValue(op.MethodType, op.Context, op.Result, op.Void, out var invoker)
Expand Down Expand Up @@ -271,77 +271,37 @@ protected internal sealed class ServiceBindContext
/// The caller-provided state for this operation
/// </summary>
public object State { get; }

/// <summary>
/// The service binder to use.
/// </summary>
public ServiceBinder ServiceBinder { get; }

/// <summary>
/// The service contract interface type
/// </summary>
public Type ContractType { get; }

/// <summary>
/// The concrete service type
/// </summary>
public Type ServiceType { get; }

private InterfaceMapping? _map;
private InterfaceMapping GetMap() // lazily memoized
=> _map ??= ServiceType.GetInterfaceMap(ContractType);
internal ServiceBindContext(Type contractType, Type serviceType, object state)
internal ServiceBindContext(Type contractType, Type serviceType, object state, ServiceBinder serviceBinder)
{
State = state;
ServiceBinder = serviceBinder;
ContractType = contractType;
ServiceType = serviceType;
}

/// <summary>
/// Gets the implementing method from a method definition
/// </summary>
public MethodInfo? GetImplementation(MethodInfo serviceMethod)
{
if (ContractType != ServiceType & serviceMethod is object)
{
var map = GetMap();
var from = map.InterfaceMethods;
var to = map.TargetMethods;
int end = Math.Min(from.Length, to.Length);
for (int i = 0; i < end; i++)
{
if (from[i] == serviceMethod) return to[i];
}
}
return null;
}

/// <summary>
/// Gets the metadata associated with a specific contract method
/// <para>Gets the metadata associated with a specific contract method.</para>
/// <para>Note: Later is higher priority in the code that consumes this.</para>
/// </summary>
public List<object> GetMetadata(MethodInfo method)
{
// consider the various possible sources of distinct metadata
object[]
contractType = ContractType.GetCustomAttributes(inherit: true),
contractMethod = method.GetCustomAttributes(inherit: true),
serviceType = Array.Empty<object>(),
serviceMethod = Array.Empty<object>();
if (ContractType != ServiceType & ContractType.IsInterface & ServiceType.IsClass)
{
serviceType = ServiceType.GetCustomAttributes(inherit: true);
serviceMethod = GetImplementation(method)?.GetCustomAttributes(inherit: true)
?? Array.Empty<object>();
}

// note: later is higher priority in the code that consumes this, but
// GetAttributes() is "most derived to least derived", so: add everything
// backwards, then reverse
var metadata = new List<object>(
contractType.Length + contractMethod.Length +
serviceType.Length + serviceMethod.Length);

metadata.AddRange(serviceMethod);
metadata.AddRange(serviceType);
metadata.AddRange(contractMethod);
metadata.AddRange(contractType);
metadata.Reverse();
return metadata;
}
/// <returns>Prioritised list of metadata.</returns>
public IList<object> GetMetadata(MethodInfo method)
=> ServiceBinder.GetMetadata(method, ContractType, ServiceType);
}
}

}
Loading

0 comments on commit e7d5887

Please sign in to comment.