Skip to content

Commit

Permalink
fix AspNet productInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
terencefan committed Feb 26, 2025
1 parent 8984d22 commit 515bb98
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ public ConnectionFactory(IServerNameProvider nameProvider, ILoggerFactory logger

protected override void SetCustomHeaders(IDictionary<string, string> headers)
{
return;
}

protected override void SetInternalUserAgent(IDictionary<string, string> headers)
{
// Fix issue: https://github.com/Azure/azure-signalr/issues/198
// .NET Framework has restriction about reserved string as the header name like "User-Agent"
headers[Constants.AsrsUserAgent] = ProductInfo.GetProductInfo();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,35 +73,35 @@ internal IDictionary<string, string> GetRequestHeaders()
{
var headers = new Dictionary<string, string>(StringComparer.InvariantCultureIgnoreCase);
SetCustomHeaders(headers);
CheckHeaders(headers, Constants.Headers.AsrsHeaderPrefix);
CheckHeaders(headers, Constants.Headers.AsrsInternalHeaderPrefix);
SetInternalHeaders(headers);
CheckHeadersPrefix(headers, Constants.Headers.AsrsHeaderPrefix);
CheckHeadersPrefix(headers, Constants.Headers.AsrsInternalHeaderPrefix);
SetInternalUserAgent(headers);
SetServerId(headers);
return headers;
}

internal virtual void SetInternalHeaders(IDictionary<string, string> headers)
internal virtual void SetServerId(IDictionary<string, string> headers)
{
// Fix issue: https://github.com/Azure/azure-signalr/issues/198
// .NET Framework has restriction about reserved string as the header name like "User-Agent"
headers[Constants.AsrsUserAgent] = ProductInfo.GetProductInfo();

if (!string.IsNullOrEmpty(_serverId) && !headers.ContainsKey(Constants.Headers.AsrsServerId))
{
headers.Add(Constants.Headers.AsrsServerId, _serverId);
}
}

protected abstract void SetInternalUserAgent(IDictionary<string, string> headers);

protected abstract void SetCustomHeaders(IDictionary<string, string> headers);

private static void CheckHeaders(IDictionary<string, string> headers, string forbidPrefix)
private static void CheckHeadersPrefix(IDictionary<string, string> headers, string forbidPrefix)
{
var item = headers.Where(x => x.Key.StartsWith(forbidPrefix, StringComparison.InvariantCultureIgnoreCase));
if (item.Any())
var item = headers.Where(x => x.Key.StartsWith(forbidPrefix, StringComparison.InvariantCultureIgnoreCase)).ToArray();
if (item.Length > 0)
{
var key = item.First().Key;
var key = item[0].Key;
throw new ArgumentException($"Invalid header {key}, custom header cannot startwith '{forbidPrefix}'");
}
}

private static Uri GetServiceUrl(IServiceEndpointProvider provider, string hubName, string connectionId, string target)
{
var baseUri = new UriBuilder(provider.GetServerEndpoint(hubName));
Expand Down Expand Up @@ -150,6 +150,7 @@ public void Dispose()
{
_inner.Dispose();
}

private sealed class GracefulLogger : ILogger
{
private readonly ILogger _inner;
Expand All @@ -160,10 +161,12 @@ public GracefulLogger(ILogger inner)
}

#nullable disable

public IDisposable BeginScope<TState>(TState state)
{
return _inner.BeginScope(state);
}

#nullable enable

public bool IsEnabled(LogLevel logLevel)
Expand Down Expand Up @@ -191,4 +194,4 @@ public void Log<TState>(LogLevel logLevel, EventId eventId, TState state, Except
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,12 @@ internal class ManagementConnectionFactory(IOptions<ServiceManagerOptions> conte
{
private readonly string? _productInfo = context.Value.ProductInfo;

internal override void SetInternalHeaders(IDictionary<string, string> headers)
protected override void SetCustomHeaders(IDictionary<string, string> headers)
{
base.SetInternalHeaders(headers);

if (_productInfo != null)
{
headers[Constants.AsrsUserAgent] = _productInfo;
}
}

protected override void SetCustomHeaders(IDictionary<string, string> headers)
protected override void SetInternalUserAgent(IDictionary<string, string> headers)
{
return;
headers[Constants.AsrsUserAgent] = _productInfo ?? ProductInfo.GetProductInfo();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,26 @@
using System.Collections.Generic;

using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;

namespace Microsoft.Azure.SignalR;

#nullable enable

internal class ConnectionFactory : ConnectionFactoryBase
{
private readonly IOptions<ServiceOptions> _options;

public ConnectionFactory(IServerNameProvider nameProvider,
IOptions<ServiceOptions> options,
ILoggerFactory loggerFactory) : base(nameProvider, loggerFactory)
{
_options = options;
}

protected override void SetCustomHeaders(IDictionary<string, string> headers)
{
_options.Value.CustomHeaderProvider?.Invoke(headers);
}

protected override void SetInternalUserAgent(IDictionary<string, string> headers)
{
// Fix issue: https://github.com/Azure/azure-signalr/issues/198
// .NET Framework has restriction about reserved string as the header name like "User-Agent"
headers[Constants.AsrsUserAgent] = ProductInfo.GetProductInfo();
}
}
6 changes: 0 additions & 6 deletions src/Microsoft.Azure.SignalR/ServiceOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,6 @@ public int ConnectionCount
/// </summary>
public string? ConnectionString { get; set; }

/// <summary>
/// Gets or sets the func to generate custom headers.
/// The headers key should not startwith "asrs-" or "x-asrs", as those are Azure SignalR Service SDK preserved headers and cannot be overwritten.
/// </summary>
public Action<IDictionary<string, string>>? CustomHeaderProvider { get; set; }

/// <summary>
/// Gets or sets the func to set diagnostic client filter from <see cref="HttpContext" />.
/// The clients will be regarded as diagnostic client only if the function returns true.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using Microsoft.Extensions.Logging.Abstractions;

using Xunit;

namespace Microsoft.Azure.SignalR.AspNet.Tests;

#nullable enable

public class ConnectionFactoryTests
{
[Fact]
public void TestGetRequestHeaders()
{
var nameProvider = new DefaultServerNameProvider();

var loggerFactory = NullLoggerFactory.Instance;

var factory = new ConnectionFactory(nameProvider, loggerFactory);

var headers = factory.GetRequestHeaders();
Assert.True(headers.TryGetValue(Constants.AsrsUserAgent, out var productInfo));
Assert.StartsWith("Microsoft.Azure.SignalR.AspNet/", productInfo);

Assert.True(headers.TryGetValue(Constants.Headers.AsrsServerId, out var serverId));
Assert.Equal(nameProvider.GetName(), serverId);
}

[Fact]
public void TestGetServerIdInRequestHeaders()
{
var nameProvider1 = new DefaultServerNameProvider();
var nameProvider2 = new DefaultServerNameProvider();

var name1 = nameProvider1.GetName();
var name2 = nameProvider2.GetName();
Assert.NotEqual(name1, name2);

static string GetServerId(IServerNameProvider nameProvider)
{
var connectionFactory = new ConnectionFactory(nameProvider, NullLoggerFactory.Instance);
var headers = connectionFactory.GetRequestHeaders();
Assert.True(headers.TryGetValue(Constants.Headers.AsrsServerId, out var serverId));
return serverId;
}

var serverId1 = GetServerId(nameProvider1);
var serverId2 = GetServerId(nameProvider2);
Assert.NotEqual(serverId1, serverId2);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace Microsoft.Azure.SignalR.Management.Tests;

public class ManagementConnectionFactoryTests
{
private const string OptionsAsrsUserAgent = $"Microsoft.Azure.SignalR.Management.123456";
private const string OptionsAsrsUserAgent = $"Microsoft.Azure.SignalR.Foo/123456";

[Theory]
[InlineData(true)]
Expand All @@ -39,11 +39,35 @@ public void TestGetRequestHeaders(bool setManagementProductInfo)

if (setManagementProductInfo)
{
Assert.StartsWith("Microsoft.Azure.SignalR.Management", productInfo);
Assert.StartsWith("Microsoft.Azure.SignalR.Foo/", productInfo);
}
else
{
Assert.StartsWith("Microsoft.Azure.SignalR.Common", productInfo);
Assert.StartsWith("Microsoft.Azure.SignalR.Management/", productInfo);
}
}

[Fact]
public void TestGetServerIdInRequestHeaders()
{
var nameProvider1 = new DefaultServerNameProvider();
var nameProvider2 = new DefaultServerNameProvider();

var name1 = nameProvider1.GetName();
var name2 = nameProvider2.GetName();
Assert.NotEqual(name1, name2);

static string GetServerId(IServerNameProvider nameProvider)
{
var options = Options.Create(new ServiceManagerOptions());
var connectionFactory = new ManagementConnectionFactory(options, nameProvider, NullLoggerFactory.Instance);
var headers = connectionFactory.GetRequestHeaders();
Assert.True(headers.TryGetValue(Constants.Headers.AsrsServerId, out var serverId));
return serverId;
}

var serverId1 = GetServerId(nameProvider1);
var serverId2 = GetServerId(nameProvider2);
Assert.NotEqual(serverId1, serverId2);
}
}
64 changes: 20 additions & 44 deletions test/Microsoft.Azure.SignalR.Tests/ConnectionFactoryTests.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;

using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;

using Xunit;

Expand All @@ -14,62 +11,41 @@ namespace Microsoft.Azure.SignalR.Tests;

public class ConnectionFactoryTests
{
[Theory]
[InlineData(null, null)]
[InlineData("foo", "bar")]
public void TestSetCustomHeaders(string? key, string? val)
[Fact]
public void TestGetRequestHeaders()
{
var nameProvider = new DefaultServerNameProvider();

var loggerFactory = NullLoggerFactory.Instance;

var options = Options.Create(new ServiceOptions());

if (key != null && val != null)
{
options.Value.CustomHeaderProvider = headers =>
{
headers.Add(key, val);
};
}

var factory = new ConnectionFactory(nameProvider, options, loggerFactory);
var factory = new ConnectionFactory(nameProvider, loggerFactory);

var headers = factory.GetRequestHeaders();
Assert.True(headers.TryGetValue(Constants.AsrsUserAgent, out var productInfo));
Assert.StartsWith("Microsoft.Azure.SignalR.Common", productInfo);
Assert.StartsWith("Microsoft.Azure.SignalR/", productInfo);

Assert.True(headers.TryGetValue(Constants.Headers.AsrsServerId, out var serverId));
Assert.Equal(nameProvider.GetName(), serverId);

if (key != null && val != null)
{
Assert.True(headers.TryGetValue(key, out var actualVal));
Assert.Equal(val, actualVal);
}
}

[Theory]
[InlineData(Constants.AsrsUserAgent, "bar")]
[InlineData(Constants.Headers.AsrsServerId, "bar")]
[InlineData("asrs-x", "bar")]
[InlineData("x-asrs-foo", "bar")]
public void TestSetCustomHeadersThrows(string key, string val)
[Fact]
public void TestGetServerIdInRequestHeaders()
{
var nameProvider = new DefaultServerNameProvider();
var nameProvider1 = new DefaultServerNameProvider();
var nameProvider2 = new DefaultServerNameProvider();

var loggerFactory = NullLoggerFactory.Instance;
var name1 = nameProvider1.GetName();
var name2 = nameProvider2.GetName();
Assert.NotEqual(name1, name2);

var options = Options.Create(new ServiceOptions());

options.Value.CustomHeaderProvider = headers =>
static string GetServerId(IServerNameProvider nameProvider)
{
headers.Add(key, val);
};

var factory = new ConnectionFactory(nameProvider, options, loggerFactory);
var connectionFactory = new ConnectionFactory(nameProvider, NullLoggerFactory.Instance);
var headers = connectionFactory.GetRequestHeaders();
Assert.True(headers.TryGetValue(Constants.Headers.AsrsServerId, out var serverId));
return serverId;
}

var exception = Assert.Throws<ArgumentException>(factory.GetRequestHeaders);
Assert.Contains(key, exception.Message!);
var serverId1 = GetServerId(nameProvider1);
var serverId2 = GetServerId(nameProvider2);
Assert.NotEqual(serverId1, serverId2);
}
}

0 comments on commit 515bb98

Please sign in to comment.