diff --git a/src/Vogen/Rules/DoNotCompareWithPrimitivesInEfCoreAnalyzer.cs b/src/Vogen/Rules/DoNotCompareWithPrimitivesInEfCoreAnalyzer.cs index da85af29d0..dc8ab0c096 100644 --- a/src/Vogen/Rules/DoNotCompareWithPrimitivesInEfCoreAnalyzer.cs +++ b/src/Vogen/Rules/DoNotCompareWithPrimitivesInEfCoreAnalyzer.cs @@ -36,6 +36,7 @@ public override void Initialize(AnalysisContext context) context.EnableConcurrentExecution(); context.RegisterSyntaxNodeAction(AnalyzeInvocation, SyntaxKind.InvocationExpression); + context.RegisterSyntaxNodeAction(AnalyzeQieryExpression, SyntaxKind.QueryExpression); } private static void AnalyzeInvocation(SyntaxNodeAnalysisContext context) @@ -68,6 +69,32 @@ private static void AnalyzeInvocation(SyntaxNodeAnalysisContext context) } } + private static void AnalyzeQieryExpression(SyntaxNodeAnalysisContext context) + { + var queryExpr = (QueryExpressionSyntax) context.Node; + var whereClauses = queryExpr.Body.DescendantNodes().OfType(); + var fromClause = queryExpr.FromClause; + + if (!IsAMemberOfDbSet(context, fromClause)) return; + + foreach (var eachArgument in whereClauses) + { + foreach (BinaryExpressionSyntax eachBinaryExpression in eachArgument.DescendantNodes().OfType()) + { + ITypeSymbol? left = context.SemanticModel.GetTypeInfo(eachBinaryExpression.Left).Type; + ITypeSymbol? right = context.SemanticModel.GetTypeInfo(eachBinaryExpression.Right).Type; + + if (left is null || right is null) continue; + + // Check if left is ValueObject and right is integer + if (IsValueObject(left) && right.SpecialType == SpecialType.System_Int32) + { + context.ReportDiagnostic(DiagnosticsCatalogue.BuildDiagnostic(_rule, left.Name, eachBinaryExpression.GetLocation())); + } + } + } + } + private static bool IsAMemberOfDbSet(SyntaxNodeAnalysisContext context, MemberAccessExpressionSyntax memberAccessExpr) { var symbolInfo = context.SemanticModel.GetSymbolInfo(memberAccessExpr.Expression); @@ -80,6 +107,18 @@ private static bool IsAMemberOfDbSet(SyntaxNodeAnalysisContext context, MemberAc return InheritsFrom(ps.Type, dbSetType); } + private static bool IsAMemberOfDbSet(SyntaxNodeAnalysisContext context, FromClauseSyntax fromClauseSyntax) + { + var symbolInfo = context.SemanticModel.GetSymbolInfo(fromClauseSyntax.Expression); + if (symbolInfo.Symbol is not IPropertySymbol ps) return false; + + var dbSetType = context.SemanticModel.Compilation.GetTypeByMetadataName("Microsoft.EntityFrameworkCore.DbSet`1"); + + if (dbSetType is null) return false; + + return InheritsFrom(ps.Type, dbSetType); + } + private static bool IsValueObject(ITypeSymbol type) => type is INamedTypeSymbol symbol && VoFilter.IsTarget(symbol); diff --git a/tests/AnalyzerTests/DoNotCompareWithPrimitivesInEfCoreAnalyzerTests.cs b/tests/AnalyzerTests/DoNotCompareWithPrimitivesInEfCoreAnalyzerTests.cs index c3f4bcade6..37d1cd5287 100644 --- a/tests/AnalyzerTests/DoNotCompareWithPrimitivesInEfCoreAnalyzerTests.cs +++ b/tests/AnalyzerTests/DoNotCompareWithPrimitivesInEfCoreAnalyzerTests.cs @@ -80,154 +80,333 @@ public async Task NoDiagnosticsForEmptyCode() await VerifyCS.VerifyAnalyzerAsync(test); } - - [Fact] - public async Task Triggers_when_found_in_IQueryableOfDbSet() + public class NonQuerySyntax { - var source = _source + """ + [Fact] + public async Task Triggers_when_found_in_IQueryableOfDbSet() + { + var source = _source + """ - public static class Test - { - public static void FilterItems() + public static class Test { - using var ctx = new DbContext(); - - var entities = ctx.Entities.Where(e => {|#0:e.Age == 50|}); + public static void FilterItems() + { + using var ctx = new DbContext(); + + var entities = ctx.Entities.Where(e => {|#0:e.Age == 50|}); + } } - } - """; - var sources = await CombineUserAndGeneratedSource(source); + """; + var sources = await CombineUserAndGeneratedSource(source); - await Run( - sources, - WithDiagnostics("VOG034", DiagnosticSeverity.Error, "Age", 0)); - } + await Run( + sources, + WithDiagnostics("VOG034", DiagnosticSeverity.Error, "Age", 0)); + } - [Fact] - public async Task Triggers_when_found_in_complex_IQueryableOfDbSet() - { - var source = _source + """ + [Fact] + public async Task Triggers_when_found_using_query_syntax() + { + var source = _source + """ - public static class Test - { - public static void FilterItems() + public static class Test { - using var ctx = new DbContext(); - - var entities = ctx.Entities.Where(e => e != null && {|#0:e.Age == 50|}); + public static void FilterItems() + { + using var ctx = new DbContext(); + + var entities = from e in ctx.Entities + where {|#0:e.Age == 50|} + select e; + } } - } - """; - var sources = await CombineUserAndGeneratedSource(source); + """; + var sources = await CombineUserAndGeneratedSource(source); - await Run( - sources, - WithDiagnostics("VOG034", DiagnosticSeverity.Error, "Age", 0)); - } - - [Fact] - public async Task Triggers_when_found_in_complex_IQueryableOfDbSet_Single() - { - var source = _source + """ + await Run( + sources, + WithDiagnostics("VOG034", DiagnosticSeverity.Error, "Age", 0)); + } - public static class Test - { - public static void FilterItems() + [Fact] + public async Task Triggers_when_found_in_complex_IQueryableOfDbSet() + { + var source = _source + """ + + public static class Test { - using var ctx = new DbContext(); - - var entity = ctx.Entities.Single(e => e != null && {|#0:e.Age == 50|}); + public static void FilterItems() + { + using var ctx = new DbContext(); + + var entities = ctx.Entities.Where(e => e != null && {|#0:e.Age == 50|}); + } } - } - """; - var sources = await CombineUserAndGeneratedSource(source); + """; + var sources = await CombineUserAndGeneratedSource(source); - await Run( - sources, - WithDiagnostics("VOG034", DiagnosticSeverity.Error, "Age", 0)); - } + await Run( + sources, + WithDiagnostics("VOG034", DiagnosticSeverity.Error, "Age", 0)); + } - [Fact] - public async Task Not_triggered_when_found_in_non_IQueryableOfDbSet() - { - var source = _source + """ + [Fact] + public async Task Triggers_when_found_in_complex_IQueryableOfDbSet_Single() + { + var source = _source + """ - public static class Test - { - public static void FilterItems() + public static class Test { - var employees = new[] - { - new EmployeeEntity {Name = Name.From("Fred"), Age = Age.From(50) }, - new EmployeeEntity {Name = Name.From("Barney"), Age = Age.From(42) } - }; - - var matching = employees.Where(e => e.Age == 50); + public static void FilterItems() + { + using var ctx = new DbContext(); + + var entity = ctx.Entities.Single(e => e != null && {|#0:e.Age == 50|}); + } } - } - """; - var sources = await CombineUserAndGeneratedSource(source); - - await Run(sources, Enumerable.Empty()); - } + """; + var sources = await CombineUserAndGeneratedSource(source); + await Run( + sources, + WithDiagnostics("VOG034", DiagnosticSeverity.Error, "Age", 0)); + } - private static IEnumerable WithDiagnostics(string code, - DiagnosticSeverity severity, - string arguments, - params int[] locations) - { - foreach (var location in locations) + [Fact] + public async Task Not_triggered_when_found_in_non_IQueryableOfDbSet() { - yield return VerifyCS.Diagnostic(code).WithSeverity(severity).WithLocation(location) - .WithArguments(arguments); + var source = _source + """ + + public static class Test + { + public static void FilterItems() + { + var employees = new[] + { + new EmployeeEntity {Name = Name.From("Fred"), Age = Age.From(50) }, + new EmployeeEntity {Name = Name.From("Barney"), Age = Age.From(42) } + }; + + var matching = employees.Where(e => e.Age == 50); + } + } + """; + var sources = await CombineUserAndGeneratedSource(source); + + await Run(sources, Enumerable.Empty()); } - } - private async Task Run(string source, IEnumerable expected) => await Run([source], expected); - private async Task Run(string[] sources, IEnumerable expected) - { - var test = new VerifyCS.Test + private static IEnumerable WithDiagnostics(string code, + DiagnosticSeverity severity, + string arguments, + params int[] locations) { - CompilerDiagnostics = CompilerDiagnostics.Errors, - ReferenceAssemblies = References.Net80WithEfCoreAndOurs.Value, - }; + foreach (var location in locations) + { + yield return VerifyCS.Diagnostic(code).WithSeverity(severity).WithLocation(location) + .WithArguments(arguments); + } + } + + private async Task Run(string source, IEnumerable expected) => await Run([source], expected); - foreach (var eachSource in sources) + private async Task Run(string[] sources, IEnumerable expected) { - test.TestState.Sources.Add(eachSource); + var test = new VerifyCS.Test + { + CompilerDiagnostics = CompilerDiagnostics.Errors, + ReferenceAssemblies = References.Net80WithEfCoreAndOurs.Value, + }; + + foreach (var eachSource in sources) + { + test.TestState.Sources.Add(eachSource); + } + + test.ExpectedDiagnostics.AddRange(expected); + + await test.RunAsync(); } - test.ExpectedDiagnostics.AddRange(expected); + private static async Task CombineUserAndGeneratedSource(string userSource) + { + PortableExecutableReference peReference = MetadataReference.CreateFromFile(typeof(ValueObjectAttribute).Assembly.Location); + + var strippedSource = _placeholderPattern.Replace(userSource, string.Empty).Replace("|}", string.Empty); + + NuGetPackage[] packages = [new("Microsoft.EntityFrameworkCore", "8.0.10", string.Empty)]; - await test.RunAsync(); + (ImmutableArray Diagnostics, SyntaxTree[] GeneratedSources) output = await new ProjectBuilder() + .WithUserSource(strippedSource) + //.WithNugetPackages(packages) + .WithTargetFramework(TargetFramework.Net8_0) + .GetGeneratedOutput(ignoreInitialCompilationErrors: true, peReference); + + if (output.Diagnostics.Length > 0) + { + throw new AssertFailedException( + $""" + Expected user source to be error and generated code to be free from errors: + User source: {userSource} + Errors: {string.Join(",", output.Diagnostics.Select(d => d.ToString()))} + """); + } + + return [userSource, ..output.GeneratedSources.Select(o => o.ToString())]; + } } - private static async Task CombineUserAndGeneratedSource(string userSource) + public class QuerySyntax { - PortableExecutableReference peReference = MetadataReference.CreateFromFile(typeof(ValueObjectAttribute).Assembly.Location); + [Fact] + public async Task Triggers_when_found_in_IQueryableOfDbSet() + { + var source = _source + """ + + public static class Test + { + public static void FilterItems() + { + using var ctx = new DbContext(); + + var entities = from e in ctx.Entities where {|#0:e.Age == 50|} select e; + } + } + """; + var sources = await CombineUserAndGeneratedSource(source); + + await Run( + sources, + WithDiagnostics("VOG034", DiagnosticSeverity.Error, "Age", 0)); + } + + [Fact] + public async Task Triggers_when_found_in_complex_IQueryableOfDbSet() + { + var source = _source + """ + + public static class Test + { + public static void FilterItems() + { + using var ctx = new DbContext(); + + var entities = from e in ctx.Entities where e != null && {|#0:e.Age == 50|} select e; + } + } + """; + var sources = await CombineUserAndGeneratedSource(source); + + await Run( + sources, + WithDiagnostics("VOG034", DiagnosticSeverity.Error, "Age", 0)); + } + + [Fact] + public async Task Triggers_when_found_in_complex_IQueryableOfDbSet_Single() + { + var source = _source + """ + + public static class Test + { + public static void FilterItems() + { + using var ctx = new DbContext(); + + var entity = (from e in ctx.Entities where e != null && {|#0:e.Age == 50|} select e).Single(); + } + } + """; + var sources = await CombineUserAndGeneratedSource(source); + + await Run( + sources, + WithDiagnostics("VOG034", DiagnosticSeverity.Error, "Age", 0)); + } + + [Fact] + public async Task Not_triggered_when_found_in_non_IQueryableOfDbSet() + { + var source = _source + """ + + public static class Test + { + public static void FilterItems() + { + var employees = new[] + { + new EmployeeEntity {Name = Name.From("Fred"), Age = Age.From(50) }, + new EmployeeEntity {Name = Name.From("Barney"), Age = Age.From(42) } + }; + + var matching = from e in employees where e.Age == 50 select e; + } + } + """; + var sources = await CombineUserAndGeneratedSource(source); + + await Run(sources, Enumerable.Empty()); + } - var strippedSource = _placeholderPattern.Replace(userSource, string.Empty).Replace("|}", string.Empty); - NuGetPackage[] packages = [new("Microsoft.EntityFrameworkCore", "8.0.10", string.Empty)]; + private static IEnumerable WithDiagnostics(string code, + DiagnosticSeverity severity, + string arguments, + params int[] locations) + { + foreach (var location in locations) + { + yield return VerifyCS.Diagnostic(code).WithSeverity(severity).WithLocation(location) + .WithArguments(arguments); + } + } - (ImmutableArray Diagnostics, SyntaxTree[] GeneratedSources) output = await new ProjectBuilder() - .WithUserSource(strippedSource) - //.WithNugetPackages(packages) - .WithTargetFramework(TargetFramework.Net8_0) - .GetGeneratedOutput(ignoreInitialCompilationErrors: true, peReference); + private async Task Run(string source, IEnumerable expected) => await Run([source], expected); - if (output.Diagnostics.Length > 0) + private async Task Run(string[] sources, IEnumerable expected) { - throw new AssertFailedException( - $""" - Expected user source to be error and generated code to be free from errors: - User source: {userSource} - Errors: {string.Join(",", output.Diagnostics.Select(d => d.ToString()))} - """); + var test = new VerifyCS.Test + { + CompilerDiagnostics = CompilerDiagnostics.Errors, + ReferenceAssemblies = References.Net80WithEfCoreAndOurs.Value, + }; + + foreach (var eachSource in sources) + { + test.TestState.Sources.Add(eachSource); + } + + test.ExpectedDiagnostics.AddRange(expected); + + await test.RunAsync(); } - return [userSource, ..output.GeneratedSources.Select(o => o.ToString())]; + private static async Task CombineUserAndGeneratedSource(string userSource) + { + PortableExecutableReference peReference = MetadataReference.CreateFromFile(typeof(ValueObjectAttribute).Assembly.Location); + + var strippedSource = _placeholderPattern.Replace(userSource, string.Empty).Replace("|}", string.Empty); + + NuGetPackage[] packages = [new("Microsoft.EntityFrameworkCore", "8.0.10", string.Empty)]; + + (ImmutableArray Diagnostics, SyntaxTree[] GeneratedSources) output = await new ProjectBuilder() + .WithUserSource(strippedSource) + //.WithNugetPackages(packages) + .WithTargetFramework(TargetFramework.Net8_0) + .GetGeneratedOutput(ignoreInitialCompilationErrors: true, peReference); + + if (output.Diagnostics.Length > 0) + { + throw new AssertFailedException( + $""" + Expected user source to be error and generated code to be free from errors: + User source: {userSource} + Errors: {string.Join(",", output.Diagnostics.Select(d => d.ToString()))} + """); + } + + return [userSource, ..output.GeneratedSources.Select(o => o.ToString())]; + } } } \ No newline at end of file diff --git a/tests/Testbench/EfCoreTest/EfCoreScenario.cs b/tests/Testbench/EfCoreTest/EfCoreScenario.cs index 0f7be14eae..701745369b 100644 --- a/tests/Testbench/EfCoreTest/EfCoreScenario.cs +++ b/tests/Testbench/EfCoreTest/EfCoreScenario.cs @@ -17,7 +17,7 @@ public static void Run() AddAndSaveItems(amount: 10); PrintItems(); - + FilterItems(); return; @@ -56,8 +56,11 @@ static void FilterItems() Console.WriteLine("FILTERING ITEMS..."); using var ctx = new MyContext(); - var entities = ctx.Entities.GroupBy(e => e.Id).Where(x => x.Key == 1); + int age = 50; + var entities = from e in ctx.Entities where e != null && e.Age == age select e; //Console.WriteLine(string.Join(Environment.NewLine, entities.Select(e => $"ID: {e.Id.Value}, Name: {e.Name}, Age: {e.TheAge}"))); + } } -} \ No newline at end of file + +}