From d90ecd4a4ccff83875949f07f9353e2b3738fd2f Mon Sep 17 00:00:00 2001 From: Marc Gravell Date: Mon, 14 Oct 2024 12:31:48 +0100 Subject: [PATCH] Implement Stream api support (#341) * client-to-server spike logic; not tested * server bindings (and binding tests, no integration test yet) * most basic of basic integration tests * optimize TryFastParse; needs tests * verify behaviour in all expected scenarios * for compat: don't demand the trailer Signed-off-by: Marc Gravell * nit * marshaller validation * release notes * fix StreamRewriteBasicTest (timing brittleness) --------- Signed-off-by: Marc Gravell --- Directory.Packages.props | 7 +- docs/releasenotes.md | 6 + examples/pb-net/JustProtos/SomeType.cs | 4 +- .../SchemaGenerator.cs | 2 +- .../Configuration/BinderConfiguration.cs | 2 +- .../Configuration/ClientFactory.cs | 13 +- .../GoogleProtobufMarshallerFactory.cs | 12 +- .../ProtoBufMarshallerFactory.cs | 2 + .../Configuration/ServerBinder.cs | 4 +- src/protobuf-net.Grpc/Internal/BytesValue.cs | 354 ++++++++++ .../Internal/ContractOperation.cs | 90 ++- src/protobuf-net.Grpc/Internal/Empty.cs | 4 +- .../Internal/MarshallerCache.cs | 14 +- .../Internal/MetadataContext.cs | 15 +- .../Internal/ProxyEmitter.cs | 34 +- .../Internal/Reshape.ByteStream.cs | 264 ++++++++ src/protobuf-net.Grpc/Internal/Reshape.cs | 16 +- .../Internal/ServerInvokerLookup.cs | 36 +- .../protobuf-net.Grpc.csproj | 5 + .../FileDescriptorSetFactoryTests.cs | 4 +- .../ReflectionServiceTests.cs | 7 +- .../ClientProxyTests.cs | 2 +- .../StreamTests.cs | 171 ++++- .../BytesValueMarshallerTests.cs | 129 ++++ .../ContractOperationTests.cs | 628 +++++++++++++++++- tests/protobuf-net.Grpc.Test/IAllOptions.cs | 19 +- tests/protobuf-net.Grpc.Test/TestBindings.cs | 21 +- .../protobuf-net.Grpc.Test.csproj | 2 +- version.json | 2 +- 29 files changed, 1749 insertions(+), 120 deletions(-) create mode 100644 src/protobuf-net.Grpc/Internal/BytesValue.cs create mode 100644 src/protobuf-net.Grpc/Internal/Reshape.ByteStream.cs create mode 100644 tests/protobuf-net.Grpc.Test/BytesValueMarshallerTests.cs diff --git a/Directory.Packages.props b/Directory.Packages.props index a9647bfa..6d776606 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -21,10 +21,13 @@ + + - - + + + diff --git a/docs/releasenotes.md b/docs/releasenotes.md index 5fdc8287..eb1ba07b 100644 --- a/docs/releasenotes.md +++ b/docs/releasenotes.md @@ -2,7 +2,13 @@ ## unreleased +## 1.2.0 + +- support `[Value]Task` as a return value, rewriting via [`stream BytesValue`](https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/wrappers.proto) - first + step in [#340](https://github.com/protobuf-net/protobuf-net.Grpc/issues/340) +- update library references and TFMs - improve handling of `IAsyncDisposable` +- improve error message when binding methods ([#331](https://github.com/protobuf-net/protobuf-net.Grpc/pull/331) via BasConijn) ## 1.1.1 diff --git a/examples/pb-net/JustProtos/SomeType.cs b/examples/pb-net/JustProtos/SomeType.cs index 718651a0..f7c7ace3 100644 --- a/examples/pb-net/JustProtos/SomeType.cs +++ b/examples/pb-net/JustProtos/SomeType.cs @@ -12,11 +12,11 @@ static void Foo() { // the point here being: these types *exist*, despite // not appearing as local .cs files - Type[] types = { + Type[] types = [ typeof(DescriptorProto), typeof(TimeResult), typeof(MultiplyRequest), - }; + ]; _ = types; } } diff --git a/src/protobuf-net.Grpc.Reflection/SchemaGenerator.cs b/src/protobuf-net.Grpc.Reflection/SchemaGenerator.cs index e5918e73..f566d58c 100644 --- a/src/protobuf-net.Grpc.Reflection/SchemaGenerator.cs +++ b/src/protobuf-net.Grpc.Reflection/SchemaGenerator.cs @@ -41,7 +41,7 @@ public string GetSchema() /// this method need to remain for backward compatibility for client which will get this updated version, without recompilation. /// Thus, this method mustn't be deleted. public string GetSchema(Type contractType) - => GetSchema(new [] {contractType}); + => GetSchema([contractType]); /// /// Get the .proto schema associated with multiple service contracts diff --git a/src/protobuf-net.Grpc/Configuration/BinderConfiguration.cs b/src/protobuf-net.Grpc/Configuration/BinderConfiguration.cs index 6b31c431..e31146f4 100644 --- a/src/protobuf-net.Grpc/Configuration/BinderConfiguration.cs +++ b/src/protobuf-net.Grpc/Configuration/BinderConfiguration.cs @@ -12,7 +12,7 @@ namespace ProtoBuf.Grpc.Configuration public sealed class BinderConfiguration { // this *must* stay above Default - .cctor order is file order - static readonly MarshallerFactory[] s_defaultFactories = new MarshallerFactory[] { ProtoBufMarshallerFactory.Default, ProtoBufMarshallerFactory.GoogleProtobuf }; + static readonly MarshallerFactory[] s_defaultFactories = [ProtoBufMarshallerFactory.Default, ProtoBufMarshallerFactory.GoogleProtobuf]; /// /// Use the default MarshallerFactory and ServiceBinder diff --git a/src/protobuf-net.Grpc/Configuration/ClientFactory.cs b/src/protobuf-net.Grpc/Configuration/ClientFactory.cs index 6a4b1540..15fba49d 100644 --- a/src/protobuf-net.Grpc/Configuration/ClientFactory.cs +++ b/src/protobuf-net.Grpc/Configuration/ClientFactory.cs @@ -46,14 +46,9 @@ public virtual GrpcClient CreateClient(CallInvoker channel, Type contractType) => new GrpcClient(channel, contractType, BinderConfiguration); - private sealed class ConfiguredClientFactory : ClientFactory + private sealed class ConfiguredClientFactory(BinderConfiguration? binderConfiguration) : ClientFactory { - protected override BinderConfiguration BinderConfiguration { get; } - - public ConfiguredClientFactory(BinderConfiguration? binderConfiguration) - { - BinderConfiguration = binderConfiguration ?? BinderConfiguration.Default; - } + protected override BinderConfiguration BinderConfiguration { get; } = binderConfiguration ?? BinderConfiguration.Default; private readonly ConcurrentDictionary _proxyCache = new ConcurrentDictionary(); @@ -61,7 +56,7 @@ public ConfiguredClientFactory(BinderConfiguration? binderConfiguration) private TService SlowCreateClient(CallInvoker channel) where TService : class { - var factory = ProxyEmitter.CreateFactory(BinderConfiguration); + var factory = ProxyEmitter.CreateFactory(BinderConfiguration, null); var key = typeof(TService); if (!_proxyCache.TryAdd(key, factory)) factory = (Func)_proxyCache[key]; @@ -78,7 +73,7 @@ public override TService CreateClient(CallInvoker channel) internal static class DefaultProxyCache where TService : class { - internal static readonly Func Create = ProxyEmitter.CreateFactory(BinderConfiguration.Default); + internal static readonly Func Create = ProxyEmitter.CreateFactory(BinderConfiguration.Default, null); } private sealed class DefaultClientFactory : ClientFactory diff --git a/src/protobuf-net.Grpc/Configuration/GoogleProtobufMarshallerFactory.cs b/src/protobuf-net.Grpc/Configuration/GoogleProtobufMarshallerFactory.cs index 53207aeb..aaafc110 100644 --- a/src/protobuf-net.Grpc/Configuration/GoogleProtobufMarshallerFactory.cs +++ b/src/protobuf-net.Grpc/Configuration/GoogleProtobufMarshallerFactory.cs @@ -64,16 +64,16 @@ protected internal override Marshaller CreateMarshaller() parser.ParseFrom(context.PayloadAsReadOnlySequence() */ var context = Expression.Parameter(typeof(global::Grpc.Core.DeserializationContext), "context"); - var parseFrom = parser.PropertyType.GetMethod("ParseFrom", new Type[] { typeof(ReadOnlySequence) })!; + var parseFrom = parser.PropertyType.GetMethod("ParseFrom", [typeof(ReadOnlySequence)])!; Expression body = Expression.Call(Expression.Constant(parser.GetValue(null), parser.PropertyType), parseFrom, Expression.Call(context, nameof(DeserializationContext.PayloadAsReadOnlySequence), Type.EmptyTypes)); deserializer = Expression.Lambda>(body, context).Compile(); var message = Expression.Parameter(typeof(T), "message"); context = Expression.Parameter(typeof(global::Grpc.Core.SerializationContext), "context"); - var setPayloadLength = typeof(global::Grpc.Core.SerializationContext).GetMethod(nameof(global::Grpc.Core.SerializationContext.SetPayloadLength), new Type[] { typeof(int) })!; + var setPayloadLength = typeof(global::Grpc.Core.SerializationContext).GetMethod(nameof(global::Grpc.Core.SerializationContext.SetPayloadLength), [typeof(int)])!; var calculateSize = iMessage.GetMethod("CalculateSize", Type.EmptyTypes)!; - var writeTo = me.GetMethod("WriteTo", new Type[] { iMessage, typeof(IBufferWriter) })!; + var writeTo = me.GetMethod("WriteTo", [iMessage, typeof(IBufferWriter)])!; body = Expression.Block( Expression.Call(context, setPayloadLength, Expression.Call(message, calculateSize)), Expression.Call(writeTo, message, Expression.Call(context, "GetBufferWriter", Type.EmptyTypes)), @@ -92,16 +92,16 @@ protected internal override Marshaller CreateMarshaller() */ var context = Expression.Parameter(typeof(global::Grpc.Core.DeserializationContext), "context"); - var parseFrom = parser.PropertyType.GetMethod("ParseFrom", new Type[] { typeof(byte[]) })!; + var parseFrom = parser.PropertyType.GetMethod("ParseFrom", [typeof(byte[])])!; Expression body = Expression.Call(Expression.Constant(parser.GetValue(null), parser.PropertyType), parseFrom, Expression.Call(context, nameof(DeserializationContext.PayloadAsNewBuffer), Type.EmptyTypes)); deserializer = Expression.Lambda>(body, context).Compile(); var message = Expression.Parameter(typeof(T), "message"); context = Expression.Parameter(typeof(global::Grpc.Core.SerializationContext), "context"); - var toByteArray = me.GetMethod("ToByteArray", new Type[] { iMessage })!; + var toByteArray = me.GetMethod("ToByteArray", [iMessage])!; var complete = typeof(global::Grpc.Core.SerializationContext).GetMethod( - nameof(global::Grpc.Core.SerializationContext.Complete), new Type[] { typeof(byte[]) })!; + nameof(global::Grpc.Core.SerializationContext.Complete), [typeof(byte[])])!; body = Expression.Call(context, complete, Expression.Call(toByteArray, message)); serializer = Expression.Lambda>(body, message, context).Compile(); } diff --git a/src/protobuf-net.Grpc/Configuration/ProtoBufMarshallerFactory.cs b/src/protobuf-net.Grpc/Configuration/ProtoBufMarshallerFactory.cs index 7595151d..84f6d947 100644 --- a/src/protobuf-net.Grpc/Configuration/ProtoBufMarshallerFactory.cs +++ b/src/protobuf-net.Grpc/Configuration/ProtoBufMarshallerFactory.cs @@ -55,7 +55,9 @@ public enum Options // note: these are the same *object*, but pre-checked for optional API support, for efficiency // (the minimum .NET object size means that the extra fields don't cost anything) private readonly IMeasuredProtoOutput>? _measuredWriterModel; +#pragma warning disable CA1859 // change type of field for performance - but actually this is a speculative test private readonly IProtoInput>? _squenceReaderModel; +#pragma warning restore CA1859 /// /// Create a new factory using a specific protobuf-net model diff --git a/src/protobuf-net.Grpc/Configuration/ServerBinder.cs b/src/protobuf-net.Grpc/Configuration/ServerBinder.cs index 6642a2f9..22edf7eb 100644 --- a/src/protobuf-net.Grpc/Configuration/ServerBinder.cs +++ b/src/protobuf-net.Grpc/Configuration/ServerBinder.cs @@ -37,7 +37,7 @@ public int Bind(object state, Type serviceType, BinderConfiguration? binderConfi { int totalCount = 0; object?[]? argsBuffer = null; - Type[] typesBuffer = Array.Empty(); + Type[] typesBuffer = []; binderConfiguration ??= BinderConfiguration.Default; var potentialServiceContracts = typeof(IGrpcService).IsAssignableFrom(serviceType) ? new HashSet {serviceType} @@ -92,7 +92,7 @@ bool AddMethod(string? serviceName, Type @in, Type @out, string on, MethodInfo m { if (typesBuffer.Length == 0) { - typesBuffer = new Type[] {serviceType, typeof(void), typeof(void)}; + typesBuffer = [serviceType, typeof(void), typeof(void)]; } typesBuffer[1] = @in; diff --git a/src/protobuf-net.Grpc/Internal/BytesValue.cs b/src/protobuf-net.Grpc/Internal/BytesValue.cs new file mode 100644 index 00000000..1039e505 --- /dev/null +++ b/src/protobuf-net.Grpc/Internal/BytesValue.cs @@ -0,0 +1,354 @@ +using Grpc.Core; +using ProtoBuf.Meta; +using System; +using System.Buffers; +using System.Buffers.Binary; +using System.ComponentModel; +using System.Diagnostics; +using System.IO; +using System.Runtime.InteropServices; +using System.Threading; + +namespace ProtoBuf.Grpc.Internal; + + +/// +/// Represents a single BytesValue chunk (as per wrappers.proto) +/// +[ProtoContract(Name = ".google.protobuf.BytesValue")] +[Obsolete(Reshape.WarningMessage, false)] +[Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] +public sealed class BytesValue(byte[] oversized, int length, bool pooled) +{ + /// + /// Indicates the maximum length supported for individual chunks when using API rewriting. + /// + public const int MaxLength = 0x1FFFFF; // 21 bits of length prefix; 2,097,151 bytes + // (note we will still *read* buffers larger than that, because of non-"us" endpoints, but we'll never send them) + + +#if DEBUG + private static int _fastPassMiss = 0; + internal static int FastPassMiss => Volatile.Read(ref _fastPassMiss); +#endif + + [Flags] + enum Flags : byte + { + None = 0, + Pooled = 1 << 0, + Recycled = 1 << 1, + } + private Flags _flags = pooled ? Flags.Pooled : Flags.None; + private byte[] _oversized = oversized; + private int _length = length; + + private BytesValue() : this([], 0, false) { } // for deserialization + + internal bool IsPooled => (_flags & Flags.Pooled) != 0; + + internal bool IsRecycled => (_flags & Flags.Recycled) != 0; + + /// + /// Gets or sets the value as a right-sized array + /// + [ProtoMember(1)] + public byte[] RightSized // for deserializer only + { + get + { + ThrowIfRecycled(); + if (_oversized.Length != _length) + { + Array.Resize(ref _oversized, _length); + _flags &= ~Flags.Pooled; + } + return _oversized; + } + set + { + value ??= []; + _length = value.Length; + _oversized = value; + } + } + + /// + /// Recycles this instance, releasing the buffer (if pooled), and resetting the length to zero. + /// + public void Recycle() + { + var flags = _flags; + _flags = Flags.Recycled; + var tmp = _oversized; + _length = 0; + _oversized = []; + + if ((flags & Flags.Pooled) != 0) + { + ArrayPool.Shared.Return(tmp); + } + } + + private void ThrowIfRecycled() + { + if ((_flags & Flags.Recycled) != 0) + { + Throw(); + } + static void Throw() => throw new InvalidOperationException("This " + nameof(BytesValue) + " instance has been recycled"); + } + + /// + /// Indicates whether this value is empty (zero bytes) + /// + public bool IsEmpty => _length == 0; + + /// + /// Gets the size (in bytes) of this value + /// + public int Length => _length; + + /// + /// Gets the payload as an + /// + public ArraySegment ArraySegment + { + get + { + ThrowIfRecycled(); + return new(_oversized, 0, _length); + } + } + + /// + /// Gets the payload as a + /// + public ReadOnlySpan Span + { + get + { + ThrowIfRecycled(); + return new(_oversized, 0, _length); + } + } + + /// + /// Gets the payload as a + /// + public ReadOnlyMemory Memory + { + get + { + ThrowIfRecycled(); + return new(_oversized, 0, _length); + } + } + + + /// + /// Gets the gRPC marshaller for this type. + /// + public static Marshaller Marshaller { get; } = new(Serialize, Deserialize); + + private static BytesValue Deserialize(DeserializationContext context) + { + try + { + var payload = context.PayloadAsReadOnlySequence(); + var totalLen = payload.Length; + BytesValue? result; + + if (payload.First.Length >= 4) + { + // enough bytes in the first segment + result = TryFastParse(payload.First.Span, payload); + } + else + { + // copy up-to 4 bytes into a buffer, handling multi-segment concerns + Span buffer = stackalloc byte[4]; + payload.Slice(0, (int)Math.Min(totalLen, 4)).CopyTo(buffer); + result = TryFastParse(buffer, payload); + } + + return result ?? SlowParse(payload); + } + catch (Exception ex) + { + Debug.WriteLine(ex.Message); + throw; + } + } + + private static BytesValue SlowParse(in ReadOnlySequence payload) + { + IProtoInput model = RuntimeTypeModel.Default; + var len = payload.Length; + // use protobuf-net v3 API if available + if (model is IProtoInput> v3) + { + return v3.Deserialize(payload); + } + + // use protobuf-net v2 API + MemoryStream ms; + if (payload.IsSingleSegment && MemoryMarshal.TryGetArray(payload.First, out var segment)) + { + ms = new MemoryStream(segment.Array!, segment.Offset, segment.Count, writable: false, publiclyVisible: true); + } + else + { + ms = new MemoryStream(); + ms.SetLength(len); + if (ms.TryGetBuffer(out var buffer) && buffer.Count >= len) + { + payload.CopyTo(buffer.AsSpan()); + } + else + { +#if !(NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER) + byte[] leased = []; +#endif + foreach (var chunk in payload) + { +#if NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER + ms.Write(chunk.Span); +#else + if (MemoryMarshal.TryGetArray(chunk, out segment)) + { + ms.Write(segment.Array!, segment.Offset, segment.Count); + } + else + { + if (leased.Length < segment.Count) + { + ArrayPool.Shared.Return(leased); + leased = ArrayPool.Shared.Rent(segment.Count); + } + segment.AsSpan().CopyTo(leased); + ms.Write(leased, 0, segment.Count); + } +#endif + } +#if !(NETSTANDARD2_1_OR_GREATER || NET5_0_OR_GREATER) + if (leased.Length != 0) + { + ArrayPool.Shared.Return(leased); + } +#endif + Debug.Assert(ms.Position == len, "should have written all bytes"); + ms.Position = 0; + } + } + Debug.Assert(ms.Position == 0 && ms.Length == len, "full payload should be ready to read"); + return model.Deserialize(ms); + } + + internal static BytesValue? TryFastParse(ReadOnlySpan start, in ReadOnlySequence payload) + { + // note: optimized for little-endian CPUs, but safe anywhere (big-endian has an extra reverse) + int raw = BinaryPrimitives.ReadInt32LittleEndian(start); + int byteLen, headerLen; + switch (raw & 0x808080FF) // test the entire first byte, and the MSBs of the rest + { + // one-byte length, with anything after (0A00*, backwards) + case 0x0000000A: + case 0x8000000A: + case 0x0080000A: + case 0x8080000A: + headerLen = 2; + byteLen = (raw & 0x7F00) >> 8; + break; + // two-byte length, with anything after (0A8000*, backwards) + case 0x0000800A: + case 0x8000800A: + headerLen = 3; + byteLen = ((raw & 0x7F00) >> 8) | ((raw & 0x7F0000) >> 9); + break; + // three-byte length (0A808000, backwards) + case 0x0080800A: + headerLen = 4; + byteLen = ((raw & 0x7F00) >> 8) | ((raw & 0x7F0000) >> 9) | ((raw & 0x7F000000) >> 10); + break; + default: + return null; // not optimized + } + if (headerLen + byteLen != payload.Length) + { +#if DEBUG + Interlocked.Increment(ref _fastPassMiss); +#endif + return null; // not the exact payload (other fields?) + } + +#if DEBUG + // double-check our math using the less efficient library functions + var arr = start.Slice(0, 4).ToArray(); + Debug.Assert(start[0] == 0x0A, "field 1, string"); + Debug.Assert(Serializer.TryReadLengthPrefix(arr, 1, 3, PrefixStyle.Base128, out int checkLen) + && checkLen == byteLen, $"length mismatch; {byteLen} vs {checkLen}"); +#endif + + var leased = ArrayPool.Shared.Rent(byteLen); + payload.Slice(headerLen).CopyTo(leased); + return new(leased, byteLen, pooled: true); + } + + private static void Serialize(BytesValue value, global::Grpc.Core.SerializationContext context) + { + int byteLen = value.Length, headerLen; + if (byteLen <= 0x7F) // 7 bit + { + headerLen = 2; + } + else if (byteLen <= 0x3FFF) // 14 bit + { + headerLen = 3; + } + else if (byteLen <= 0x1FFFFF) // 21 bit + { + headerLen = 4; + } + else + { + throw new NotSupportedException("We don't expect to write messages this large!"); + } + int totalLength = headerLen + byteLen; + context.SetPayloadLength(totalLength); + var writer = context.GetBufferWriter(); + var buffer = writer.GetSpan(totalLength); + // we'll assume that we get space for at least the header bytes, but we can *hope* for the entire thing + + buffer[0] = 0x0A; // field 1, string + switch (headerLen) + { + case 2: + buffer[1] = (byte)byteLen; + break; + case 3: + buffer[1] = (byte)(byteLen | 0x80); + buffer[2] = (byte)(byteLen >> 7); + break; + case 4: + buffer[1] = (byte)(byteLen | 0x80); + buffer[2] = (byte)((byteLen >> 7) | 0x80); + buffer[3] = (byte)(byteLen >> 14); + break; + } + if (buffer.Length >= totalLength) + { + // write everything in one go + value.Span.CopyTo(buffer.Slice(headerLen)); + writer.Advance(totalLength); + } + else + { + // commit the header, then write the body + writer.Advance(headerLen); + writer.Write(value.Span); + } + value.Recycle(); + context.Complete(); + } +} \ No newline at end of file diff --git a/src/protobuf-net.Grpc/Internal/ContractOperation.cs b/src/protobuf-net.Grpc/Internal/ContractOperation.cs index a943f378..a208a660 100644 --- a/src/protobuf-net.Grpc/Internal/ContractOperation.cs +++ b/src/protobuf-net.Grpc/Internal/ContractOperation.cs @@ -6,43 +6,33 @@ using System.Threading.Tasks; using System.Linq; using System.Threading; +using System.IO; namespace ProtoBuf.Grpc.Internal { - internal readonly struct ContractOperation + internal readonly struct ContractOperation(string name, Type from, Type to, MethodInfo method, + MethodType methodType, ContextKind contextKind, ResultKind arg, ResultKind resultKind, VoidKind @void) { - public string Name { get; } - public Type From { get; } - public Type To { get; } - public MethodInfo Method { get; } - public MethodType MethodType { get; } - public ContextKind Context { get; } - public ResultKind Arg { get; } - public ResultKind Result { get; } - public VoidKind Void { get; } + public string Name { get; } = name; + public Type From { get; } = from; + public Type To { get; } = to; + public MethodInfo Method { get; } = method; + public MethodType MethodType { get; } = methodType; + public ContextKind Context { get; } = contextKind; + public ResultKind Arg { get; } = arg; + public ResultKind Result { get; } = resultKind; + public VoidKind Void { get; } = @void; public bool VoidRequest => (Void & VoidKind.Request) != 0; public bool VoidResponse => (Void & VoidKind.Response) != 0; public override string ToString() => $"{Name}: {From.Name}=>{To.Name}, {MethodType}, {Result}, {Context}, {Void}"; - public ContractOperation(string name, Type from, Type to, MethodInfo method, - MethodType methodType, ContextKind contextKind, ResultKind arg, ResultKind resultKind, VoidKind @void) - { - Name = name; - From = from; - To = to; - Method = method; - MethodType = methodType; - Context = contextKind; - Arg = arg; - Result = resultKind; - Void = @void; - } - internal enum TypeCategory { None, Void, + Data, + UntypedTask, UntypedValueTask, TypedTask, @@ -59,7 +49,10 @@ internal enum TypeCategory AsyncClientStreamingCall, AsyncDuplexStreamingCall, AsyncServerStreamingCall, - Data, + Stream, + TaskStream, + ValueTaskStream, + Invalid, } @@ -181,6 +174,21 @@ internal enum TypeCategory { (TypeCategory.IObservable, TypeCategory.None, TypeCategory.None, TypeCategory.IObservable), (ContextKind.NoContext, MethodType.DuplexStreaming, ResultKind.Observable,ResultKind.Observable, VoidKind.None, 0, RET) }, { (TypeCategory.IObservable, TypeCategory.CallContext, TypeCategory.None, TypeCategory.IObservable), (ContextKind.CallContext, MethodType.DuplexStreaming, ResultKind.Observable, ResultKind.Observable, VoidKind.None, 0, RET) }, { (TypeCategory.IObservable, TypeCategory.CancellationToken, TypeCategory.None, TypeCategory.IObservable), (ContextKind.CancellationToken, MethodType.DuplexStreaming, ResultKind.Observable, ResultKind.Observable, VoidKind.None, 0, RET) }, + + // server streaming via Stream, with/without arg + {(TypeCategory.None, TypeCategory.None, TypeCategory.None, TypeCategory.TaskStream), (ContextKind.NoContext, MethodType.ServerStreaming, ResultKind.Sync, ResultKind.TaskStream, VoidKind.Request, VOID, RET) }, + {(TypeCategory.CallContext, TypeCategory.None, TypeCategory.None, TypeCategory.TaskStream), (ContextKind.CallContext, MethodType.ServerStreaming, ResultKind.Sync, ResultKind.TaskStream, VoidKind.Request, VOID, RET) }, + {(TypeCategory.CancellationToken, TypeCategory.None, TypeCategory.None, TypeCategory.TaskStream), (ContextKind.CancellationToken, MethodType.ServerStreaming, ResultKind.Sync, ResultKind.TaskStream, VoidKind.Request, VOID, RET) }, + {(TypeCategory.Data, TypeCategory.None, TypeCategory.None, TypeCategory.TaskStream), (ContextKind.NoContext, MethodType.ServerStreaming, ResultKind.Sync, ResultKind.TaskStream, VoidKind.None, 0, RET) }, + {(TypeCategory.Data, TypeCategory.CallContext, TypeCategory.None, TypeCategory.TaskStream), (ContextKind.CallContext, MethodType.ServerStreaming, ResultKind.Sync, ResultKind.TaskStream, VoidKind.None, 0, RET) }, + {(TypeCategory.Data, TypeCategory.CancellationToken, TypeCategory.None, TypeCategory.TaskStream), (ContextKind.CancellationToken, MethodType.ServerStreaming, ResultKind.Sync, ResultKind.TaskStream, VoidKind.None, 0, RET) }, + + {(TypeCategory.None, TypeCategory.None, TypeCategory.None, TypeCategory.ValueTaskStream), (ContextKind.NoContext, MethodType.ServerStreaming, ResultKind.Sync, ResultKind.ValueTaskStream, VoidKind.Request, VOID, RET) }, + {(TypeCategory.CallContext, TypeCategory.None, TypeCategory.None, TypeCategory.ValueTaskStream), (ContextKind.CallContext, MethodType.ServerStreaming, ResultKind.Sync, ResultKind.ValueTaskStream, VoidKind.Request, VOID, RET) }, + {(TypeCategory.CancellationToken, TypeCategory.None, TypeCategory.None, TypeCategory.ValueTaskStream), (ContextKind.CancellationToken, MethodType.ServerStreaming, ResultKind.Sync, ResultKind.ValueTaskStream, VoidKind.Request, VOID, RET) }, + {(TypeCategory.Data, TypeCategory.None, TypeCategory.None, TypeCategory.ValueTaskStream), (ContextKind.NoContext, MethodType.ServerStreaming, ResultKind.Sync, ResultKind.ValueTaskStream, VoidKind.None, 0, RET) }, + {(TypeCategory.Data, TypeCategory.CallContext, TypeCategory.None, TypeCategory.ValueTaskStream), (ContextKind.CallContext, MethodType.ServerStreaming, ResultKind.Sync, ResultKind.ValueTaskStream, VoidKind.None, 0, RET) }, + {(TypeCategory.Data, TypeCategory.CancellationToken, TypeCategory.None, TypeCategory.ValueTaskStream), (ContextKind.CancellationToken, MethodType.ServerStreaming, ResultKind.Sync, ResultKind.ValueTaskStream, VoidKind.None, 0, RET) }, }; internal static int SignatureCount => s_signaturePatterns.Count; @@ -196,6 +204,9 @@ static TypeCategory GetCategory(MarshallerCache marshallerCache, Type type, IBin if (type == typeof(CallOptions)) return TypeCategory.CallOptions; if (type == typeof(CallContext)) return TypeCategory.CallContext; if (type == typeof(CancellationToken)) return TypeCategory.CancellationToken; + if (type == typeof(Stream)) return TypeCategory.Stream; + if (type == typeof(Task)) return TypeCategory.TaskStream; + if (type == typeof(ValueTask)) return TypeCategory.ValueTaskStream; if (type.IsGenericType) { @@ -229,6 +240,7 @@ private static (TypeCategory Arg0, TypeCategory Arg1, TypeCategory Arg2, TypeCat signature.Ret = GetCategory(marshallerCache, returnType, bindContext); return signature; } + [System.Diagnostics.CodeAnalysis.SuppressMessage("Style", "IDE0066:Convert switch statement to expression", Justification = "Clarity")] internal static bool TryIdentifySignature(MethodInfo method, BinderConfiguration binderConfig, out ContractOperation operation, IBindContext? bindContext) { operation = default; @@ -244,7 +256,10 @@ internal static bool TryIdentifySignature(MethodInfo method, BinderConfiguration var signature = GetSignature(binderConfig.MarshallerCache, args, method.ReturnType, bindContext); - if (!s_signaturePatterns.TryGetValue(signature, out var config)) return false; + if (!s_signaturePatterns.TryGetValue(signature, out var config)) + { + return false; + } (Type type, TypeCategory category) GetTypeByIndex(int index) { @@ -271,6 +286,12 @@ static Type GetDataType((Type type, TypeCategory category) key, bool req) case TypeCategory.UntypedValueTask: #pragma warning disable CS0618 // Empty return typeof(Empty); +#pragma warning restore CS0618 + case TypeCategory.TaskStream: + case TypeCategory.ValueTaskStream: + case TypeCategory.Stream: +#pragma warning disable CS0618 // BytesValue + return typeof(BytesValue); #pragma warning restore CS0618 case TypeCategory.TypedTask: case TypeCategory.TypedValueTask: @@ -341,7 +362,7 @@ where parameters.Length > 1 && parameters[0].ParameterType == typeof(CallContext).MakeByRefType() select method).ToDictionary(x => x.Name); - static readonly Dictionary<(MethodType, ResultKind, ResultKind, VoidKind), string> _clientResponseMap = new Dictionary<(MethodType, ResultKind, ResultKind, VoidKind), string> + static readonly Dictionary<(MethodType Method, ResultKind Arg, ResultKind Result, VoidKind Void), string> _clientResponseMap = new Dictionary<(MethodType, ResultKind, ResultKind, VoidKind), string> { {(MethodType.DuplexStreaming, ResultKind.AsyncEnumerable, ResultKind.AsyncEnumerable, VoidKind.None), nameof(Reshape.DuplexAsync) }, {(MethodType.DuplexStreaming, ResultKind.Observable, ResultKind.Observable, VoidKind.None), nameof(Reshape.DuplexObservable) }, @@ -361,6 +382,9 @@ where parameters.Length > 1 {(MethodType.Unary, ResultKind.Sync, ResultKind.ValueTask, VoidKind.Response), nameof(Reshape.UnaryValueTaskAsyncVoid) }, {(MethodType.Unary, ResultKind.Sync, ResultKind.Sync, VoidKind.None), nameof(Reshape.UnarySync) }, {(MethodType.Unary, ResultKind.Sync, ResultKind.Sync, VoidKind.Response), nameof(Reshape.UnarySyncVoid) }, + + {(MethodType.ServerStreaming, ResultKind.Sync, ResultKind.TaskStream, VoidKind.None), nameof(Reshape.ServerByteStreamingTaskAsync) }, + {(MethodType.ServerStreaming, ResultKind.Sync, ResultKind.ValueTaskStream, VoidKind.None), nameof(Reshape.ServerByteStreamingValueTaskAsync) }, }; #pragma warning restore CS0618 @@ -411,10 +435,11 @@ internal static ISet ExpandInterfaces(Type type) internal static ISet ExpandWithInterfacesMarkedAsSubService(ServiceBinder serviceBinder, Type serviceContract) { - var set = new HashSet(); - - // first add the service contract by itself - set.Add(serviceContract); + var set = new HashSet + { + // first add the service contract by itself + serviceContract + }; // now add all inherited interfaces which are marked as sub-services foreach (var t in serviceContract.GetInterfaces()) @@ -462,6 +487,9 @@ internal enum ResultKind AsyncEnumerable, Grpc, Observable, + Stream, + TaskStream, + ValueTaskStream, } [Flags] diff --git a/src/protobuf-net.Grpc/Internal/Empty.cs b/src/protobuf-net.Grpc/Internal/Empty.cs index 7f3549ee..d405bc4f 100644 --- a/src/protobuf-net.Grpc/Internal/Empty.cs +++ b/src/protobuf-net.Grpc/Internal/Empty.cs @@ -33,9 +33,9 @@ private Empty() { } /// Compares two instances for equality /// public override int GetHashCode() => 42; - bool IEquatable.Equals(Empty? other) => other is object; + bool IEquatable.Equals(Empty? other) => other is not null; internal static readonly Marshaller Marshaller - = new Marshaller((Empty _)=> Array.Empty(), (byte[] _) => Instance); + = new Marshaller((Empty _) => [], (byte[] _) => Instance); } } diff --git a/src/protobuf-net.Grpc/Internal/MarshallerCache.cs b/src/protobuf-net.Grpc/Internal/MarshallerCache.cs index f4d4d79e..c4babdc6 100644 --- a/src/protobuf-net.Grpc/Internal/MarshallerCache.cs +++ b/src/protobuf-net.Grpc/Internal/MarshallerCache.cs @@ -7,18 +7,17 @@ namespace ProtoBuf.Grpc.Internal { - internal sealed class MarshallerCache + internal sealed class MarshallerCache(MarshallerFactory[] factories) { - private readonly MarshallerFactory[] _factories; - public MarshallerCache(MarshallerFactory[] factories) - => _factories = factories ?? throw new ArgumentNullException(nameof(factories)); + private readonly MarshallerFactory[] _factories = factories ?? throw new ArgumentNullException(nameof(factories)); + internal bool CanSerializeType(Type type) { if (_marshallers.TryGetValue(type, out var obj)) return obj != null; return SlowImpl(this, type); static bool SlowImpl(MarshallerCache obj, Type type) - => _createAndAdd.MakeGenericMethod(type).Invoke(obj, Array.Empty()) != null; + => _createAndAdd.MakeGenericMethod(type).Invoke(obj, []) != null; } static readonly MethodInfo _createAndAdd = typeof(MarshallerCache).GetMethod( nameof(CreateAndAdd), BindingFlags.Instance | BindingFlags.NonPublic)!; @@ -26,8 +25,9 @@ static bool SlowImpl(MarshallerCache obj, Type type) private readonly ConcurrentDictionary _marshallers = new ConcurrentDictionary { -#pragma warning disable CS0618 // Empty - [typeof(Empty)] = Empty.Marshaller +#pragma warning disable CS0618 // Empty, BytesValue + [typeof(Empty)] = Empty.Marshaller, + [typeof(BytesValue)] = BytesValue.Marshaller, #pragma warning restore CS0618 }; diff --git a/src/protobuf-net.Grpc/Internal/MetadataContext.cs b/src/protobuf-net.Grpc/Internal/MetadataContext.cs index 9fadda8f..dc89b12f 100644 --- a/src/protobuf-net.Grpc/Internal/MetadataContext.cs +++ b/src/protobuf-net.Grpc/Internal/MetadataContext.cs @@ -73,7 +73,20 @@ internal void SetTrailers(RpcException fault) } } - internal void SetTrailers(T? call, Func getStatus, Func getMetadata) + internal void SetTrailers(AsyncUnaryCall? call) + => SetTrailers(call, static c => c.GetStatus(), static c => c.GetTrailers()); + + internal void SetTrailers(AsyncClientStreamingCall? call) + => SetTrailers(call, static c => c.GetStatus(), static c => c.GetTrailers()); + + internal void SetTrailers(AsyncServerStreamingCall? call) + => SetTrailers(call, static c => c.GetStatus(), static c => c.GetTrailers()); + + internal void SetTrailers(AsyncDuplexStreamingCall? call) + => SetTrailers(call, static c => c.GetStatus(), static c => c.GetTrailers()); + + + private void SetTrailers(T? call, Func getStatus, Func getMetadata) where T : class { if (call is null) return; diff --git a/src/protobuf-net.Grpc/Internal/ProxyEmitter.cs b/src/protobuf-net.Grpc/Internal/ProxyEmitter.cs index a5218864..3ec577ec 100644 --- a/src/protobuf-net.Grpc/Internal/ProxyEmitter.cs +++ b/src/protobuf-net.Grpc/Internal/ProxyEmitter.cs @@ -109,7 +109,7 @@ private static void Ldarg(ILGenerator il, ushort index) static int _typeIndex; private static readonly MethodInfo s_marshallerCacheGenericMethodDef = typeof(MarshallerCache).GetMethod(nameof(MarshallerCache.GetMarshaller), BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic)!; - internal static Func CreateFactory(BinderConfiguration binderConfig) + internal static Func CreateFactory(BinderConfiguration binderConfig, Action? log) where TService : class { // front-load reflection discovery @@ -119,10 +119,13 @@ internal static Func CreateFactory(BinderConfig if (binderConfig == BinderConfiguration.Default) // only use ProxyAttribute for default binder { var proxy = (typeof(TService).GetCustomAttribute(typeof(ProxyAttribute)) as ProxyAttribute)?.Type; - if (proxy is object) return CreateViaActivator(proxy); + if (proxy is not null) + { + return CreateViaActivator(proxy); + } } - return EmitFactory(binderConfig); + return EmitFactory(binderConfig, log); } [MethodImpl(MethodImplOptions.NoInlining)] internal static Func CreateViaActivator(Type type) @@ -131,11 +134,11 @@ internal static Func CreateViaActivator(Type ty type, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance, null, - new object[] { channel }, + [channel], null)!; } [MethodImpl(MethodImplOptions.NoInlining)] - internal static Func EmitFactory(BinderConfiguration binderConfig) + internal static Func EmitFactory(BinderConfiguration binderConfig, Action? log) { Type baseType = GrpcClientFactory.ClientBaseType; @@ -172,12 +175,12 @@ internal static Func EmitFactory(BinderConfigur var ops = ContractOperation.FindOperations(binderConfig, typeof(TService), null); int marshallerIndex = 0; - Dictionary marshallers = new Dictionary(); + Dictionary marshallers = []; FieldBuilder Marshaller(Type forType) { if (marshallers.TryGetValue(forType, out var val)) return val.Field; - var instance = s_marshallerCacheGenericMethodDef.MakeGenericMethod(forType).Invoke(binderConfig.MarshallerCache, Array.Empty())!; + var instance = s_marshallerCacheGenericMethodDef.MakeGenericMethod(forType).Invoke(binderConfig.MarshallerCache, [])!; var name = "_m" + marshallerIndex++; var field = type.DefineField(name, typeof(Marshaller<>).MakeGenericType(forType), FieldAttributes.Static | FieldAttributes.Private); // **not** readonly, we need to set it afterwards! marshallers.Add(forType, (field, name, instance)); @@ -229,7 +232,10 @@ FieldBuilder Marshaller(Type forType) il.Emit(OpCodes.Ret); } else + { + log?.Invoke($"Unclear method: {iType.Name}.{iMethod.Name}"); il.ThrowException(typeof(NotSupportedException)); + } continue; } @@ -238,12 +244,13 @@ FieldBuilder Marshaller(Type forType) { if (!binderConfig.Binder.TryFindInheritedService(iType, contractExpandInterfaces, out serviceName)) { + log?.Invoke($"Inherited method is not a service: {iType.Name}.{iMethod.Name}"); il.ThrowException(typeof(NotSupportedException)); continue; } } - Type[] fromTo = new Type[] { op.From, op.To }; + Type[] fromTo = [op.From, op.To]; // private static Method s_{i} var field = type.DefineField("s_op_" + fieldIndex++, typeof(Method<,>).MakeGenericType(fromTo), FieldAttributes.Static | FieldAttributes.Private); @@ -261,8 +268,9 @@ FieldBuilder Marshaller(Type forType) switch (op.Context) { case ContextKind.CallOptions: - // we only support this for signatures that match the exat google pattern, but: + // we only support this for signatures that match the exact google pattern, but: // defer for now + log?.Invoke($"Call options not supported: {iType.Name}.{iMethod.Name}"); il.ThrowException(typeof(NotImplementedException)); break; case ContextKind.NoContext: @@ -274,6 +282,7 @@ FieldBuilder Marshaller(Type forType) if (method == null) { // unexpected, but... + log?.Invoke($"No client helper: {iType.Name}.{iMethod.Name}"); il.ThrowException(typeof(NotSupportedException)); } else @@ -316,6 +325,7 @@ FieldBuilder Marshaller(Type forType) break; case ContextKind.ServerCallContext: // server call? we're writing a client! default: // who knows! + log?.Invoke($"Unexpected context kind: {iType.Name}.{iMethod.Name}"); il.ThrowException(typeof(NotSupportedException)); break; } @@ -334,12 +344,12 @@ FieldBuilder Marshaller(Type forType) { finalType.GetField(name, BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Public)!.SetValue(null, instance); } - finalType.GetMethod(InitMethodName, BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Public)!.Invoke(null, Array.Empty()); + finalType.GetMethod(InitMethodName, BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Public)!.Invoke(null, []); // return the factory var p = Expression.Parameter(typeof(CallInvoker), "channel"); return Expression.Lambda>( - Expression.New(finalType.GetConstructor(new[] { typeof(CallInvoker) })!, p), p).Compile(); + Expression.New(finalType.GetConstructor([typeof(CallInvoker)])!, p), p).Compile(); ConstructorBuilder? WritePassThruCtor(MethodAttributes accessibility) { @@ -364,7 +374,7 @@ internal static readonly FieldInfo s_Empty_InstaneTask= typeof(Empty).GetField(nameof(Empty.InstanceTask))!; #pragma warning restore CS0618 - internal static readonly MethodInfo s_CallContext_FromCancellationToken = typeof(CallContext).GetMethod("op_Implicit", BindingFlags.Public | BindingFlags.Static, null, new[] { typeof(CancellationToken) }, null)!; + internal static readonly MethodInfo s_CallContext_FromCancellationToken = typeof(CallContext).GetMethod("op_Implicit", BindingFlags.Public | BindingFlags.Static, null, [typeof(CancellationToken)], null)!; internal const string FactoryName = "Create"; } diff --git a/src/protobuf-net.Grpc/Internal/Reshape.ByteStream.cs b/src/protobuf-net.Grpc/Internal/Reshape.ByteStream.cs new file mode 100644 index 00000000..36ea33fd --- /dev/null +++ b/src/protobuf-net.Grpc/Internal/Reshape.ByteStream.cs @@ -0,0 +1,264 @@ +using Grpc.Core; +using System; +using System.Buffers; +using System.ComponentModel; +using System.Diagnostics; +using System.Globalization; +using System.IO; +using System.IO.Pipelines; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace ProtoBuf.Grpc.Internal; + +partial class Reshape +{ + /// + /// Performs an operation that returns data from the server as a . + /// + [Obsolete(WarningMessage, false)] + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ValueTask ServerByteStreamingValueTaskAsync( + in CallContext context, + CallInvoker invoker, Method method, TRequest request, string? host = null) + where TRequest : class + => new(ServerByteStreamingTaskAsync(in context, invoker, method, request, host)); + + /// + /// Performs an operation that returns data from the server as a . + /// + [Obsolete(WarningMessage, false)] + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Task ServerByteStreamingTaskAsync( + in CallContext context, + CallInvoker invoker, Method method, TRequest request, string? host = null) + where TRequest : class + { + + context.CallOptions.CancellationToken.ThrowIfCancellationRequested(); + return ReadByteValueSequenceAsStream(invoker.AsyncServerStreamingCall(Assert(method), host, context.CallOptions, request), context.Prepare(), context.CancellationToken); + + async static Task ReadByteValueSequenceAsStream(AsyncServerStreamingCall call, MetadataContext? metadata, CancellationToken cancellationToken) + { + const bool DemandTrailer = false; // don't *demand* the trailer indicating total length, but enforce it if we find it (we always send it, currently) + try + { + // wait for headers, even if not available; that means we're in a state to start spoofing the stream + if (metadata is not null) + { + await metadata.SetHeadersAsync(call.ResponseHeadersAsync); + } + else + { + // even if we aren't capturing headers, we want to wait for them to be available, + // + await call.ResponseHeadersAsync.ConfigureAwait(false); + } + var firstRead = call.ResponseStream.MoveNext(CancellationToken.None); + if (firstRead.IsCompleted) + { + // probably an error in the call; fetch eagerly, so we can fail *before* + // providing a stream that needs reading to expose the fault; we'll + // touch the .Result, which is fine - we know it is completed, and this is + // Task, not ValueTask, so it is repeatable; if it throws: server fault; + // if it returns false, empty stream + if (!firstRead.GetAwaiter().GetResult()) + { + // empty stream, which could be valid zero-length, or could be + // a server fault; we'll find out + metadata?.SetTrailers(call); + call.Dispose(); + return Stream.Null; + } + + // if we get here, the first read was success - that just means + // the server is fast; we'll just let the main path await the already-completed + // first read, like normal + } + + // so if we got this far, we think the server is happy - start spinning up infrastructure to be the stream + Pipe pipe = new(); + _ = Task.Run(() => ReadByteValueSequenceToPipeWriter(call, firstRead, pipe.Writer, metadata, DemandTrailer, cancellationToken), CancellationToken.None); + return pipe.Reader.AsStream(leaveOpen: false); + } + catch (RpcException fault) + { + metadata?.SetTrailers(fault); + call.Dispose(); // note not using; only in case of fault! + throw; + } + } + + async static Task ReadByteValueSequenceToPipeWriter(AsyncServerStreamingCall call, Task pendingRead, PipeWriter destination, MetadataContext? metadata, bool demandTrailer, CancellationToken cancellationToken) + { + Exception? fault = null; + try + { + var source = call.ResponseStream; + long actualLength = 0; + bool clientTerminated = false; + while (await pendingRead.ConfigureAwait(false)) // note that the context's cancellation is already baked in + { + var chunk = source.Current; + var result = await destination.WriteAsync(chunk.Memory, cancellationToken).ConfigureAwait(false); + actualLength += chunk.Length; + + if (result.IsCanceled) + { + cancellationToken.ThrowIfCancellationRequested(); + FallbackThrowCanceled(); + } + + if (result.IsCompleted) + { + // reader has shut down; stop copying (we'll tell the server by disposing the call) + clientTerminated = true; + demandTrailer = false; + break; + } + + pendingRead = source.MoveNext(CancellationToken.None); + } + string? lenTrailer; + try + { + lenTrailer = call.GetTrailers().GetString(TrailerStreamLength); + } + catch (InvalidOperationException) when (clientTerminated) + { + // we didn't let the stream get to the end; the trailers simply might not be there + lenTrailer = null; + metadata = null; + } + + if (string.IsNullOrWhiteSpace(lenTrailer)) + { + if (demandTrailer) throw new InvalidOperationException($"Missing trailer: '{TrailerStreamLength}'"); + } + else if (!long.TryParse(lenTrailer, NumberStyles.Integer, CultureInfo.InvariantCulture, out var expectedLength) + || expectedLength != actualLength) + { + throw new InvalidOperationException($"Invalid trailer or length mismatch: '{TrailerStreamLength}'"); + } + metadata?.SetTrailers(call); + } + catch (Exception ex) + { + fault = ex; + if (fault is RpcException rpcFault) + { + metadata?.SetTrailers(rpcFault); + } + throw; + } + finally + { + try + { + // signal that no more data will be written, or at least try! + await destination.CompleteAsync(fault).ConfigureAwait(false); + } + catch (Exception ex) + { + Debug.WriteLine(ex.Message); + } + + try + { + call.Dispose(); + } + catch (Exception ex) + { + Debug.WriteLine(ex.Message); + } + } + } + } + + /// + /// Consumes an asynchronous enumerable sequence and writes it to a server stream-writer + /// + [Obsolete(WarningMessage, false)] + [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] + public static async Task WriteStream(Task source, IAsyncStreamWriter writer, ServerCallContext context, bool writeTrailer) + { + try + { +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + await // IDisposable is up-level +#endif + using var stream = await source; + + // read from the stream and write to writer + int size = 512; // start modest and increase + long totalLength = 0; +#if DEBUG + int debugChunk = 0; +#endif + + while (true) + { + byte[] leased = ArrayPool.Shared.Rent(size); + + var maxRead = Math.Min(leased.Length, BytesValue.MaxLength); + var bytes = await stream.ReadAsync(leased, 0, maxRead, context.CancellationToken).ConfigureAwait(false); + if (bytes <= 0) // EOF + { + ArrayPool.Shared.Return(leased); + break; + } + if (bytes == maxRead) + { + // allow more next time + size = Math.Min(size * 2, BytesValue.MaxLength); + } + else + { + // allow less next time, down to whatever we read + size = Math.Max(bytes, 128); + } + + var chunk = new BytesValue(leased, bytes, pooled: true); +#if DEBUG + context.ResponseTrailers.Add($"pbn_chunk{debugChunk}", bytes); +#endif + totalLength += bytes; + await writer.WriteAsync(chunk).ConfigureAwait(false); + } + + if (writeTrailer) + { + context.ResponseTrailers.Add(TrailerStreamLength, totalLength.ToString(CultureInfo.InvariantCulture)); + } + } + catch (Exception ex) + { + Debug.WriteLine(ex.Message); + throw; + } + } + + // more idiomatic labels like content-length are reserved, and are not transmitted/received + internal const string TrailerStreamLength = "stream-length"; + + static void FallbackThrowCanceled() => throw new OperationCanceledException(); + + static Method Assert(IMethod method) + { + var typed = method as Method; + if (typed is null) + { + ThrowMethodFail(typeof(TRequest), typeof(TResponse)); + } + return typed!; + } + +#pragma warning disable IDE0079 // (unnecessary suppression) +#pragma warning disable CA2208 // usage of literal "method" + static void ThrowMethodFail(Type request, Type response) => throw new ArgumentException($"Method was expected to take '{request.Name}' and return '{response.Name}'", "method"); +#pragma warning restore CA2208 +#pragma warning restore IDE0079 +} diff --git a/src/protobuf-net.Grpc/Internal/Reshape.cs b/src/protobuf-net.Grpc/Internal/Reshape.cs index 3f333d63..c1a2b1fa 100644 --- a/src/protobuf-net.Grpc/Internal/Reshape.cs +++ b/src/protobuf-net.Grpc/Internal/Reshape.cs @@ -16,7 +16,7 @@ namespace ProtoBuf.Grpc.Internal /// [Obsolete(WarningMessage, false)] [Browsable(false), EditorBrowsable(EditorBrowsableState.Never)] - public static class Reshape + public static partial class Reshape { internal const string WarningMessage = "This API is intended for use by runtime-generated code; all types and methods can be changed without notice - it is only guaranteed to work with the internally generated code"; @@ -558,7 +558,7 @@ private static async Task UnaryTaskAsyncImpl( metadata?.SetTrailers(fault); throw; } - metadata?.SetTrailers(call, c => c.GetStatus(), c => c.GetTrailers()); + metadata?.SetTrailers(call); return value; } @@ -624,7 +624,7 @@ private static async IAsyncEnumerable ServerStreamingAsyncImpl c.GetStatus(), c => c.GetTrailers()); + metadata?.SetTrailers(call); } } @@ -655,7 +655,7 @@ protected override ValueTask OnBeforeAsync() protected override ValueTask OnAfterAsync() { - _metadata?.SetTrailers(_call, c => c.GetStatus(), c => c.GetTrailers()); + _metadata?.SetTrailers(_call); return default; } @@ -724,7 +724,7 @@ private static async Task ClientStreamingTaskAsyncImpl c.GetStatus(), c => c.GetTrailers()); + metadata?.SetTrailers(call); return result; } catch (RpcException fault) @@ -789,7 +789,7 @@ private static async Task ClientStreamingObservableTaskAsyncImpl c.GetStatus(), c => c.GetTrailers()); + metadata?.SetTrailers(call); return result; } catch (RpcException fault) @@ -884,7 +884,7 @@ private static async IAsyncEnumerable DuplexAsyncImpl c.GetStatus(), c => c.GetTrailers()); + metadata?.SetTrailers(call); } } @@ -918,7 +918,7 @@ protected override ValueTask OnBeforeAsync() protected override async ValueTask OnAfterAsync() { await _sendAll; - _metadata?.SetTrailers(_call, c => c.GetStatus(), c => c.GetTrailers()); + _metadata?.SetTrailers(_call); } protected override void OnFault(RpcException fault) diff --git a/src/protobuf-net.Grpc/Internal/ServerInvokerLookup.cs b/src/protobuf-net.Grpc/Internal/ServerInvokerLookup.cs index 8363c3e1..3f3fd2b8 100644 --- a/src/protobuf-net.Grpc/Internal/ServerInvokerLookup.cs +++ b/src/protobuf-net.Grpc/Internal/ServerInvokerLookup.cs @@ -35,10 +35,10 @@ static Expression ToTaskT(Expression expression) if (type.GetGenericTypeDefinition() == typeof(ValueTask<>)) return Expression.Call(expression, nameof(ValueTask.AsTask), null); } - return Expression.Call(typeof(Task), nameof(Task.FromResult), new Type[] { expression.Type }, expression); + return Expression.Call(typeof(Task), nameof(Task.FromResult), [expression.Type], expression); } - internal static readonly ConstructorInfo s_CallContext_FromServerContext = typeof(CallContext).GetConstructor(new[] { typeof(object), typeof(ServerCallContext) })!; + internal static readonly ConstructorInfo s_CallContext_FromServerContext = typeof(CallContext).GetConstructor([typeof(object), typeof(ServerCallContext)])!; internal static readonly PropertyInfo s_ServerContext_CancellationToken = typeof(ServerCallContext).GetProperty(nameof(ServerCallContext.CancellationToken))!; static Expression ToCallContext(Expression server, Expression context) => Expression.New(s_CallContext_FromServerContext, server, context); @@ -48,22 +48,30 @@ static Expression ToTaskT(Expression expression) static Expression AsAsyncEnumerable(Expression value, Expression context) => Expression.Call(typeof(Reshape), nameof(Reshape.AsAsyncEnumerable), typeArguments: value.Type.GetGenericArguments(), - arguments: new Expression[] { value, Expression.Property(context, nameof(ServerCallContext.CancellationToken)) }); + arguments: [value, Expression.Property(context, nameof(ServerCallContext.CancellationToken))]); static Expression AsObservable(Expression value, Expression context) => Expression.Call(typeof(Reshape), nameof(Reshape.AsObservable), typeArguments: value.Type.GetGenericArguments(), - arguments: new Expression[] { value }); + arguments: [value]); static Expression WriteTo(Expression value, Expression writer, Expression context) => Expression.Call(typeof(Reshape), nameof(Reshape.WriteTo), typeArguments: value.Type.GetGenericArguments(), - arguments: new Expression[] { value, writer, Expression.Property(context, nameof(ServerCallContext.CancellationToken)) }); + arguments: [value, writer, Expression.Property(context, nameof(ServerCallContext.CancellationToken))]); static Expression WriteObservableTo(Expression value, Expression writer, Expression context) => Expression.Call(typeof(Reshape), nameof(Reshape.WriteObservableTo), typeArguments: value.Type.GetGenericArguments(), - arguments: new Expression[] { value, writer }); + arguments: [value, writer]); + + static Expression WriteStream(Expression value, Expression writer, Expression context, bool writeTrailer = true) + => Expression.Call(typeof(Reshape), nameof(Reshape.WriteStream), + typeArguments: null, + arguments: [ToTaskT(value), writer, context, ConstantBoolean(writeTrailer)]); + + private static Expression ConstantBoolean(bool value) => value ? True : False; + private static Expression True = Expression.Constant(true, typeof(bool)), False = Expression.Constant(false, typeof(bool)); internal static bool TryGetValue(MethodType MethodType, ContextKind Context, ResultKind Arg, ResultKind Result, VoidKind Void, out Func? invoker) => _invokers.TryGetValue((MethodType, Context, Arg, Result, Void), out invoker); @@ -198,6 +206,22 @@ internal static bool TryGetValue(MethodType MethodType, ContextKind Context, Res {(MethodType.DuplexStreaming, ContextKind.NoContext, ResultKind.Observable, ResultKind.Observable, VoidKind.None), (method, args) => WriteObservableTo(Expression.Call(args[0], method, AsObservable(args[1], args[3])), args[2], args[3]) }, {(MethodType.DuplexStreaming, ContextKind.CallContext, ResultKind.Observable,ResultKind.Observable, VoidKind.None), (method, args) => WriteObservableTo(Expression.Call(args[0], method, AsObservable(args[1], args[3]), ToCallContext(args[0], args[3])), args[2], args[3]) }, {(MethodType.DuplexStreaming, ContextKind.CancellationToken, ResultKind.Observable, ResultKind.Observable, VoidKind.None), (method, args) => WriteObservableTo(Expression.Call(args[0], method, AsObservable(args[1], args[3]), ToCancellationToken(args[3])), args[2], args[3]) }, + + {(MethodType.ServerStreaming, ContextKind.NoContext, ResultKind.Sync, ResultKind.TaskStream, VoidKind.Request), (method, args) => WriteStream(Expression.Call(args[0], method), args[2], args[3])}, + {(MethodType.ServerStreaming, ContextKind.CallContext, ResultKind.Sync, ResultKind.TaskStream, VoidKind.Request), (method, args) => WriteStream(Expression.Call(args[0], method, ToCallContext(args[0], args[3])), args[2], args[3])}, + {(MethodType.ServerStreaming, ContextKind.CancellationToken, ResultKind.Sync, ResultKind.TaskStream, VoidKind.Request), (method, args) => WriteStream(Expression.Call(args[0], method, ToCancellationToken(args[3])), args[2], args[3])}, + + {(MethodType.ServerStreaming, ContextKind.NoContext, ResultKind.Sync, ResultKind.TaskStream, VoidKind.None), (method, args) => WriteStream(Expression.Call(args[0], method, args[1]), args[2], args[3])}, + {(MethodType.ServerStreaming, ContextKind.CallContext, ResultKind.Sync, ResultKind.TaskStream, VoidKind.None), (method, args) => WriteStream(Expression.Call(args[0], method, args[1], ToCallContext(args[0], args[3])), args[2], args[3])}, + {(MethodType.ServerStreaming, ContextKind.CancellationToken, ResultKind.Sync, ResultKind.TaskStream, VoidKind.None), (method, args) => WriteStream(Expression.Call(args[0], method, args[1], ToCancellationToken(args[3])), args[2], args[3])}, + + {(MethodType.ServerStreaming, ContextKind.NoContext, ResultKind.Sync, ResultKind.ValueTaskStream, VoidKind.Request), (method, args) => WriteStream(Expression.Call(args[0], method), args[2], args[3])}, + {(MethodType.ServerStreaming, ContextKind.CallContext, ResultKind.Sync, ResultKind.ValueTaskStream, VoidKind.Request), (method, args) => WriteStream(Expression.Call(args[0], method, ToCallContext(args[0], args[3])), args[2], args[3])}, + {(MethodType.ServerStreaming, ContextKind.CancellationToken, ResultKind.Sync, ResultKind.ValueTaskStream, VoidKind.Request), (method, args) => WriteStream(Expression.Call(args[0], method, ToCancellationToken(args[3])), args[2], args[3])}, + + {(MethodType.ServerStreaming, ContextKind.NoContext, ResultKind.Sync, ResultKind.ValueTaskStream, VoidKind.None), (method, args) => WriteStream(Expression.Call(args[0], method, args[1]), args[2], args[3])}, + {(MethodType.ServerStreaming, ContextKind.CallContext, ResultKind.Sync, ResultKind.ValueTaskStream, VoidKind.None), (method, args) => WriteStream(Expression.Call(args[0], method, args[1], ToCallContext(args[0], args[3])), args[2], args[3])}, + {(MethodType.ServerStreaming, ContextKind.CancellationToken, ResultKind.Sync, ResultKind.ValueTaskStream, VoidKind.None), (method, args) => WriteStream(Expression.Call(args[0], method, args[1], ToCancellationToken(args[3])), args[2], args[3])}, }; } } diff --git a/src/protobuf-net.Grpc/protobuf-net.Grpc.csproj b/src/protobuf-net.Grpc/protobuf-net.Grpc.csproj index 56f0867d..1f128e24 100644 --- a/src/protobuf-net.Grpc/protobuf-net.Grpc.csproj +++ b/src/protobuf-net.Grpc/protobuf-net.Grpc.csproj @@ -7,11 +7,16 @@ + + + + + diff --git a/tests/protobuf-net.Grpc.Reflection.Test/FileDescriptorSetFactoryTests.cs b/tests/protobuf-net.Grpc.Reflection.Test/FileDescriptorSetFactoryTests.cs index 5859ee22..da8250ac 100644 --- a/tests/protobuf-net.Grpc.Reflection.Test/FileDescriptorSetFactoryTests.cs +++ b/tests/protobuf-net.Grpc.Reflection.Test/FileDescriptorSetFactoryTests.cs @@ -18,9 +18,9 @@ public void SimpleService() var fileDescriptorSet = FileDescriptorSetFactory.Create(new[] { typeof(GreeterService) }); Assert.Empty(fileDescriptorSet.GetErrors()); - Assert.Equal(new[] { "GreeterService" }, + Assert.Equal(["GreeterService"], fileDescriptorSet.Files.SelectMany(x => x.Services).Select(x => x.Name).ToArray()); - Assert.Equal(new[] { "HelloReply", "HelloRequest" }, + Assert.Equal(["HelloReply", "HelloRequest"], fileDescriptorSet.Files.SelectMany(x => x.MessageTypes).Select(x => x.Name).ToArray()); } } diff --git a/tests/protobuf-net.Grpc.Reflection.Test/ReflectionServiceTests.cs b/tests/protobuf-net.Grpc.Reflection.Test/ReflectionServiceTests.cs index 9ee08fb7..86a5ed71 100644 --- a/tests/protobuf-net.Grpc.Reflection.Test/ReflectionServiceTests.cs +++ b/tests/protobuf-net.Grpc.Reflection.Test/ReflectionServiceTests.cs @@ -42,7 +42,7 @@ public async Task ShouldIncludeDependenciesInCorrectOrder(Type service, string s { // Use reflection. var addImportMethod = AddImportMethod.Value; - addImportMethod.Invoke(fileDescriptor, new object?[] {dependency, true, default}); + addImportMethod.Invoke(fileDescriptor, [dependency, true, default]); } fileDescriptorSet.Files.Add(fileDescriptor); @@ -90,8 +90,7 @@ async IAsyncEnumerable GetRequest() ".ReflectionTest.BclMessage", } }, - new object[] - { + [ typeof(ReflectionTest.Service.Nested), ".ReflectionTest.Service.Nested", new[] @@ -105,7 +104,7 @@ async IAsyncEnumerable GetRequest() ".ReflectionTest.Service.Three", ".ReflectionTest.Service.Two", } - }, + ], }; } } diff --git a/tests/protobuf-net.Grpc.Test.Integration/ClientProxyTests.cs b/tests/protobuf-net.Grpc.Test.Integration/ClientProxyTests.cs index 8aae3869..0838dbfa 100644 --- a/tests/protobuf-net.Grpc.Test.Integration/ClientProxyTests.cs +++ b/tests/protobuf-net.Grpc.Test.Integration/ClientProxyTests.cs @@ -145,6 +145,6 @@ await Assert.ThrowsAsync(async () => }); } #endif - } + } } \ No newline at end of file diff --git a/tests/protobuf-net.Grpc.Test.Integration/StreamTests.cs b/tests/protobuf-net.Grpc.Test.Integration/StreamTests.cs index 0cdaaec9..9316dccc 100644 --- a/tests/protobuf-net.Grpc.Test.Integration/StreamTests.cs +++ b/tests/protobuf-net.Grpc.Test.Integration/StreamTests.cs @@ -8,10 +8,13 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.IO; +using System.IO.Pipelines; using System.Linq; using System.Reactive.Linq; using System.Reactive.Subjects; using System.Runtime.CompilerServices; +using System.Text; using System.Threading; using System.Threading.Tasks; using Xunit; @@ -74,6 +77,12 @@ public interface IStreamAPI ValueTask TakeFive(CancellationToken cancellationToken = default); } + [Service] + public interface IStreamRewrite + { + ValueTask MagicStream(Foo foo, CallContext ctx = default); + } + public enum Scenario { RunToCompletion, @@ -88,7 +97,7 @@ public enum Scenario FaultSuccessGoodProducer, // observes cancellation } - class StreamServer : IStreamAPI + class StreamServer : IStreamAPI, IStreamRewrite { readonly StreamTestsFixture _fixture; internal StreamServer(StreamTestsFixture fixture) @@ -370,6 +379,67 @@ private async IAsyncEnumerable Producer(CallContext ctx) await Task.Delay(10, ctx.CancellationToken); } } + + async ValueTask IStreamRewrite.MagicStream(Foo foo, CallContext ctx) + { + var scenario = GetScenario(ctx); + if (scenario is Scenario.FaultBeforeHeaders) + { + throw new RpcException(new Status(StatusCode.PermissionDenied, nameof(Scenario.FaultBeforeHeaders))); + } + + var headers = new Metadata() + { + {"resp-header", scenario.ToString() }, + }; + + await ctx.ServerCallContext!.WriteResponseHeadersAsync(headers); + + switch (scenario) + { + case Scenario.YieldNothing: + return Stream.Null; + case Scenario.RunToCompletion: + return new MemoryStream(Encoding.UTF8.GetBytes("hello, world")); + case Scenario.FaultBeforeTrailers: + case Scenario.FaultBeforeYield: + case Scenario.FaultAfterYield: + var pipe = new Pipe(); + _ = Task.Run(async () => + { + Exception? fault = null; + await Task.Delay(50); + try + { + if (scenario == Scenario.FaultBeforeYield) + { + throw new RpcException(new Status(StatusCode.PermissionDenied, nameof(Scenario.FaultBeforeYield))); + } + await pipe.Writer.WriteAsync(Encoding.UTF8.GetBytes("hello, "), CancellationToken.None); + if (scenario == Scenario.FaultAfterYield) + { + throw new RpcException(new Status(StatusCode.PermissionDenied, nameof(Scenario.FaultAfterYield))); + } + await pipe.Writer.WriteAsync(Encoding.UTF8.GetBytes("world"), CancellationToken.None); + if (scenario == Scenario.FaultBeforeTrailers) + { + throw new RpcException(new Status(StatusCode.PermissionDenied, nameof(Scenario.FaultBeforeTrailers))); + } + } + catch (Exception ex) + { + fault = ex; + } + finally + { + await pipe.Writer.CompleteAsync(fault); + } + }, CancellationToken.None); + return pipe.Reader.AsStream(); + default: + throw new ArgumentOutOfRangeException(nameof(scenario)); + } + } } [ProtoContract] @@ -383,10 +453,10 @@ public class Foo public class NativeStreamTests : StreamTests { public NativeStreamTests(StreamTestsFixture fixture, ITestOutputHelper log) : base(fixture, log) { } - protected override IAsyncDisposable CreateClient(out IStreamAPI client) + protected override IAsyncDisposable CreateClient(out TService client) { var channel = new Channel("localhost", Port, ChannelCredentials.Insecure); - client = channel.CreateGrpcService(); + client = channel.CreateGrpcService(); return new DisposableChannel(channel); } sealed class DisposableChannel : IAsyncDisposable @@ -403,10 +473,10 @@ public class ManagedStreamTests : StreamTests { public override bool IsManagedClient => true; public ManagedStreamTests(StreamTestsFixture fixture, ITestOutputHelper log) : base(fixture, log) { } - protected override IAsyncDisposable CreateClient(out IStreamAPI client) + protected override IAsyncDisposable CreateClient(out TService client) { var http = global::Grpc.Net.Client.GrpcChannel.ForAddress($"http://localhost:{Port}"); - client = http.CreateGrpcService(); + client = http.CreateGrpcService(); return new DisposableChannel(http); } sealed class DisposableChannel : IAsyncDisposable @@ -474,7 +544,9 @@ public void Dispose() GC.SuppressFinalize(this); } - protected abstract IAsyncDisposable CreateClient(out IStreamAPI client); + protected IAsyncDisposable CreateClient(out IStreamAPI client) => CreateClient(out client); + + protected abstract IAsyncDisposable CreateClient(out TService client) where TService : class; const int DEFAULT_SIZE = 20; @@ -746,6 +818,93 @@ void CheckForCancellation(string when) } } + [Theory] + [InlineData(Scenario.RunToCompletion)] + [InlineData(Scenario.YieldNothing)] + [InlineData(Scenario.FaultBeforeHeaders)] + [InlineData(Scenario.FaultBeforeYield)] + [InlineData(Scenario.FaultAfterYield)] + [InlineData(Scenario.FaultBeforeTrailers)] + + [InlineData(Scenario.RunToCompletion, CallContextFlags.CaptureMetadata)] + [InlineData(Scenario.YieldNothing, CallContextFlags.CaptureMetadata)] + [InlineData(Scenario.FaultBeforeHeaders, CallContextFlags.CaptureMetadata)] + [InlineData(Scenario.FaultBeforeYield, CallContextFlags.CaptureMetadata)] + [InlineData(Scenario.FaultAfterYield, CallContextFlags.CaptureMetadata)] + [InlineData(Scenario.FaultBeforeTrailers, CallContextFlags.CaptureMetadata)] + public async Task StreamRewriteBasicTest(Scenario scenario, CallContextFlags flags = CallContextFlags.None) + { + // note that depending on timing, FaultBeforeYield may be exposed as *either* a failed Stream + // fetch, *or* an unreadable stream + + await using var svc = CreateClient(out var client); + bool withMetadata = (flags & CallContextFlags.CaptureMetadata) != 0; + var ctx = new CallContext(new CallOptions(headers: new Metadata { { nameof(Scenario), scenario.ToString() } }), flags); + + try + { + using var stream = await client.MagicStream(new Foo { Bar = 1 }, ctx); + if (withMetadata) + { + WriteMetadata("header", await ctx.ResponseHeadersAsync()); + } + + using var sr = new StreamReader(stream); + + switch (scenario) + { + case Scenario.RunToCompletion: + string s = await sr.ReadToEndAsync(); + Assert.Equal("hello, world", s); + break; + case Scenario.YieldNothing: + s = await sr.ReadToEndAsync(); + Assert.Equal("", s); + break; + default: + var ex = await Assert.ThrowsAsync(sr.ReadToEndAsync); + Assert.Equal(StatusCode.PermissionDenied, ex.StatusCode); + Assert.Equal(scenario.ToString(), ex.Status.Detail); + break; + } + } + catch (RpcException ex) when (scenario is Scenario.FaultBeforeHeaders or Scenario.FaultBeforeYield) + { + if (withMetadata) + { + WriteMetadata("header", await ctx.ResponseHeadersAsync()); + } + Assert.Equal(StatusCode.PermissionDenied, ex.StatusCode); + Assert.Equal(scenario.ToString(), ex.Status.Detail); + } + if (withMetadata) + { + WriteMetadata("trailer", ctx.ResponseTrailers()); + } + } + + private void WriteMetadata(string label, Metadata? value) + { + if (value is null) + { + _fixture.Output?.WriteLine($"(no {label} metadata)"); + } + else + { + foreach (var pair in value) + { + if (pair.IsBinary) + { + _fixture.Output?.WriteLine($"{label} {pair.Key}={BitConverter.ToString(pair.ValueBytes)}"); + } + else + { + _fixture.Output?.WriteLine($"{label} {pair.Key}='{pair.Value}'"); + } + } + } + } + private static IObservable ForObservableImpl(StreamTestsFixture fixture, int count, int from, int millisecondsDelay) { void Log(string message) diff --git a/tests/protobuf-net.Grpc.Test/BytesValueMarshallerTests.cs b/tests/protobuf-net.Grpc.Test/BytesValueMarshallerTests.cs new file mode 100644 index 00000000..9e911817 --- /dev/null +++ b/tests/protobuf-net.Grpc.Test/BytesValueMarshallerTests.cs @@ -0,0 +1,129 @@ +using Grpc.Core; +using ProtoBuf.Grpc.Internal; +using System; +using System.Buffers; +using System.IO; +using Xunit; + +namespace protobuf_net.Grpc.Test; + +#pragma warning disable CS0618 // all marked obsolete! + +public class BytesValueMarshallerTests +{ + [Fact] + public void ProveMaxLength() + { + Assert.Equal(0b1111111_1111111_1111111, BytesValue.MaxLength); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(8)] + [InlineData(9)] + [InlineData(16)] + [InlineData(17)] + [InlineData(24)] + [InlineData(25)] + [InlineData(32)] + [InlineData(64)] + // "varint" is a 7-bit scheme; easiest way to see + // ranges is via 0b notation with 7-bit groups + [InlineData(0b0000000_0000000_1111111)] + [InlineData(0b0000000_0000001_0000000)] + [InlineData(0b0000000_0000001_1111111)] + [InlineData(0b0000000_1111111_1111111)] + [InlineData(0b0000001_0000000_0000000)] + [InlineData(0b1111011_0000000_0000000)] + [InlineData(0b1111011_0000000_1010101)] + [InlineData(0b1111011_1110101_1010101)] + [InlineData(0b1111111_1111111_1111111)] + + public void TestFastParseAndFormat(int length) + { + var source = new byte[length]; + new Random().NextBytes(source); + var ser = new TestSerializationContext(); + BytesValue.Marshaller.ContextualSerializer(new BytesValue(source, source.Length, false), ser); + byte[] chunk = ser.ToArray(); + +#if DEBUG + var missCount = BytesValue.FastPassMiss; +#endif + + // check via our custom deserializer + var result = BytesValue.Marshaller.ContextualDeserializer(new TestDeserializationContext(chunk)); + Assert.NotNull(result); + Assert.True(result.Span.SequenceEqual(source)); + Assert.True(result.IsPooled); + Assert.False(result.IsRecycled); + result.Recycle(); + Assert.False(result.IsPooled); + Assert.True(result.IsRecycled); + Assert.True(result.IsEmpty); + +#if DEBUG + Assert.Equal(missCount, BytesValue.FastPassMiss); // expect no new misses +#endif + + // and double-check via protobuf-net directly + result = ProtoBuf.Serializer.Deserialize(new MemoryStream(chunk)); + Assert.NotNull(result); + Assert.True(result.Span.SequenceEqual(source)); + Assert.False(result.IsPooled); + Assert.False(result.IsRecycled); + result.Recycle(); + Assert.False(result.IsPooled); + Assert.True(result.IsRecycled); + Assert.True(result.IsEmpty); + + } + + class TestSerializationContext : SerializationContext + { + public byte[] ToArray() => _payload; + private byte[] _payload = []; + private TestBufferWriter? _writer; + public override IBufferWriter GetBufferWriter() => _writer ?? new(_payload); + + public override void SetPayloadLength(int payloadLength) + { + Array.Resize(ref _payload, payloadLength); + _writer = null; + } + + public override void Complete() { } + public override void Complete(byte[] payload) => _payload = payload; + } + + class TestDeserializationContext(byte[] chunk) : DeserializationContext + { + public override int PayloadLength => chunk.Length; + public override byte[] PayloadAsNewBuffer() + { + var arr = new byte[chunk.Length]; + Buffer.BlockCopy(chunk, 0, arr, 0, chunk.Length); + return arr; + } + public override ReadOnlySequence PayloadAsReadOnlySequence() + => new(chunk); + } + + class TestBufferWriter(byte[] payload) : IBufferWriter + { + private byte[] _bytes = payload; + private int _committed = 0; + + public void Advance(int count) + => _committed += count; + + public byte[] AsArray() => _bytes; + + public Memory GetMemory(int sizeHint = 0) + => new(_bytes, _committed, _bytes.Length - _committed); + + public Span GetSpan(int sizeHint = 0) + => new(_bytes, _committed, _bytes.Length - _committed); + } +} diff --git a/tests/protobuf-net.Grpc.Test/ContractOperationTests.cs b/tests/protobuf-net.Grpc.Test/ContractOperationTests.cs index 27733e1c..daddd732 100644 --- a/tests/protobuf-net.Grpc.Test/ContractOperationTests.cs +++ b/tests/protobuf-net.Grpc.Test/ContractOperationTests.cs @@ -5,9 +5,11 @@ using System; using System.Collections.Generic; using System.Collections.ObjectModel; +using System.IO; using System.Linq; using System.Reflection; using System.Text; +using System.Threading; using System.Threading.Tasks; using Xunit; using Xunit.Abstractions; @@ -52,13 +54,13 @@ public void SublclassInterfaces() [Fact] public void GeneralPurposeSignatureCount() { - Assert.Equal(78, ContractOperation.GeneralPurposeSignatureCount()); + Assert.Equal(90, ContractOperation.GeneralPurposeSignatureCount()); } [Fact] public void ServerSignatureCount() { - Assert.Equal(78, ServerInvokerLookup.GeneralPurposeSignatureCount()); + Assert.Equal(90, ServerInvokerLookup.GeneralPurposeSignatureCount()); } [Fact] @@ -255,6 +257,20 @@ public interface IOtherMiddle : IInner [InlineData(nameof(IAllOptions.Shared_ValueTaskClientStreaming_Context_ValVoid_Observable), typeof(HelloRequest), typeof(Empty), MethodType.ClientStreaming, (int)ContextKind.CallContext, (int)ResultKind.ValueTask, (int)VoidKind.Response, (int)ResultKind.Observable)] [InlineData(nameof(IAllOptions.Shared_ValueTaskClientStreaming_NoContext_ValVoid_Observable), typeof(HelloRequest), typeof(Empty), MethodType.ClientStreaming, (int)ContextKind.NoContext, (int)ResultKind.ValueTask, (int)VoidKind.Response, (int)ResultKind.Observable)] [InlineData(nameof(IAllOptions.Shared_ValueTaskClientStreaming_CancellationToken_ValVoid_Observable), typeof(HelloRequest), typeof(Empty), MethodType.ClientStreaming, (int)ContextKind.CancellationToken, (int)ResultKind.ValueTask, (int)VoidKind.Response, (int)ResultKind.Observable)] + + [InlineData(nameof(IAllOptions.Shared_ServerStreaming_T_Stream_NoContext), typeof(Empty), typeof(BytesValue), MethodType.ServerStreaming, (int)ContextKind.NoContext, (int)ResultKind.TaskStream, (int)VoidKind.Request, (int)ResultKind.Sync)] + [InlineData(nameof(IAllOptions.Shared_ServerStreaming_T_Stream_CancellationToken), typeof(Empty), typeof(BytesValue), MethodType.ServerStreaming, (int)ContextKind.CancellationToken, (int)ResultKind.TaskStream, (int)VoidKind.Request, (int)ResultKind.Sync)] + [InlineData(nameof(IAllOptions.Shared_ServerStreaming_T_Stream_Context), typeof(Empty), typeof(BytesValue), MethodType.ServerStreaming, (int)ContextKind.CallContext, (int)ResultKind.TaskStream, (int)VoidKind.Request, (int)ResultKind.Sync)] + [InlineData(nameof(IAllOptions.Shared_ServerStreaming_T_Stream_Arg_NoContext), typeof(HelloRequest), typeof(BytesValue), MethodType.ServerStreaming, (int)ContextKind.NoContext, (int)ResultKind.TaskStream, (int)VoidKind.None, (int)ResultKind.Sync)] + [InlineData(nameof(IAllOptions.Shared_ServerStreaming_T_Stream_Arg_CancellationToken), typeof(HelloRequest), typeof(BytesValue), MethodType.ServerStreaming, (int)ContextKind.CancellationToken, (int)ResultKind.TaskStream, (int)VoidKind.None, (int)ResultKind.Sync)] + [InlineData(nameof(IAllOptions.Shared_ServerStreaming_T_Stream_Arg_Context), typeof(HelloRequest), typeof(BytesValue), MethodType.ServerStreaming, (int)ContextKind.CallContext, (int)ResultKind.TaskStream, (int)VoidKind.None, (int)ResultKind.Sync)] + + [InlineData(nameof(IAllOptions.Shared_ServerStreaming_VT_Stream_NoContext), typeof(Empty), typeof(BytesValue), MethodType.ServerStreaming, (int)ContextKind.NoContext, (int)ResultKind.ValueTaskStream, (int)VoidKind.Request, (int)ResultKind.Sync)] + [InlineData(nameof(IAllOptions.Shared_ServerStreaming_VT_Stream_CancellationToken), typeof(Empty), typeof(BytesValue), MethodType.ServerStreaming, (int)ContextKind.CancellationToken, (int)ResultKind.ValueTaskStream, (int)VoidKind.Request, (int)ResultKind.Sync)] + [InlineData(nameof(IAllOptions.Shared_ServerStreaming_VT_Stream_Context), typeof(Empty), typeof(BytesValue), MethodType.ServerStreaming, (int)ContextKind.CallContext, (int)ResultKind.ValueTaskStream, (int)VoidKind.Request, (int)ResultKind.Sync)] + [InlineData(nameof(IAllOptions.Shared_ServerStreaming_VT_Stream_Arg_NoContext), typeof(HelloRequest), typeof(BytesValue), MethodType.ServerStreaming, (int)ContextKind.NoContext, (int)ResultKind.ValueTaskStream, (int)VoidKind.None, (int)ResultKind.Sync)] + [InlineData(nameof(IAllOptions.Shared_ServerStreaming_VT_Stream_Arg_CancellationToken), typeof(HelloRequest), typeof(BytesValue), MethodType.ServerStreaming, (int)ContextKind.CancellationToken, (int)ResultKind.ValueTaskStream, (int)VoidKind.None, (int)ResultKind.Sync)] + [InlineData(nameof(IAllOptions.Shared_ServerStreaming_VT_Stream_Arg_Context), typeof(HelloRequest), typeof(BytesValue), MethodType.ServerStreaming, (int)ContextKind.CallContext, (int)ResultKind.ValueTaskStream, (int)VoidKind.None, (int)ResultKind.Sync)] public void CheckMethodIdentification(string name, Type from, Type to, MethodType methodType, int context, int result, int @void, int arg) { var method = typeof(IAllOptions).GetMethod(name); @@ -276,6 +292,116 @@ public void CheckMethodIdentification(string name, Type from, Type to, MethodTyp Assert.Equal((VoidKind)@void, operation.Void); } + [Fact] + public void BindServer() + { + var expected = typeof(IAllOptions).GetMethods().Select(m => m.Name).Where(s => !s.StartsWith("Client_") && !s.StartsWith("Server_")).ToArray(); + Array.Sort(expected); + var server = new TestBinder(_output); + var obj = new MyServer(); + int count = server.Bind(this, typeof(MyServer), null, obj); + _output.WriteLine($"Bound: {count} methods"); + Assert.Equal(expected, server.Collect()); + } + + internal sealed class TestBinder(ITestOutputHelper log) : ServerBinder + { + private readonly List _methods = []; + protected override bool TryBind(ServiceBindContext bindContext, Method method, MethodStub stub) + { + try + { + switch (method.Type) + { + case MethodType.Unary: + stub.CreateDelegate>(); + break; + case MethodType.ClientStreaming: + stub.CreateDelegate>(); + break; + case MethodType.ServerStreaming: + stub.CreateDelegate>(); + break; + case MethodType.DuplexStreaming: + stub.CreateDelegate>(); + break; + default: + return false; + } + } + catch (Exception ex) + { + log.WriteLine($"Failed to bind {stub.Method.Name}: {ex.Message}"); + return false; + } + _methods.Add(stub.Method.Name); + return true; + } + + public string[] Collect() + { + _methods.Sort(); + var arr = _methods.ToArray(); + _methods.Clear(); // reset + return arr; + } + } + + [Fact] + public void EmitClientProxy() + { + int errorCount = 0; + HashSet permitted = [ + "Call options not supported: IAllOptions.Client_BlockingUnary", + "Call options not supported: IAllOptions.Client_AsyncUnary", + "Call options not supported: IAllOptions.Client_ClientStreaming", + "Call options not supported: IAllOptions.Client_Duplex", + "Call options not supported: IAllOptions.Client_ServerStreaming", + ]; + var factory = ProxyEmitter.CreateFactory(BinderConfiguration.Default, s => + { + if (!s.Contains("IAllOptions.Server_") && !permitted.Contains(s)) + { + errorCount++; + } + _output.WriteLine(s); + }); + + Assert.NotNull(factory(NullCallInvoker.Instance)); + Assert.Equal(0, errorCount); + } + + sealed class NullCallInvoker : CallInvoker + { + private NullCallInvoker() { } + public static CallInvoker Instance { get; } = new NullCallInvoker(); + + public override AsyncClientStreamingCall AsyncClientStreamingCall(Method method, string? host, CallOptions options) + { + throw new NotSupportedException(); + } + + public override AsyncDuplexStreamingCall AsyncDuplexStreamingCall(Method method, string? host, CallOptions options) + { + throw new NotSupportedException(); + } + + public override AsyncServerStreamingCall AsyncServerStreamingCall(Method method, string? host, CallOptions options, TRequest request) + { + throw new NotSupportedException(); + } + + public override AsyncUnaryCall AsyncUnaryCall(Method method, string? host, CallOptions options, TRequest request) + { + throw new NotSupportedException(); + } + + public override TResponse BlockingUnaryCall(Method method, string? host, CallOptions options, TRequest request) + { + throw new NotSupportedException(); + } + } + [Fact] public void WriteAllMethodSignatures() { @@ -345,4 +471,502 @@ interface ID : IE, IF { } interface IE { } interface IF { } } + + class MyServer : IAllOptions + { + public AsyncUnaryCall Client_AsyncUnary(HelloRequest request, CallOptions options) + { + throw new NotSupportedException(); + } + + public HelloReply Client_BlockingUnary(HelloRequest request, CallOptions options) + { + throw new NotSupportedException(); + } + + public AsyncClientStreamingCall Client_ClientStreaming(CallOptions options) + { + throw new NotSupportedException(); + } + + public AsyncDuplexStreamingCall Client_Duplex(CallOptions options) + { + throw new NotSupportedException(); + } + + public AsyncServerStreamingCall Client_ServerStreaming(HelloRequest request, CallOptions options) + { + throw new NotSupportedException(); + } + + public Task Server_ClientStreaming(IAsyncStreamReader request, ServerCallContext context) + { + throw new NotSupportedException(); + } + + public Task Server_Duplex(IAsyncStreamReader request, IServerStreamWriter response, ServerCallContext context) + { + throw new NotSupportedException(); + } + + public Task Server_ServerStreaming(HelloRequest request, IServerStreamWriter response, ServerCallContext context) + { + throw new NotSupportedException(); + } + + public Task Server_Unary(HelloRequest request, ServerCallContext context) + { + throw new NotSupportedException(); + } + + public HelloReply Shared_BlockingUnary_CancellationToken(HelloRequest request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public void Shared_BlockingUnary_CancellationToken_ValVoid(HelloRequest request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public HelloReply Shared_BlockingUnary_CancellationToken_VoidVal(CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public void Shared_BlockingUnary_CancellationToken_VoidVoid(CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public HelloReply Shared_BlockingUnary_Context(HelloRequest request, CallContext context) + { + throw new NotSupportedException(); + } + + public void Shared_BlockingUnary_Context_ValVoid(HelloRequest request, CallContext context) + { + throw new NotSupportedException(); + } + + public HelloReply Shared_BlockingUnary_Context_VoidVal(CallContext context) + { + throw new NotSupportedException(); + } + + public void Shared_BlockingUnary_Context_VoidVoid(CallContext context) + { + throw new NotSupportedException(); + } + + public HelloReply Shared_BlockingUnary_NoContext(HelloRequest request) + { + throw new NotSupportedException(); + } + + public void Shared_BlockingUnary_NoContext_ValVoid(HelloRequest request) + { + throw new NotSupportedException(); + } + + public HelloReply Shared_BlockingUnary_NoContext_VoidVal() + { + throw new NotSupportedException(); + } + + public void Shared_BlockingUnary_NoContext_VoidVoid() + { + throw new NotSupportedException(); + } + + public IAsyncEnumerable Shared_Duplex_CancellationToken(IAsyncEnumerable request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public IObservable Shared_Duplex_CancellationToken_Observable(IObservable request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public IAsyncEnumerable Shared_Duplex_Context(IAsyncEnumerable request, CallContext context) + { + throw new NotSupportedException(); + } + + public IObservable Shared_Duplex_Context_Observable(IObservable request, CallContext context) + { + throw new NotSupportedException(); + } + + public IAsyncEnumerable Shared_Duplex_NoContext(IAsyncEnumerable request) + { + throw new NotSupportedException(); + } + + public IObservable Shared_Duplex_NoContext_Observable(IObservable request) + { + throw new NotSupportedException(); + } + + public IAsyncEnumerable Shared_ServerStreaming_CancellationToken(HelloRequest request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public IObservable Shared_ServerStreaming_CancellationToken_Observable(HelloRequest request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public IAsyncEnumerable Shared_ServerStreaming_CancellationToken_VoidVal(CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public IObservable Shared_ServerStreaming_CancellationToken_VoidVal_Observable(CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public IAsyncEnumerable Shared_ServerStreaming_Context(HelloRequest request, CallContext context) + { + throw new NotSupportedException(); + } + + public IObservable Shared_ServerStreaming_Context_Observable(HelloRequest request, CallContext context) + { + throw new NotSupportedException(); + } + + public IAsyncEnumerable Shared_ServerStreaming_Context_VoidVal(CallContext context) + { + throw new NotSupportedException(); + } + + public IObservable Shared_ServerStreaming_Context_VoidVal_Observable(CallContext context) + { + throw new NotSupportedException(); + } + + public IAsyncEnumerable Shared_ServerStreaming_NoContext(HelloRequest request) + { + throw new NotSupportedException(); + } + + public IObservable Shared_ServerStreaming_NoContext_Observable(HelloRequest request) + { + throw new NotSupportedException(); + } + + public IAsyncEnumerable Shared_ServerStreaming_NoContext_VoidVal() + { + throw new NotSupportedException(); + } + + public IObservable Shared_ServerStreaming_NoContext_VoidVal_Observable() + { + throw new NotSupportedException(); + } + + public Task Shared_ServerStreaming_T_Stream_Arg_CancellationToken(HelloRequest request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public Task Shared_ServerStreaming_T_Stream_Arg_Context(HelloRequest request, CallContext context) + { + throw new NotSupportedException(); + } + + public Task Shared_ServerStreaming_T_Stream_Arg_NoContext(HelloRequest request) + { + throw new NotSupportedException(); + } + + public Task Shared_ServerStreaming_T_Stream_CancellationToken(CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public Task Shared_ServerStreaming_T_Stream_Context(CallContext context) + { + throw new NotSupportedException(); + } + + public Task Shared_ServerStreaming_T_Stream_NoContext() + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ServerStreaming_VT_Stream_Arg_CancellationToken(HelloRequest request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ServerStreaming_VT_Stream_Arg_Context(HelloRequest request, CallContext context) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ServerStreaming_VT_Stream_Arg_NoContext(HelloRequest request) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ServerStreaming_VT_Stream_CancellationToken(CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ServerStreaming_VT_Stream_Context(CallContext context) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ServerStreaming_VT_Stream_NoContext() + { + throw new NotSupportedException(); + } + + public Task Shared_TaskClientStreaming_CancellationToken(IAsyncEnumerable request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskClientStreaming_CancellationToken_Observable(IObservable request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskClientStreaming_CancellationToken_ValVoid(IAsyncEnumerable request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskClientStreaming_CancellationToken_ValVoid_Observable(IObservable request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskClientStreaming_Context(IAsyncEnumerable request, CallContext context) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskClientStreaming_Context_Observable(IObservable request, CallContext context) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskClientStreaming_Context_ValVoid(IAsyncEnumerable request, CallContext context) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskClientStreaming_Context_ValVoid_Observable(IObservable request, CallContext context) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskClientStreaming_NoContext(IAsyncEnumerable request) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskClientStreaming_NoContext_Observable(IObservable request) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskClientStreaming_NoContext_ValVoid(IAsyncEnumerable request) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskClientStreaming_NoContext_ValVoid_Observable(IObservable request) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskUnary_CancellationToken(HelloRequest request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskUnary_CancellationToken_ValVoid(HelloRequest request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskUnary_CancellationToken_VoidVal(CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskUnary_CancellationToken_VoidVoid(CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskUnary_Context(HelloRequest request, CallContext context) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskUnary_Context_ValVoid(HelloRequest request, CallContext context) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskUnary_Context_VoidVal(CallContext context) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskUnary_Context_VoidVoid(CallContext context) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskUnary_NoContext(HelloRequest request) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskUnary_NoContext_ValVoid(HelloRequest request) + { + throw new NotSupportedException(); + } + + public Task Shared_TaskUnary_NoContext_VoidVal() + { + throw new NotSupportedException(); + } + + public Task Shared_TaskUnary_NoContext_VoidVoid() + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskClientStreaming_CancellationToken(IAsyncEnumerable request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskClientStreaming_CancellationToken_Observable(IObservable request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskClientStreaming_CancellationToken_ValVoid(IAsyncEnumerable request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskClientStreaming_CancellationToken_ValVoid_Observable(IObservable request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskClientStreaming_Context(IAsyncEnumerable request, CallContext context) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskClientStreaming_Context_Observable(IObservable request, CallContext context) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskClientStreaming_Context_ValVoid(IAsyncEnumerable request, CallContext context) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskClientStreaming_Context_ValVoid_Observable(IObservable request, CallContext context) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskClientStreaming_NoContext(IAsyncEnumerable request) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskClientStreaming_NoContext_Observable(IObservable request) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskClientStreaming_NoContext_ValVoid(IAsyncEnumerable request) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskClientStreaming_NoContext_ValVoid_Observable(IObservable request) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskUnary_CancellationToken(HelloRequest request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskUnary_CancellationToken_ValVoid(HelloRequest request, CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskUnary_CancellationToken_VoidVal(CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskUnary_CancellationToken_VoidVoid(CancellationToken cancellationToken) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskUnary_Context(HelloRequest request, CallContext context) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskUnary_Context_ValVoid(HelloRequest request, CallContext context) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskUnary_Context_VoidVal(CallContext context) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskUnary_Context_VoidVoid(CallContext context) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskUnary_NoContext(HelloRequest request) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskUnary_NoContext_ValVoid(HelloRequest request) + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskUnary_NoContext_VoidVal() + { + throw new NotSupportedException(); + } + + public ValueTask Shared_ValueTaskUnary_NoContext_VoidVoid() + { + throw new NotSupportedException(); + } + } } diff --git a/tests/protobuf-net.Grpc.Test/IAllOptions.cs b/tests/protobuf-net.Grpc.Test/IAllOptions.cs index 6308e423..793bb272 100644 --- a/tests/protobuf-net.Grpc.Test/IAllOptions.cs +++ b/tests/protobuf-net.Grpc.Test/IAllOptions.cs @@ -3,6 +3,7 @@ using ProtoBuf.Grpc; using System; using System.Collections.Generic; +using System.IO; using System.ServiceModel; using System.Threading; using System.Threading.Tasks; @@ -24,7 +25,7 @@ public class HelloReply [ServiceContract] - interface IAllOptions + public interface IAllOptions { // google client patterns HelloReply Client_BlockingUnary(HelloRequest request, CallOptions options); @@ -145,5 +146,21 @@ interface IAllOptions IObservable Shared_Duplex_NoContext_Observable(IObservable request); IObservable Shared_Duplex_Context_Observable(IObservable request, CallContext context); IObservable Shared_Duplex_CancellationToken_Observable(IObservable request, CancellationToken cancellationToken); + + // server-streaming via Stream + Task Shared_ServerStreaming_T_Stream_NoContext(); + Task Shared_ServerStreaming_T_Stream_Context(CallContext context); + Task Shared_ServerStreaming_T_Stream_CancellationToken(CancellationToken cancellationToken); + Task Shared_ServerStreaming_T_Stream_Arg_NoContext(HelloRequest request); + Task Shared_ServerStreaming_T_Stream_Arg_Context(HelloRequest request, CallContext context); + Task Shared_ServerStreaming_T_Stream_Arg_CancellationToken(HelloRequest request, CancellationToken cancellationToken); + + ValueTask Shared_ServerStreaming_VT_Stream_NoContext(); + ValueTask Shared_ServerStreaming_VT_Stream_Context(CallContext context); + ValueTask Shared_ServerStreaming_VT_Stream_CancellationToken(CancellationToken cancellationToken); + ValueTask Shared_ServerStreaming_VT_Stream_Arg_NoContext(HelloRequest request); + ValueTask Shared_ServerStreaming_VT_Stream_Arg_Context(HelloRequest request, CallContext context); + ValueTask Shared_ServerStreaming_VT_Stream_Arg_CancellationToken(HelloRequest request, CancellationToken cancellationToken); + } } diff --git a/tests/protobuf-net.Grpc.Test/TestBindings.cs b/tests/protobuf-net.Grpc.Test/TestBindings.cs index 67ac4658..0e5e6c6a 100644 --- a/tests/protobuf-net.Grpc.Test/TestBindings.cs +++ b/tests/protobuf-net.Grpc.Test/TestBindings.cs @@ -9,36 +9,33 @@ namespace protobuf_net.Grpc.Test { class TestServerBinder : ServerBinder // just tracks what methods are observed { - public HashSet Methods { get; } = new HashSet(); - public List Warnings { get; } = new List(); - public List Errors { get; } = new List(); + public HashSet Methods { get; } = []; + public List Warnings { get; } = []; + public List Errors { get; } = []; protected override bool TryBind(ServiceBindContext bindContext, Method method, MethodStub stub) { Methods.Add(method.FullName); return true; } protected internal override void OnWarn(string message, object?[]? args = null) - => Warnings.Add(string.Format(message, args ?? Array.Empty())); + => Warnings.Add(string.Format(message, args ?? [])); protected internal override void OnError(string message, object?[]? args = null) - => Errors.Add(string.Format(message, args ?? Array.Empty())); + => Errors.Add(string.Format(message, args ?? [])); } - class TestChannel : ChannelBase + class TestChannel(string target) : ChannelBase(target) { - public TestChannel(string target) : base(target) { } public override CallInvoker CreateCallInvoker() => new TestInvoker(this); - public HashSet Calls { get; } = new(); + public HashSet Calls { get; } = []; private void Call(IMethod method) => Calls.Add(Target + ":" + method.FullName); - class TestInvoker : CallInvoker + class TestInvoker(TestChannel channel) : CallInvoker { - public TestChannel Channel { get; } - public TestInvoker(TestChannel channel) - => Channel = channel; + public TestChannel Channel { get; } = channel; public override TResponse BlockingUnaryCall(Method method, string? host, CallOptions options, TRequest request) { diff --git a/tests/protobuf-net.Grpc.Test/protobuf-net.Grpc.Test.csproj b/tests/protobuf-net.Grpc.Test/protobuf-net.Grpc.Test.csproj index e8f5d437..b178cf04 100644 --- a/tests/protobuf-net.Grpc.Test/protobuf-net.Grpc.Test.csproj +++ b/tests/protobuf-net.Grpc.Test/protobuf-net.Grpc.Test.csproj @@ -10,7 +10,7 @@ $(DefineConstants);CLIENT_FACTORY - + diff --git a/version.json b/version.json index 820c8f87..0374b448 100644 --- a/version.json +++ b/version.json @@ -1,6 +1,6 @@ { "$schema": "https://raw.githubusercontent.com/AArnott/Nerdbank.GitVersioning/master/src/NerdBank.GitVersioning/version.schema.json", - "version": "1.1", + "version": "1.2", "assemblyVersion": "1.0", "nugetPackageVersion": { "semVer": 2