Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Special case List<ClaimsIdentity> in SelectPrimaryIdentity #111799

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,31 @@ protected ClaimsPrincipal(SerializationInfo info, StreamingContext context)
{
ArgumentNullException.ThrowIfNull(identities);

foreach (ClaimsIdentity identity in identities)
// If the identities value is exactly a List<ClaimsIdentity>, special case it so that
// the enumerator allocation can be skipped. Doing this for List<ClaimsIdentity> is the 99%
// case because it is normally used on the _identities value, which is a List<ClaimsIdentity>.
if (identities.GetType() == typeof(List<ClaimsIdentity>))
vcsjones marked this conversation as resolved.
Show resolved Hide resolved
{
if (identity != null)
List<ClaimsIdentity> identitiesList = (identities as List<ClaimsIdentity>)!;

for (int i = 0; i < identitiesList.Count; i++)
{
ClaimsIdentity identity = identitiesList[i];

if (identity != null)
{
return identity;
}
}
}
else
{
foreach (ClaimsIdentity identity in identities)
{
return identity;
if (identity != null)
{
return identity;
}
}
}

Expand Down
67 changes: 67 additions & 0 deletions src/libraries/System.Security.Claims/tests/ClaimsPrincipalTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections;
using System.Collections.Generic;
using System.IO;
using System.Linq;
Expand Down Expand Up @@ -242,6 +243,72 @@ public void Current_FallsBackToThread_UnauthenticatedPrincipalPolicy()
}).Dispose();
}

[ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
public void PrimaryIdentitySelector_Default()
{
RemoteExecutor.Invoke(static () =>
{
ClaimsIdentity identity0 = null;
ClaimsIdentity identity1 = new([new Claim("type", "value")]);
ClaimsIdentity identity2 = new([new Claim("type", "value")]);
IEnumerable<ClaimsIdentity> identities = [identity0, identity1, identity2];
Func<IEnumerable<ClaimsIdentity>, ClaimsIdentity> selector = ClaimsPrincipal.PrimaryIdentitySelector;

Assert.Same(identity1, selector(identities));
Assert.Null(selector([]));
AssertExtensions.Throws<ArgumentNullException>("identities", () => selector(null));
}).Dispose();
}

[ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
public void PrimaryIdentitySelector_DefaultDoesNotSpecialCaseInterfaceList()
{
RemoteExecutor.Invoke(static () =>
{
ClaimsIdentity identity0 = null;
ClaimsIdentity identity1 = new([new Claim("type", "value")]);
ClaimsIdentity identity2 = new([new Claim("type", "value")]);
ClaimsIdentityList identities = [identity0, identity1, identity2];
Func<IEnumerable<ClaimsIdentity>, ClaimsIdentity> selector = ClaimsPrincipal.PrimaryIdentitySelector;

Assert.Same(identity1, selector(identities));
Assert.True(identities.EnumeratedAtLeastOnce, nameof(identities.EnumeratedAtLeastOnce));
Assert.Null(selector(new ClaimsIdentityList()));
}).Dispose();
}

private sealed class ClaimsIdentityList : IList<ClaimsIdentity>
{
private readonly List<ClaimsIdentity> _claimsIdentities = [];

public bool EnumeratedAtLeastOnce { get; set; }

public ClaimsIdentity this[int index]
{
get => _claimsIdentities[index];
set => _claimsIdentities[index] = value;
}

public int Count => _claimsIdentities.Count;
public bool IsReadOnly => ((ICollection<ClaimsIdentity>)_claimsIdentities).IsReadOnly;
public void Add(ClaimsIdentity item) => _claimsIdentities.Add(item);
public void Clear() => _claimsIdentities.Clear();
public bool Contains(ClaimsIdentity item) => _claimsIdentities.Contains(item);
public void CopyTo(ClaimsIdentity[] array, int arrayIndex) => _claimsIdentities.CopyTo(array, arrayIndex);
public int IndexOf(ClaimsIdentity item) => _claimsIdentities.IndexOf(item);
public void Insert(int index, ClaimsIdentity item) => _claimsIdentities.Insert(index, item);
public bool Remove(ClaimsIdentity item) => _claimsIdentities.Remove(item);
public void RemoveAt(int index) => _claimsIdentities.RemoveAt(index);

public IEnumerator<ClaimsIdentity> GetEnumerator()
{
EnumeratedAtLeastOnce = true;
return _claimsIdentities.GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable<ClaimsIdentity>)this).GetEnumerator();
}

private class NonClaimsPrincipal : IPrincipal
{
public IIdentity Identity { get; set; }
Expand Down
Loading