diff --git a/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ConnectionFactory.cs b/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ConnectionFactory.cs index a2ceaab4f..aeda544f9 100644 --- a/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ConnectionFactory.cs +++ b/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ConnectionFactory.cs @@ -15,8 +15,15 @@ public ConnectionFactory(IServerNameProvider nameProvider, ILoggerFactory logger { } + protected override void SetInternalUserAgent(IDictionary headers) + { + // Fix issue: https://github.com/Azure/azure-signalr/issues/198 + // .NET Framework has restriction about reserved string as the header name like "User-Agent" + headers[Constants.AsrsUserAgent] = ProductInfo.GetProductInfo(); + } + protected override void SetCustomHeaders(IDictionary headers) { return; } -} +} \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ConnectionFactoryBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ConnectionFactoryBase.cs index 228c118ca..da9868629 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ConnectionFactoryBase.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ConnectionFactoryBase.cs @@ -73,35 +73,35 @@ internal IDictionary GetRequestHeaders() { var headers = new Dictionary(StringComparer.InvariantCultureIgnoreCase); SetCustomHeaders(headers); - CheckHeaders(headers, Constants.Headers.AsrsHeaderPrefix); - CheckHeaders(headers, Constants.Headers.AsrsInternalHeaderPrefix); - SetInternalHeaders(headers); + CheckHeadersPrefix(headers, Constants.Headers.AsrsHeaderPrefix); + CheckHeadersPrefix(headers, Constants.Headers.AsrsInternalHeaderPrefix); + SetInternalUserAgent(headers); + SetServerId(headers); return headers; } - internal virtual void SetInternalHeaders(IDictionary headers) + internal virtual void SetServerId(IDictionary headers) { - // Fix issue: https://github.com/Azure/azure-signalr/issues/198 - // .NET Framework has restriction about reserved string as the header name like "User-Agent" - headers[Constants.AsrsUserAgent] = ProductInfo.GetProductInfo(); - if (!string.IsNullOrEmpty(_serverId) && !headers.ContainsKey(Constants.Headers.AsrsServerId)) { headers.Add(Constants.Headers.AsrsServerId, _serverId); } } + protected abstract void SetInternalUserAgent(IDictionary headers); + protected abstract void SetCustomHeaders(IDictionary headers); - private static void CheckHeaders(IDictionary headers, string forbidPrefix) + private static void CheckHeadersPrefix(IDictionary headers, string forbidPrefix) { - var item = headers.Where(x => x.Key.StartsWith(forbidPrefix, StringComparison.InvariantCultureIgnoreCase)); - if (item.Any()) + var item = headers.Where(x => x.Key.StartsWith(forbidPrefix, StringComparison.InvariantCultureIgnoreCase)).ToArray(); + if (item.Length > 0) { - var key = item.First().Key; + var key = item[0].Key; throw new ArgumentException($"Invalid header {key}, custom header cannot startwith '{forbidPrefix}'"); } } + private static Uri GetServiceUrl(IServiceEndpointProvider provider, string hubName, string connectionId, string target) { var baseUri = new UriBuilder(provider.GetServerEndpoint(hubName)); @@ -150,6 +150,7 @@ public void Dispose() { _inner.Dispose(); } + private sealed class GracefulLogger : ILogger { private readonly ILogger _inner; @@ -160,10 +161,12 @@ public GracefulLogger(ILogger inner) } #nullable disable + public IDisposable BeginScope(TState state) { return _inner.BeginScope(state); } + #nullable enable public bool IsEnabled(LogLevel logLevel) @@ -191,4 +194,4 @@ public void Log(LogLevel logLevel, EventId eventId, TState state, Except } } } -} +} \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR.Management/ManagementConnectionFactory.cs b/src/Microsoft.Azure.SignalR.Management/ManagementConnectionFactory.cs index 1ab243bce..00b299460 100644 --- a/src/Microsoft.Azure.SignalR.Management/ManagementConnectionFactory.cs +++ b/src/Microsoft.Azure.SignalR.Management/ManagementConnectionFactory.cs @@ -17,18 +17,13 @@ internal class ManagementConnectionFactory(IOptions conte { private readonly string? _productInfo = context.Value.ProductInfo; - internal override void SetInternalHeaders(IDictionary headers) + protected override void SetCustomHeaders(IDictionary headers) { - base.SetInternalHeaders(headers); - - if (_productInfo != null) - { - headers[Constants.AsrsUserAgent] = _productInfo; - } + return; } - protected override void SetCustomHeaders(IDictionary headers) + protected override void SetInternalUserAgent(IDictionary headers) { - return; + headers[Constants.AsrsUserAgent] = _productInfo ?? ProductInfo.GetProductInfo(); } -} +} \ No newline at end of file diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ConnectionFactory.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ConnectionFactory.cs index 7ca8dcf27..1804cbc41 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ConnectionFactory.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ConnectionFactory.cs @@ -21,8 +21,15 @@ public ConnectionFactory(IServerNameProvider nameProvider, _options = options; } + protected override void SetInternalUserAgent(IDictionary headers) + { + // Fix issue: https://github.com/Azure/azure-signalr/issues/198 + // .NET Framework has restriction about reserved string as the header name like "User-Agent" + headers[Constants.AsrsUserAgent] = ProductInfo.GetProductInfo(); + } + protected override void SetCustomHeaders(IDictionary headers) { _options.Value.CustomHeaderProvider?.Invoke(headers); } -} +} \ No newline at end of file diff --git a/test/Microsoft.Azure.SignalR.AspNet.Tests/ConnectionFactoryTests.cs b/test/Microsoft.Azure.SignalR.AspNet.Tests/ConnectionFactoryTests.cs new file mode 100644 index 000000000..3215ada2c --- /dev/null +++ b/test/Microsoft.Azure.SignalR.AspNet.Tests/ConnectionFactoryTests.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Microsoft.Extensions.Logging.Abstractions; + +using Xunit; + +namespace Microsoft.Azure.SignalR.AspNet.Tests; + +#nullable enable + +public class ConnectionFactoryTests +{ + [Fact] + public void TestGetRequestHeaders() + { + var nameProvider = new DefaultServerNameProvider(); + + var loggerFactory = NullLoggerFactory.Instance; + + var factory = new ConnectionFactory(nameProvider, loggerFactory); + + var headers = factory.GetRequestHeaders(); + Assert.True(headers.TryGetValue(Constants.AsrsUserAgent, out var productInfo)); + Assert.StartsWith("Microsoft.Azure.SignalR.AspNet/", productInfo); + + Assert.True(headers.TryGetValue(Constants.Headers.AsrsServerId, out var serverId)); + Assert.Equal(nameProvider.GetName(), serverId); + } +} diff --git a/test/Microsoft.Azure.SignalR.Management.Tests/ManagementConnectionFactoryTests.cs b/test/Microsoft.Azure.SignalR.Management.Tests/ManagementConnectionFactoryTests.cs index 3c2904eab..7755385ca 100644 --- a/test/Microsoft.Azure.SignalR.Management.Tests/ManagementConnectionFactoryTests.cs +++ b/test/Microsoft.Azure.SignalR.Management.Tests/ManagementConnectionFactoryTests.cs @@ -12,7 +12,7 @@ namespace Microsoft.Azure.SignalR.Management.Tests; public class ManagementConnectionFactoryTests { - private const string OptionsAsrsUserAgent = $"Microsoft.Azure.SignalR.Management.123456"; + private const string OptionsAsrsUserAgent = $"Microsoft.Azure.SignalR.Foo/123456"; [Theory] [InlineData(true)] @@ -39,11 +39,11 @@ public void TestGetRequestHeaders(bool setManagementProductInfo) if (setManagementProductInfo) { - Assert.StartsWith("Microsoft.Azure.SignalR.Management", productInfo); + Assert.StartsWith("Microsoft.Azure.SignalR.Foo/", productInfo); } else { - Assert.StartsWith("Microsoft.Azure.SignalR.Common", productInfo); + Assert.StartsWith("Microsoft.Azure.SignalR.Management/", productInfo); } } } diff --git a/test/Microsoft.Azure.SignalR.Tests/ConnectionFactoryTests.cs b/test/Microsoft.Azure.SignalR.Tests/ConnectionFactoryTests.cs index 6afffebdf..edeb6329c 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ConnectionFactoryTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ConnectionFactoryTests.cs @@ -37,7 +37,7 @@ public void TestSetCustomHeaders(string? key, string? val) var headers = factory.GetRequestHeaders(); Assert.True(headers.TryGetValue(Constants.AsrsUserAgent, out var productInfo)); - Assert.StartsWith("Microsoft.Azure.SignalR.Common", productInfo); + Assert.StartsWith("Microsoft.Azure.SignalR/", productInfo); Assert.True(headers.TryGetValue(Constants.Headers.AsrsServerId, out var serverId)); Assert.Equal(nameProvider.GetName(), serverId); @@ -52,8 +52,10 @@ public void TestSetCustomHeaders(string? key, string? val) [Theory] [InlineData(Constants.AsrsUserAgent, "bar")] [InlineData(Constants.Headers.AsrsServerId, "bar")] - [InlineData("asrs-x", "bar")] + [InlineData("asrs-foo", "bar")] + [InlineData("ASRS-bar", "bar")] [InlineData("x-asrs-foo", "bar")] + [InlineData("x-ASRS-bar", "bar")] public void TestSetCustomHeadersThrows(string key, string val) { var nameProvider = new DefaultServerNameProvider();