Skip to content

Commit

Permalink
Analyzer to warn about using a boxing variant (space-wizards#3564)
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulRitter authored Dec 20, 2022
1 parent 3393658 commit ef0c0d0
Show file tree
Hide file tree
Showing 14 changed files with 285 additions and 46 deletions.
3 changes: 3 additions & 0 deletions Robust.Analyzers/Diagnostics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ public static class Diagnostics
public const string IdAccess = "RA0002";
public const string IdExplicitVirtual = "RA0003";
public const string IdTaskResult = "RA0004";
public const string IdUseGenericVariant = "RA0005";
public const string IdUseGenericVariantInvalidUsage = "RA0006";
public const string IdUseGenericVariantAttributeValueError = "RA0007";

public static SuppressionDescriptor MeansImplicitAssignment =>
new SuppressionDescriptor("RADC1000", "CS0649", "Marked as implicitly assigned.");
Expand Down
238 changes: 238 additions & 0 deletions Robust.Analyzers/PreferGenericVariantAnalyzer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Operations;

namespace Robust.Analyzers;

[DiagnosticAnalyzer(LanguageNames.CSharp)]
public sealed class PreferGenericVariantAnalyzer : DiagnosticAnalyzer
{
private const string AttributeType = "Robust.Shared.Analyzers.PreferGenericVariantAttribute";

public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics => ImmutableArray.Create(
UseGenericVariantDescriptor, UseGenericVariantInvalidUsageDescriptor,
UseGenericVariantAttributeValueErrorDescriptor);

private static readonly DiagnosticDescriptor UseGenericVariantDescriptor = new(
Diagnostics.IdUseGenericVariant,
"Consider using the generic variant of this method",
"Consider using the generic variant of this method to avoid potential allocations",
"Usage",
DiagnosticSeverity.Warning,
true,
"Consider using the generic variant of this method to avoid potential allocations.");

private static readonly DiagnosticDescriptor UseGenericVariantInvalidUsageDescriptor = new(
Diagnostics.IdUseGenericVariantInvalidUsage,
"Invalid generic variant provided",
"Generic variant provided mismatches the amount of type parameters of non-generic variant",
"Usage",
DiagnosticSeverity.Error,
true,
"The non-generic variant should have at least as many type parameter at the beginning of the method as there are generic type parameters on the generic variant.");

private static readonly DiagnosticDescriptor UseGenericVariantAttributeValueErrorDescriptor = new(
Diagnostics.IdUseGenericVariantAttributeValueError,
"Failed resolving generic variant value",
"Failed resolving generic variant value: {0}",
"Usage",
DiagnosticSeverity.Error,
true,
"Consider using nameof to avoid any typos.");

public override void Initialize(AnalysisContext context)
{
context.ConfigureGeneratedCodeAnalysis(GeneratedCodeAnalysisFlags.ReportDiagnostics | GeneratedCodeAnalysisFlags.Analyze);
context.EnableConcurrentExecution();
context.RegisterOperationAction(CheckForGenericVariant, OperationKind.Invocation);
}

private void CheckForGenericVariant(OperationAnalysisContext obj)
{
if(obj.Operation is not IInvocationOperation invocationOperation) return;

var preferGenericAttribute = obj.Compilation.GetTypeByMetadataName(AttributeType);

string genericVariant = null;
AttributeData foundAttribute = null;
foreach (var attribute in invocationOperation.TargetMethod.GetAttributes())
{
if (!SymbolEqualityComparer.Default.Equals(attribute.AttributeClass, preferGenericAttribute))
continue;

genericVariant = attribute.ConstructorArguments[0].Value as string ?? invocationOperation.TargetMethod.Name;
foundAttribute = attribute;
break;
}

if(genericVariant == null) return;

var maxTypeParams = 0;
var typeTypeSymbol = obj.Compilation.GetTypeByMetadataName("System.Type");
foreach (var parameter in invocationOperation.TargetMethod.Parameters)
{
if(!SymbolEqualityComparer.Default.Equals(parameter.Type, typeTypeSymbol)) break;

maxTypeParams++;
}

if (maxTypeParams == 0)
{
obj.ReportDiagnostic(
Diagnostic.Create(UseGenericVariantInvalidUsageDescriptor,
foundAttribute.ApplicationSyntaxReference?.GetSyntax().GetLocation()));
return;
}

IMethodSymbol genericVariantMethod = null;
foreach (var member in invocationOperation.TargetMethod.ContainingType.GetMembers())
{
if (member is not IMethodSymbol methodSymbol
|| methodSymbol.Name != genericVariant
|| !methodSymbol.IsGenericMethod
|| methodSymbol.TypeParameters.Length > maxTypeParams
|| methodSymbol.Parameters.Length > invocationOperation.TargetMethod.Parameters.Length - methodSymbol.TypeParameters.Length
) continue;

var typeParamCount = methodSymbol.TypeParameters.Length;
var failedParamComparison = false;
var objType = obj.Compilation.GetSpecialType(SpecialType.System_Object);
for (int i = 0; i < methodSymbol.Parameters.Length; i++)
{
if (methodSymbol.Parameters[i].Type is ITypeParameterSymbol && SymbolEqualityComparer.Default.Equals(invocationOperation.TargetMethod.Parameters[i + typeParamCount].Type, objType))
continue;

if (!SymbolEqualityComparer.Default.Equals(methodSymbol.Parameters[i].Type,
invocationOperation.TargetMethod.Parameters[i + typeParamCount].Type))
{
failedParamComparison = true;
break;
}
}

if(failedParamComparison) continue;

genericVariantMethod = methodSymbol;
}

if (genericVariantMethod == null)
{
obj.ReportDiagnostic(Diagnostic.Create(
UseGenericVariantAttributeValueErrorDescriptor,
foundAttribute.ApplicationSyntaxReference?.GetSyntax().GetLocation(),
genericVariant));
return;
}

var typeOperands = new string[genericVariantMethod.TypeParameters.Length];
for (var i = 0; i < genericVariantMethod.TypeParameters.Length; i++)
{
switch (invocationOperation.Arguments[i].Value)
{
//todo figure out if ILocalReferenceOperation, IPropertyReferenceOperation or IFieldReferenceOperation is referencing static typeof assignments
case ITypeOfOperation typeOfOperation:
typeOperands[i] = typeOfOperation.TypeOperand.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
continue;
default:
return;
}
}

obj.ReportDiagnostic(Diagnostic.Create(
UseGenericVariantDescriptor,
invocationOperation.Syntax.GetLocation(),
ImmutableDictionary.CreateRange(new Dictionary<string, string>()
{
{"typeOperands", string.Join(",", typeOperands)}
})));
}
}

[ExportCodeFixProvider(LanguageNames.CSharp)]
public class PreferGenericVariantCodeFixProvider : CodeFixProvider
{
private static string Title(string method, string[] types) => $"Use {method}<{string.Join(",", types)}>.";

public override FixAllProvider GetFixAllProvider()
{
return WellKnownFixAllProviders.BatchFixer;
}

public override async Task RegisterCodeFixesAsync(CodeFixContext context)
{
var root = await context.Document.GetSyntaxRootAsync();
if(root == null) return;

foreach (var diagnostic in context.Diagnostics)
{
if (!diagnostic.Properties.TryGetValue("typeOperands", out var typeOperandsRaw)
|| typeOperandsRaw == null) continue;

var node = root.FindNode(diagnostic.Location.SourceSpan);
if (node is ArgumentSyntax argumentSyntax)
node = argumentSyntax.Expression;

if(node is not InvocationExpressionSyntax invocationExpression)
continue;

var typeOperands = typeOperandsRaw.Split(',');

context.RegisterCodeFix(
CodeAction.Create(
Title(invocationExpression.Expression.ToString(), typeOperands),
c => FixAsync(context.Document, invocationExpression, typeOperands, c),
Title(invocationExpression.Expression.ToString(), typeOperands)),
diagnostic);
}
}

private async Task<Document> FixAsync(
Document contextDocument,
InvocationExpressionSyntax invocationExpression,
string[] typeOperands,
CancellationToken cancellationToken)
{
var memberAccess = (MemberAccessExpressionSyntax)invocationExpression.Expression;

var root = (CompilationUnitSyntax) await contextDocument.GetSyntaxRootAsync(cancellationToken);

var arguments = new ArgumentSyntax[invocationExpression.ArgumentList.Arguments.Count - typeOperands.Length];
var types = new TypeSyntax[typeOperands.Length];

for (int i = 0; i < typeOperands.Length; i++)
{
types[i] = ((TypeOfExpressionSyntax)invocationExpression.ArgumentList.Arguments[i].Expression).Type;
}



Array.Copy(
invocationExpression.ArgumentList.Arguments.ToArray(),
typeOperands.Length,
arguments,
0,
arguments.Length);

memberAccess = memberAccess.WithName(SyntaxFactory.GenericName(memberAccess.Name.Identifier,
SyntaxFactory.TypeArgumentList(SyntaxFactory.SeparatedList(types))));

root = root!.ReplaceNode(invocationExpression,
invocationExpression.WithArgumentList(invocationExpression.ArgumentList.WithArguments(SyntaxFactory.SeparatedList(arguments)))
.WithExpression(memberAccess));

return contextDocument.WithSyntaxRoot(root);
}

public override ImmutableArray<string> FixableDiagnosticIds =>
ImmutableArray.Create(Diagnostics.IdUseGenericVariant);
}
5 changes: 5 additions & 0 deletions Robust.Analyzers/Robust.Analyzers.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,9 @@
<Compile Include="..\Robust.Shared\Analyzers\AccessPermissions.cs" />
</ItemGroup>

<ItemGroup>
<!-- Needed for PreferGenericVariantAnalyzer. -->
<Compile Include="..\Robust.Shared\Analyzers\PreferGenericVariantAttribute.cs" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public override ValidationNode ValidateRsi(ISerializationManager serializationMa
return new ErrorNode(node, "Sprite specifier has missing/invalid state node");
}

var path = serializationManager.ValidateNode(typeof(ResourcePath),
var path = serializationManager.ValidateNode<ResourcePath>(
new ValueDataNode($"{SharedSpriteComponent.TextureRoot / valuePathNode.Value}"), context);

if (path is ErrorNode) return path;
Expand All @@ -43,8 +43,9 @@ public override ValidationNode ValidateRsi(ISerializationManager serializationMa
// the state exists. So lets just check if the state .png exists, without properly validating the RSI's
// meta.json

var statePath = serializationManager.ValidateNode(typeof(ResourcePath),
new ValueDataNode($"{SharedSpriteComponent.TextureRoot / valuePathNode.Value / valueStateNode.Value}.png"), context);
var statePath = serializationManager.ValidateNode<ResourcePath>(
new ValueDataNode($"{SharedSpriteComponent.TextureRoot / valuePathNode.Value / valueStateNode.Value}.png"),
context);

if (statePath is ErrorNode) return statePath;

Expand Down
18 changes: 18 additions & 0 deletions Robust.Shared/Analyzers/PreferGenericVariantAttribute.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
using System;

#if NETSTANDARD2_0
namespace Robust.Shared.Analyzers.Implementation;
#else
namespace Robust.Shared.Analyzers;
#endif

[AttributeUsage(AttributeTargets.Method)]
public sealed class PreferGenericVariantAttribute : Attribute
{
public readonly string GenericVariant;

public PreferGenericVariantAttribute(string genericVariant = null!)
{
GenericVariant = genericVariant;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ public ValidationNode Validate(
continue;
}

var keyValidated = serialization.ValidateNode(typeof(string), key, context);
var keyValidated = serialization.ValidateNode<string>(key, context);

ValidationNode valNode;
if (IsNull(val))
Expand Down
17 changes: 2 additions & 15 deletions Robust.Shared/Serialization/Manager/ISerializationManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public interface ISerializationManager
/// A node with whether or not <see cref="node"/> is valid and which of its fields
/// are invalid, if any.
/// </returns>
[PreferGenericVariant]
ValidationNode ValidateNode(Type type, DataNode node, ISerializationContext? context = null);

/// <summary>
Expand Down Expand Up @@ -224,23 +225,9 @@ T Read<T, TNode, TReader>(
/// A serialized datanode created from the given <see cref="value"/>
/// of type <see cref="type"/>.
/// </returns>
[PreferGenericVariant]
DataNode WriteValue(Type type, object? value, bool alwaysWrite = false, ISerializationContext? context = null, bool notNullableOverride = false);

/// <summary>
/// Serializes a value into a node.
/// </summary>
/// <param name="value">The value to serialize.</param>
/// <param name="alwaysWrite">
/// Whether or not to always write the given values into the resulting node,
/// even if they are the default.
/// </param>
/// <param name="context">The context to use, if any.</param>
/// <param name="notNullableOverride">Set true if a reference Type should not allow null. Not necessary for value types.</param>
/// <returns>
/// A serialized datanode created from the given <see cref="value"/>.
/// </returns>
DataNode WriteValue(object? value, bool alwaysWrite = false, ISerializationContext? context = null, bool notNullableOverride = false);

#endregion

#region Copy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,18 +253,6 @@ public DataNode WriteValue<T, TWriter>(T value, bool alwaysWrite = false, ISeria
return WriteValue(GetOrCreateCustomTypeSerializer<TWriter>(), value, alwaysWrite, context, notNullableOverride);
}

public DataNode WriteValue(object? value, bool alwaysWrite = false,
ISerializationContext? context = null, bool notNullableOverride = false)
{
if (value == null)
{
if (notNullableOverride) throw new NullNotAllowedException();
return NullNode();
}

return WriteValue(value.GetType(), value, alwaysWrite, context);
}

public DataNode WriteValue(Type type, object? value, bool alwaysWrite = false, ISerializationContext? context = null, bool notNullableOverride = false)
{
if (value == null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ private ValidationNode Validate(ISerializationManager serializationManager, Mapp
{
if (key is not ValueDataNode value)
{
mapping.Add(new ErrorNode(key, $"Cannot cast node {key} to ValueDataNode."), serializationManager.ValidateNode(typeof(TValue), val, context));
mapping.Add(new ErrorNode(key, $"Cannot cast node {key} to ValueDataNode."), serializationManager.ValidateNode<TValue>(val, context));
continue;
}

mapping.Add(PrototypeSerializer.Validate(serializationManager, value, dependencies, context), serializationManager.ValidateNode(typeof(TValue), val, context));
mapping.Add(PrototypeSerializer.Validate(serializationManager, value, dependencies, context), serializationManager.ValidateNode<TValue>(val, context));
}

return new ValidatedMappingNode(mapping);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ private ValidationNode Validate(ISerializationManager serializationManager, Mapp
{
if (val is not ValueDataNode value)
{
mapping.Add(new ErrorNode(val, $"Cannot cast node {val} to ValueDataNode."), serializationManager.ValidateNode(typeof(TValue), key, context));
mapping.Add(new ErrorNode(val, $"Cannot cast node {val} to ValueDataNode."), serializationManager.ValidateNode<TValue>(key, context));
continue;
}

mapping.Add(PrototypeSerializer.Validate(serializationManager, value, dependencies, context), serializationManager.ValidateNode(typeof(TValue), key, context));
mapping.Add(PrototypeSerializer.Validate(serializationManager, value, dependencies, context), serializationManager.ValidateNode<TValue>(key, context));
}

return new ValidatedMappingNode(mapping);
Expand Down
Loading

0 comments on commit ef0c0d0

Please sign in to comment.