diff --git a/src/Equatable.SourceGenerator/EquatableGenerator.cs b/src/Equatable.SourceGenerator/EquatableGenerator.cs index ae4e355..b6641fb 100644 --- a/src/Equatable.SourceGenerator/EquatableGenerator.cs +++ b/src/Equatable.SourceGenerator/EquatableGenerator.cs @@ -1,3 +1,4 @@ +using System.Reflection; using System.Xml.Linq; using Equatable.SourceGenerator.Models; @@ -200,11 +201,15 @@ private static (ComparerTypes? comparerType, string? comparerName, string? compa private static (ComparerTypes? comparerType, string? comparerName, string? comparerInstance) GetStringComparer(AttributeData? attribute) { - var argument = attribute?.ConstructorArguments.FirstOrDefault(); - if (argument == null || !argument.HasValue) + if (attribute == null || attribute.ConstructorArguments.Length != 1) + return (ComparerTypes.Default, null, null); + + var argument = attribute.ConstructorArguments[0]; + + if (argument.Value is not int value) return (ComparerTypes.String, "CurrentCulture", null); - var comparerName = argument?.Value switch + var comparerName = value switch { 0 => "CurrentCulture", 1 => "CurrentCultureIgnoreCase", @@ -220,30 +225,19 @@ private static (ComparerTypes? comparerType, string? comparerName, string? compa private static (ComparerTypes? comparerType, string? comparerName, string? comparerInstance) GetEqualityComparer(AttributeData? attribute) { - if (attribute == null) + if (attribute == null || attribute.ConstructorArguments.Length != 2) return (ComparerTypes.Default, null, null); - // attribute constructor - var comparerType = attribute.ConstructorArguments.FirstOrDefault(); - if (comparerType.Value is INamedTypeSymbol typeSymbol) - { - return (ComparerTypes.Custom, typeSymbol.ToDisplayString(), null); - } - - // generic attribute - var attributeClass = attribute.AttributeClass; - if (attributeClass is { IsGenericType: true } - && attributeClass.TypeArguments.Length == attributeClass.TypeParameters.Length - && attributeClass.TypeArguments.Length == 1) - { - var typeArgument = attributeClass.TypeArguments[0]; - var comparerName = typeArgument.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var comparerArgument = attribute.ConstructorArguments[0]; + if (comparerArgument.Value is not INamedTypeSymbol typeSymbol) + return (ComparerTypes.Default, null, null); // invalid syntax found - return (ComparerTypes.Custom, comparerName, null); - } + var comparerName = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var instanceArgument = attribute.ConstructorArguments[1]; + var comparerInstance = instanceArgument.Value as string; - return (ComparerTypes.Default, null, null); + return (ComparerTypes.Custom, comparerName, comparerInstance); } diff --git a/src/Equatable.SourceGenerator/EquatableWriter.cs b/src/Equatable.SourceGenerator/EquatableWriter.cs index de0af12..babdf11 100644 --- a/src/Equatable.SourceGenerator/EquatableWriter.cs +++ b/src/Equatable.SourceGenerator/EquatableWriter.cs @@ -124,6 +124,13 @@ private static void GenerateEquatable(IndentedStringBuilder codeBuilder, Equatab break; case ComparerTypes.Reference: + codeBuilder + .Append(" global::System.Object.ReferenceEquals(") + .Append(entityProperty.PropertyName) + .Append(", other.") + .Append(entityProperty.PropertyName) + .Append(")"); + break; case ComparerTypes.Sequence: codeBuilder @@ -395,6 +402,11 @@ private static void GenerateHashCode(IndentedStringBuilder codeBuilder, Equatabl .AppendLine(");"); break; case ComparerTypes.Reference: + codeBuilder + .Append("hashCode = (hashCode * -1521134295) + ") + .Append("global::System.Runtime.CompilerServices.RuntimeHelpers.GetHashCode(") + .Append(entityProperty.PropertyName) + .AppendLine("!);"); break; case ComparerTypes.Sequence: codeBuilder diff --git a/test/Equatable.Entities/Audit.cs b/test/Equatable.Entities/Audit.cs index 3bc41fc..8487c3f 100644 --- a/test/Equatable.Entities/Audit.cs +++ b/test/Equatable.Entities/Audit.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Generic; using Equatable.Attributes; @@ -13,4 +12,7 @@ public partial class Audit : ModelBase public int? TaskId { get; set; } public string? Content { get; set; } public string? UserName { get; set; } + + [ReferenceEquality] + public object? Lock { get; set; } } diff --git a/test/Equatable.Entities/CustomLength.cs b/test/Equatable.Entities/CustomLength.cs new file mode 100644 index 0000000..ea6c2f1 --- /dev/null +++ b/test/Equatable.Entities/CustomLength.cs @@ -0,0 +1,38 @@ +using System.Collections.Generic; + +using Equatable.Attributes; + +namespace Equatable.Entities; + +[Equatable] +public partial class CustomLength +{ + public int Id { get; set; } + + public string Name { get; set; } = null!; + + [EqualityComparer(typeof(LengthComparerDefault))] + public string? Key { get; set; } + + [EqualityComparer(typeof(LengthComparerInstance), nameof(LengthComparerInstance.Instance))] + public string? Value { get; set; } + +} + +public static class LengthComparerDefault +{ + public static readonly LengthEqualityComparer Default = new(); +} + +public static class LengthComparerInstance +{ + public static readonly LengthEqualityComparer Instance = new(); +} + +public class LengthEqualityComparer : IEqualityComparer +{ + public bool Equals(string? x, string? y) => x?.Length == y?.Length; + + public int GetHashCode(string? obj) => obj?.Length.GetHashCode() ?? 0; +} + diff --git a/test/Equatable.Entities/Equatable.Entities.csproj b/test/Equatable.Entities/Equatable.Entities.csproj index c0ca6da..f47488e 100644 --- a/test/Equatable.Entities/Equatable.Entities.csproj +++ b/test/Equatable.Entities/Equatable.Entities.csproj @@ -8,6 +8,10 @@ true + + + + diff --git a/test/Equatable.Entities/Nested.cs b/test/Equatable.Entities/Nested.cs index 3809a73..90b3d70 100644 --- a/test/Equatable.Entities/Nested.cs +++ b/test/Equatable.Entities/Nested.cs @@ -2,13 +2,13 @@ namespace Equatable.Entities; -public class Nested +public partial class Nested { //[Equatable] public partial class Animal { public int Id { get; set; } - public string Name { get; set; } - public string Type { get; set; } + public string? Name { get; set; } + public string? Type { get; set; } } } diff --git a/test/Equatable.Entities/UserImport.cs b/test/Equatable.Entities/UserImport.cs index 72ba702..2495092 100644 --- a/test/Equatable.Entities/UserImport.cs +++ b/test/Equatable.Entities/UserImport.cs @@ -1,5 +1,7 @@ using System; using System.Collections.Generic; +using System.Text.Json.Serialization; + using Equatable.Attributes; namespace Equatable.Entities; @@ -10,6 +12,7 @@ public partial class UserImport [StringEquality(StringComparison.OrdinalIgnoreCase)] public string EmailAddress { get; set; } = null!; + [JsonPropertyName("name")] public string? DisplayName { get; set; } public string? FirstName { get; set; } @@ -20,6 +23,7 @@ public partial class UserImport public DateTimeOffset? LastLogin { get; set; } + [JsonIgnore] [IgnoreEquality] public string FullName => $"{FirstName} {LastName}"; diff --git a/test/Equatable.Generator.Tests/Entities/AuditTest.cs b/test/Equatable.Generator.Tests/Entities/AuditTest.cs new file mode 100644 index 0000000..58645ad --- /dev/null +++ b/test/Equatable.Generator.Tests/Entities/AuditTest.cs @@ -0,0 +1,122 @@ +using Equatable.Entities; + +namespace Equatable.Generator.Tests.Entities; + +public class AuditTest +{ + [Fact] + public void EqualAuditTrue() + { + var lockObject = new object(); + + var left = new Audit + { + Id = 1, + Date = new DateTime(2024, 9, 1), + UserId = 1, + TaskId = 2, + Lock = lockObject + }; + + var right = new Audit + { + Id = 1, + Date = new DateTime(2024, 9, 1), + UserId = 1, + TaskId = 2, + Lock = lockObject + }; + + var isEqual = left.Equals(right); + isEqual.Should().BeTrue(); + + // check operator == + isEqual = left == right; + isEqual.Should().BeTrue(); + } + + [Fact] + public void NotEqualAudit() + { + var left = new Audit + { + Id = 1, + Date = new DateTime(2024, 9, 1), + UserId = 1, + TaskId = 2, + Lock = new object() + }; + + var right = new Audit + { + Id = 1, + Date = new DateTime(2024, 9, 1), + UserId = 1, + TaskId = 2, + Lock = new object() + }; + + var isEqual = left.Equals(right); + isEqual.Should().BeFalse(); + + // check operator != + isEqual = left != right; + isEqual.Should().BeTrue(); + + } + + [Fact] + public void HashCodeAuditTrue() + { + var lockObject = new object(); + var left = new Audit + { + Id = 1, + Date = new DateTime(2024, 9, 1), + UserId = 1, + TaskId = 2, + Lock = lockObject + }; + + var right = new Audit + { + Id = 1, + Date = new DateTime(2024, 9, 1), + UserId = 1, + TaskId = 2, + Lock = lockObject + }; + + var leftCode = left.GetHashCode(); + var rightCode = right.GetHashCode(); + + leftCode.Should().Be(rightCode); + } + + [Fact] + public void HashCodeAuditNotEqual() + { + var left = new Audit + { + Id = 1, + Date = new DateTime(2024, 9, 1), + UserId = 1, + TaskId = 2, + Lock = new object() + }; + + var right = new Audit + { + Id = 1, + Date = new DateTime(2024, 9, 1), + UserId = 1, + TaskId = 2, + Lock = new object() + }; + + var leftCode = left.GetHashCode(); + var rightCode = right.GetHashCode(); + + leftCode.Should().NotBe(rightCode); + } +} diff --git a/test/Equatable.Generator.Tests/Entities/CustomLengthTest.cs b/test/Equatable.Generator.Tests/Entities/CustomLengthTest.cs new file mode 100644 index 0000000..ff46537 --- /dev/null +++ b/test/Equatable.Generator.Tests/Entities/CustomLengthTest.cs @@ -0,0 +1,111 @@ +using Equatable.Entities; + +namespace Equatable.Generator.Tests.Entities; + +public class CustomLengthTest +{ + [Fact] + public void EqualCustomLengthTrue() + { + var left = new CustomLength + { + Id = 1, + Name = "custom", + Key = "aaa", + Value = "zzz" + }; + + var right = new CustomLength + { + Id = 1, + Name = "custom", + Key = "bbb", + Value = "ccc" + }; + + var isEqual = left.Equals(right); + isEqual.Should().BeTrue(); + + // check operator == + isEqual = left == right; + isEqual.Should().BeTrue(); + } + + [Fact] + public void NotEqualCustomLength() + { + var left = new CustomLength + { + Id = 1, + Name = "custom", + Key = "xyzf", + Value = "abc" + }; + + var right = new CustomLength + { + Id = 1, + Name = "custom", + Key = "xyz", + Value = "abc" + }; + + var isEqual = left.Equals(right); + isEqual.Should().BeFalse(); + + // check operator != + isEqual = left != right; + isEqual.Should().BeTrue(); + + } + + [Fact] + public void HashCodeCustomLengthTrue() + { + var left = new CustomLength + { + Id = 1, + Name = "custom", + Key = "fff", + Value = "sss" + }; + + var right = new CustomLength + { + Id = 1, + Name = "custom", + Key = "rrr", + Value = "zzz" + }; + + var leftCode = left.GetHashCode(); + var rightCode = right.GetHashCode(); + + leftCode.Should().Be(rightCode); + } + + [Fact] + public void HashCodeCustomLengthNotEqual() + { + var left = new CustomLength + { + Id = 1, + Name = "custom", + Key = "xyzf", + Value = "abc" + }; + + var right = new CustomLength + { + Id = 1, + Name = "custom", + Key = "xyz", + Value = "abc" + }; + + var leftCode = left.GetHashCode(); + var rightCode = right.GetHashCode(); + + leftCode.Should().NotBe(rightCode); + } +} diff --git a/test/Equatable.Generator.Tests/EquatableGeneratorTest.cs b/test/Equatable.Generator.Tests/EquatableGeneratorTest.cs index 0b841a3..889b4a9 100644 --- a/test/Equatable.Generator.Tests/EquatableGeneratorTest.cs +++ b/test/Equatable.Generator.Tests/EquatableGeneratorTest.cs @@ -66,7 +66,6 @@ public Task GeneratePriorityBaseEquatable() var source = @" using System; using System.Collections.Generic; - using Equatable.Attributes; namespace Equatable.Entities; @@ -75,15 +74,10 @@ namespace Equatable.Entities; public abstract partial class ModelBase { public int Id { get; set; } - public DateTimeOffset Created { get; set; } - public string? CreatedBy { get; set; } - public DateTimeOffset Updated { get; set; } - public string? UpdatedBy { get; set; } - public long RowVersion { get; set; } } @@ -113,7 +107,6 @@ public Task GeneratePriorityBase() var source = @" using System; using System.Collections.Generic; - using Equatable.Attributes; namespace Equatable.Entities; @@ -121,15 +114,10 @@ namespace Equatable.Entities; public abstract partial class ModelBase : IEquatable { public int Id { get; set; } - public DateTimeOffset Created { get; set; } - public string? CreatedBy { get; set; } - public DateTimeOffset Updated { get; set; } - public string? UpdatedBy { get; set; } - public long RowVersion { get; set; } public override bool Equals(object? obj) @@ -295,6 +283,81 @@ public StatusReadOnly(int id, string name, string? description, int displayOrder .ScrubLinesContaining("GeneratedCodeAttribute"); } + [Fact] + public Task GenerateCustomComparer() + { + var source = @" +using System.Collections.Generic; +using Equatable.Attributes; + +namespace Equatable.Entities; + +[Equatable] +public partial class CustomComparer +{ + public int Id { get; set; } + + public string Name { get; set; } = null!; + + [EqualityComparer(typeof(LengthEqualityComparer))] + public string? Key { get; set; } +} + +public class LengthEqualityComparer : IEqualityComparer +{ + public static readonly LengthEqualityComparer Default = new(); + + public bool Equals(string? x, string? y) => x?.Length == y?.Length; + + public int GetHashCode(string? obj) => obj?.Length.GetHashCode() ?? 0; +} +"; + + var (diagnostics, output) = GetGeneratedOutput(source); + + diagnostics.Should().BeEmpty(); + + return Verifier + .Verify(output) + .UseDirectory("Snapshots") + .ScrubLinesContaining("GeneratedCodeAttribute"); + } + + [Fact] + public Task GenerateReferenceComparer() + { + var source = @" +using System; +using Equatable.Attributes; + +namespace Equatable.Entities; + +[Equatable] +public partial class Audit +{ + public int Id { get; set; } + public DateTime Date { get; set; } + public int? UserId { get; set; } + public int? TaskId { get; set; } + public string? Content { get; set; } + public string? UserName { get; set; } + + [ReferenceEquality] + public object? Lock { get; set; } +} +"; + + var (diagnostics, output) = GetGeneratedOutput(source); + + diagnostics.Should().BeEmpty(); + + return Verifier + .Verify(output) + .UseDirectory("Snapshots") + .ScrubLinesContaining("GeneratedCodeAttribute"); + } + + private static (ImmutableArray Diagnostics, string Output) GetGeneratedOutput(string source) where T : IIncrementalGenerator, new() { diff --git a/test/Equatable.Generator.Tests/Snapshots/EquatableGeneratorTest.GenerateCustomComparer.verified.txt b/test/Equatable.Generator.Tests/Snapshots/EquatableGeneratorTest.GenerateCustomComparer.verified.txt new file mode 100644 index 0000000..a3edaec --- /dev/null +++ b/test/Equatable.Generator.Tests/Snapshots/EquatableGeneratorTest.GenerateCustomComparer.verified.txt @@ -0,0 +1,47 @@ +// +#nullable enable + +namespace Equatable.Entities +{ + partial class CustomComparer : global::System.IEquatable + { + /// + public bool Equals(CustomComparer? other) + { + return other is not null + && global::System.Collections.Generic.EqualityComparer.Default.Equals(Id, other.Id) + && global::System.Collections.Generic.EqualityComparer.Default.Equals(Name, other.Name) + && global::Equatable.Entities.LengthEqualityComparer.Default.Equals(Key, other.Key); + + } + + /// + public override bool Equals(object? obj) + { + return Equals(obj as CustomComparer); + } + + /// + public static bool operator ==(CustomComparer? left, CustomComparer? right) + { + return global::System.Collections.Generic.EqualityComparer.Default.Equals(left, right); + } + + /// + public static bool operator !=(CustomComparer? left, CustomComparer? right) + { + return !(left == right); + } + + /// + public override int GetHashCode(){ + int hashCode = -749929470; + hashCode = (hashCode * -1521134295) + global::System.Collections.Generic.EqualityComparer.Default.GetHashCode(Id!); + hashCode = (hashCode * -1521134295) + global::System.Collections.Generic.EqualityComparer.Default.GetHashCode(Name!); + hashCode = (hashCode * -1521134295) + global::Equatable.Entities.LengthEqualityComparer.Default.GetHashCode(Key!); + return hashCode; + + } + + } +} diff --git a/test/Equatable.Generator.Tests/Snapshots/EquatableGeneratorTest.GenerateReferenceComparer.verified.txt b/test/Equatable.Generator.Tests/Snapshots/EquatableGeneratorTest.GenerateReferenceComparer.verified.txt new file mode 100644 index 0000000..beaea4e --- /dev/null +++ b/test/Equatable.Generator.Tests/Snapshots/EquatableGeneratorTest.GenerateReferenceComparer.verified.txt @@ -0,0 +1,55 @@ +// +#nullable enable + +namespace Equatable.Entities +{ + partial class Audit : global::System.IEquatable + { + /// + public bool Equals(Audit? other) + { + return other is not null + && global::System.Collections.Generic.EqualityComparer.Default.Equals(Id, other.Id) + && global::System.Collections.Generic.EqualityComparer.Default.Equals(Date, other.Date) + && global::System.Collections.Generic.EqualityComparer.Default.Equals(UserId, other.UserId) + && global::System.Collections.Generic.EqualityComparer.Default.Equals(TaskId, other.TaskId) + && global::System.Collections.Generic.EqualityComparer.Default.Equals(Content, other.Content) + && global::System.Collections.Generic.EqualityComparer.Default.Equals(UserName, other.UserName) + && global::System.Object.ReferenceEquals(Lock, other.Lock); + + } + + /// + public override bool Equals(object? obj) + { + return Equals(obj as Audit); + } + + /// + public static bool operator ==(Audit? left, Audit? right) + { + return global::System.Collections.Generic.EqualityComparer.Default.Equals(left, right); + } + + /// + public static bool operator !=(Audit? left, Audit? right) + { + return !(left == right); + } + + /// + public override int GetHashCode(){ + int hashCode = 374357566; + hashCode = (hashCode * -1521134295) + global::System.Collections.Generic.EqualityComparer.Default.GetHashCode(Id!); + hashCode = (hashCode * -1521134295) + global::System.Collections.Generic.EqualityComparer.Default.GetHashCode(Date!); + hashCode = (hashCode * -1521134295) + global::System.Collections.Generic.EqualityComparer.Default.GetHashCode(UserId!); + hashCode = (hashCode * -1521134295) + global::System.Collections.Generic.EqualityComparer.Default.GetHashCode(TaskId!); + hashCode = (hashCode * -1521134295) + global::System.Collections.Generic.EqualityComparer.Default.GetHashCode(Content!); + hashCode = (hashCode * -1521134295) + global::System.Collections.Generic.EqualityComparer.Default.GetHashCode(UserName!); + hashCode = (hashCode * -1521134295) + global::System.Runtime.CompilerServices.RuntimeHelpers.GetHashCode(Lock!); + return hashCode; + + } + + } +}