From 9f77f08a63e6bf28efa41a6d12fd0375b0ee3a95 Mon Sep 17 00:00:00 2001 From: Blaise Taylor Date: Mon, 13 May 2024 05:52:23 -0400 Subject: [PATCH] Always create instance methods using the declaring type. Fixes Issue #179. (#180) --- .../XpressionMapperVisitor.cs | 16 +++- ...dUseDeclaringTypeForInstanceMethodCalls.cs | 83 +++++++++++++++++++ 2 files changed, 96 insertions(+), 3 deletions(-) create mode 100644 tests/AutoMapper.Extensions.ExpressionMapping.UnitTests/ShouldUseDeclaringTypeForInstanceMethodCalls.cs diff --git a/src/AutoMapper.Extensions.ExpressionMapping/XpressionMapperVisitor.cs b/src/AutoMapper.Extensions.ExpressionMapping/XpressionMapperVisitor.cs index 88ad595..0a92b62 100644 --- a/src/AutoMapper.Extensions.ExpressionMapping/XpressionMapperVisitor.cs +++ b/src/AutoMapper.Extensions.ExpressionMapping/XpressionMapperVisitor.cs @@ -569,9 +569,19 @@ protected override Expression VisitMethodCall(MethodCallExpression node) : GetInstanceExpression(this.Visit(node.Object)); MethodCallExpression GetInstanceExpression(Expression instance) - => node.Method.IsGenericMethod - ? Expression.Call(instance, node.Method.Name, typeArgsForNewMethod.ToArray(), listOfArgumentsForNewMethod.ToArray()) - : Expression.Call(instance, instance.Type.GetMethod(node.Method.Name, listOfArgumentsForNewMethod.Select(a => a.Type).ToArray()), listOfArgumentsForNewMethod.ToArray()); + { + return node.Method.IsGenericMethod + ? Expression.Call(instance, node.Method.Name, typeArgsForNewMethod.ToArray(), listOfArgumentsForNewMethod.ToArray()) + : Expression.Call(instance, GetMethodInfoForNonGeneric(), listOfArgumentsForNewMethod.ToArray()); + + MethodInfo GetMethodInfoForNonGeneric() + { + MethodInfo methodInfo = instance.Type.GetMethod(node.Method.Name, listOfArgumentsForNewMethod.Select(a => a.Type).ToArray()); + if (methodInfo.DeclaringType != instance.Type) + methodInfo = methodInfo.DeclaringType.GetMethod(node.Method.Name, listOfArgumentsForNewMethod.Select(a => a.Type).ToArray()); + return methodInfo; + } + } MethodCallExpression GetStaticExpression() => node.Method.IsGenericMethod diff --git a/tests/AutoMapper.Extensions.ExpressionMapping.UnitTests/ShouldUseDeclaringTypeForInstanceMethodCalls.cs b/tests/AutoMapper.Extensions.ExpressionMapping.UnitTests/ShouldUseDeclaringTypeForInstanceMethodCalls.cs new file mode 100644 index 0000000..6d6ecaf --- /dev/null +++ b/tests/AutoMapper.Extensions.ExpressionMapping.UnitTests/ShouldUseDeclaringTypeForInstanceMethodCalls.cs @@ -0,0 +1,83 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using Xunit; + +namespace AutoMapper.Extensions.ExpressionMapping.UnitTests +{ + public class ShouldUseDeclaringTypeForInstanceMethodCalls + { + [Fact] + public void MethodInfoShouldRetainDeclaringTypeInMappedExpression() + { + //Arrange + var config = new MapperConfiguration + ( + cfg => + { + cfg.CreateMap(); + cfg.CreateMap(); + } + ); + config.AssertConfigurationIsValid(); + var mapper = config.CreateMapper(); + Expression> filter = e => e.SimpleEnum.HasFlag(SimpleEnum.Value3); + EntityModel entityModel1 = new() { SimpleEnum = SimpleEnumModel.Value3 }; + EntityModel entityModel2 = new() { SimpleEnum = SimpleEnumModel.Value2 }; + + //act + Expression> mappedFilter = mapper.MapExpression>>(filter); + + //assert + Assert.Equal(typeof(Enum), HasFlagVisitor.GetasFlagReflectedType(mappedFilter)); + Assert.Single(new List { entityModel1 }.AsQueryable().Where(mappedFilter)); + Assert.Empty(new List { entityModel2 }.AsQueryable().Where(mappedFilter)); + } + + public enum SimpleEnum + { + Value1, + Value2, + Value3 + } + + public record Entity + { + public int Id { get; init; } + public SimpleEnum SimpleEnum { get; init; } + } + + public enum SimpleEnumModel + { + Value1, + Value2, + Value3 + } + + public record EntityModel + { + public int Id { get; init; } + public SimpleEnumModel SimpleEnum { get; init; } + } + + public class HasFlagVisitor : ExpressionVisitor + { + public static Type GetasFlagReflectedType(Expression expression) + { + HasFlagVisitor hasFlagVisitor = new(); + hasFlagVisitor.Visit(expression); + return hasFlagVisitor.HasFlagReflectedType; + } + protected override Expression VisitMethodCall(MethodCallExpression node) + { + if (node.Method.Name == "HasFlag") + HasFlagReflectedType = node.Method.ReflectedType; + + return base.VisitMethodCall(node); + } + + public Type HasFlagReflectedType { get; private set; } + } + } +}