Skip to content

Commit

Permalink
Fix to #35393 - GroupJoin in EF Core 9 Returns Null for Joined Entities
Browse files Browse the repository at this point in the history
Problem was that in EF9 we moved some optimizations from sql nullability processor to SqlExpressionFactory (so that we optimize things early). One of the optimizations:

```
!(true == a) -> false == a
!(false == a) -> true == a
```

is not safe to do when a is nullable. Fix is to constrain this optimization in SqlExpressionFactory to only work on argument which we know is not null (constant, column, non-nullable parameter) and do the comprehensive one back in OptimizeNotExpression, once we've converted everything we could to IS NULL checks already.

Fixes #35393
  • Loading branch information
maumar committed Jan 3, 2025
1 parent aed6e81 commit 4abc7f6
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 36 deletions.
22 changes: 19 additions & 3 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -660,20 +660,36 @@ private SqlExpression Not(SqlExpression operand, SqlExpression? existingExpressi
SqlBinaryExpression { OperatorType: ExpressionType.OrElse } binary
=> AndAlso(Not(binary.Left), Not(binary.Right)),

// use equality where possible
// use equality where possible - we can only do this when we know a is not null
// at this point we are limited to constants, parameters and columns
// more comprehensive optimization is done during null expansion
// !(a == true) -> a == false
// !(a == false) -> a == true
SqlBinaryExpression { OperatorType: ExpressionType.Equal, Right: SqlConstantExpression { Value: bool } } binary
SqlBinaryExpression
{
OperatorType: ExpressionType.Equal,
Right: SqlConstantExpression { Value: bool },
Left: SqlConstantExpression { Value: bool }
or SqlParameterExpression { IsNullable: false }
or ColumnExpression { IsNullable: false } } binary
=> Equal(binary.Left, Not(binary.Right)),

// !(true == a) -> false == a
// !(false == a) -> true == a
SqlBinaryExpression { OperatorType: ExpressionType.Equal, Left: SqlConstantExpression { Value: bool } } binary
SqlBinaryExpression
{
OperatorType: ExpressionType.Equal,
Left: SqlConstantExpression { Value: bool },
Right: SqlConstantExpression { Value: bool }
or SqlParameterExpression { IsNullable: false }
or ColumnExpression { IsNullable: false }
} binary
=> Equal(Not(binary.Left), binary.Right),

// !(a == b) -> a != b
SqlBinaryExpression { OperatorType: ExpressionType.Equal } sqlBinaryOperand => NotEqual(
sqlBinaryOperand.Left, sqlBinaryOperand.Right),

// !(a != b) -> a == b
SqlBinaryExpression { OperatorType: ExpressionType.NotEqual } sqlBinaryOperand => Equal(
sqlBinaryOperand.Left, sqlBinaryOperand.Right),
Expand Down
54 changes: 38 additions & 16 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,7 @@ protected virtual SqlExpression VisitSqlBinary(
// we assume that NullSemantics rewrite is only needed (on the current level)
// if the optimization didn't make any changes.
// Reason is that optimization can/will change the nullability of the resulting expression
// and that inforation is not tracked/stored anywhere
// and that information is not tracked/stored anywhere
// so we can no longer rely on nullabilities that we computed earlier (leftNullable, rightNullable)
// when performing null semantics rewrite.
// It should be fine because current optimizations *radically* change the expression
Expand Down Expand Up @@ -1590,10 +1590,10 @@ private SqlExpression RewriteNullSemantics(
}

var leftIsNull = ProcessNullNotNull(_sqlExpressionFactory.IsNull(left), leftNullable);
var leftIsNotNull = _sqlExpressionFactory.Not(leftIsNull);
var leftIsNotNull = OptimizeNotExpression(_sqlExpressionFactory.Not(leftIsNull));

var rightIsNull = ProcessNullNotNull(_sqlExpressionFactory.IsNull(right), rightNullable);
var rightIsNotNull = _sqlExpressionFactory.Not(rightIsNull);
var rightIsNotNull = OptimizeNotExpression(_sqlExpressionFactory.Not(rightIsNull));

SqlExpression body;
if (leftNegated == rightNegated)
Expand Down Expand Up @@ -1625,7 +1625,7 @@ private SqlExpression RewriteNullSemantics(
if (sqlBinaryExpression.OperatorType == ExpressionType.NotEqual)
{
// the factory takes care of simplifying using DeMorgan
body = _sqlExpressionFactory.Not(body);
body = OptimizeNotExpression(_sqlExpressionFactory.Not(body));
}

return body;
Expand All @@ -1643,18 +1643,40 @@ protected virtual SqlExpression OptimizeNotExpression(SqlExpression expression)
return expression;
}

// !(a > b) -> a <= b
// !(a >= b) -> a < b
// !(a < b) -> a >= b
// !(a <= b) -> a > b
if (sqlUnaryExpression.Operand is SqlBinaryExpression sqlBinaryOperand
&& TryNegate(sqlBinaryOperand.OperatorType, out var negated))
if (sqlUnaryExpression.Operand is SqlBinaryExpression sqlBinaryOperand)
{
return _sqlExpressionFactory.MakeBinary(
negated,
sqlBinaryOperand.Left,
sqlBinaryOperand.Right,
sqlBinaryOperand.TypeMapping)!;
// !(a > b) -> a <= b
// !(a >= b) -> a < b
// !(a < b) -> a >= b
// !(a <= b) -> a > b
if (TryNegate(sqlBinaryOperand.OperatorType, out var negated))
{
return _sqlExpressionFactory.MakeBinary(
negated,
sqlBinaryOperand.Left,
sqlBinaryOperand.Right,
sqlBinaryOperand.TypeMapping)!;
}

// use equality where possible - at this point (true == null) and (false == null) have been converted to
// IS NULL / IS NOT NULL (i.e. false), so this optimization is safe to do. See #35393
// !(a == true) -> a == false
// !(a == false) -> a == true
if (sqlBinaryOperand is { OperatorType: ExpressionType.Equal, Right: SqlConstantExpression { Value: bool } })
{
return _sqlExpressionFactory.Equal(
sqlBinaryOperand.Left,
OptimizeNotExpression(_sqlExpressionFactory.Not(sqlBinaryOperand.Right)));
}

// !(true == a) -> false == a
// !(false == a) -> true == a
if (sqlBinaryOperand is { OperatorType: ExpressionType.Equal, Left: SqlConstantExpression { Value: bool } })
{
return _sqlExpressionFactory.Equal(
OptimizeNotExpression(_sqlExpressionFactory.Not(sqlBinaryOperand.Left)),
sqlBinaryOperand.Right);
}
}

// the factory can optimize most `NOT` expressions
Expand Down Expand Up @@ -2039,7 +2061,7 @@ private SqlExpression ProcessNullNotNull(SqlExpression sqlExpression, bool opera
return result;
}
}
break;
break;
}

return sqlUnaryExpression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2459,6 +2459,33 @@ await AssertQueryScalar(
ss => ss.Set<NullSemanticsEntity1>().Where(e => true).Select(e => e.Id));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Compare_constant_true_to_nullable_column_negated(bool async)
=> await AssertQueryScalar(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(x => !(true == x.NullableBoolA)).Select(x => x.Id));


[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Compare_constant_true_to_non_nullable_column_negated(bool async)
=> await AssertQueryScalar(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(x => !(true == x.BoolA)).Select(x => x.Id));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Compare_constant_true_to_expression_which_evaluates_to_null(bool async)
{
var prm = default(bool?);

await AssertQueryScalar(
async,
ss => ss.Set<NullSemanticsEntity1>().Where(x => x.NullableBoolA != null
&& !object.Equals(true, x.NullableBoolA == null ? null : prm)).Select(x => x.Id));
}

// We can't client-evaluate Like (for the expected results).
// However, since the test data has no LIKE wildcards, it effectively functions like equality - except that 'null like null' returns
// false instead of true. So we have this "lite" implementation which doesn't support wildcards.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,24 @@ public virtual Task GroupJoin_aggregate_nested_anonymous_key_selectors(bool asyn
(c, g) => new { c.CustomerID, Sum = g.Sum(x => x.CustomerID.Length) }),
elementSorter: e => e.CustomerID));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupJoin_on_true_equal_true(bool async)
=> AssertQuery(
async,
ss => ss.Set<Customer>().GroupJoin(
ss.Set<Order>(),
x => true,
x => true,
(c, g) => new { c, g })
.Select(x => new { x.c.CustomerID, Orders = x.g }),
elementSorter: e => e.CustomerID,
elementAsserter: (e, a) =>
{
Assert.Equal(e.CustomerID, a.CustomerID);
AssertCollection(e.Orders, a.Orders, elementSorter: ee => ee.OrderID);
});

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Inner_join_with_tautology_predicate_converts_to_cross_join(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4452,7 +4452,7 @@ INNER JOIN (
FROM [Factions] AS [f]
WHERE [f].[Name] = N'Swarm'
) AS [f0] ON [l].[Name] = [f0].[CommanderName]
WHERE [f0].[Eradicated] = CAST(0 AS bit) OR [f0].[Eradicated] IS NULL
WHERE [f0].[Eradicated] <> CAST(1 AS bit) OR [f0].[Eradicated] IS NULL
""");
}

Expand All @@ -4469,7 +4469,7 @@ LEFT JOIN (
FROM [Factions] AS [f]
WHERE [f].[Name] = N'Swarm'
) AS [f0] ON [l].[Name] = [f0].[CommanderName]
WHERE [f0].[Eradicated] = CAST(0 AS bit) OR [f0].[Eradicated] IS NULL
WHERE [f0].[Eradicated] <> CAST(1 AS bit) OR [f0].[Eradicated] IS NULL
""");
}

Expand Down Expand Up @@ -6978,7 +6978,7 @@ FROM [LocustLeaders] AS [l]
INNER JOIN [Factions] AS [f] ON [l].[Name] = [f].[CommanderName]
WHERE CASE
WHEN [f].[Name] = N'Locust' THEN CAST(1 AS bit)
END = CAST(0 AS bit) OR CASE
END <> CAST(1 AS bit) OR CASE
WHEN [f].[Name] = N'Locust' THEN CAST(1 AS bit)
END IS NULL
""");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,22 @@ public override async Task GroupJoin_aggregate_nested_anonymous_key_selectors(bo
AssertSql();
}

public override async Task GroupJoin_on_true_equal_true(bool async)
{
await base.GroupJoin_on_true_equal_true(async);

AssertSql(
"""
SELECT [c].[CustomerID], [o0].[OrderID], [o0].[CustomerID], [o0].[EmployeeID], [o0].[OrderDate]
FROM [Customers] AS [c]
OUTER APPLY (
SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate]
FROM [Orders] AS [o]
) AS [o0]
ORDER BY [c].[CustomerID]
""");
}

public override async Task Inner_join_with_tautology_predicate_converts_to_cross_join(bool async)
{
await base.Inner_join_with_tautology_predicate_converts_to_cross_join(async);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4568,6 +4568,42 @@ FROM [Entities1] AS [e]
""");
}

public async override Task Compare_constant_true_to_nullable_column_negated(bool async)
{
await base.Compare_constant_true_to_nullable_column_negated(async);

AssertSql(
"""
SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE CAST(1 AS bit) <> [e].[NullableBoolA] OR [e].[NullableBoolA] IS NULL
""");
}

public override async Task Compare_constant_true_to_non_nullable_column_negated(bool async)
{
await base.Compare_constant_true_to_non_nullable_column_negated(async);

AssertSql(
"""
SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE [e].[BoolA] = CAST(0 AS bit)
""");
}

public override async Task Compare_constant_true_to_expression_which_evaluates_to_null(bool async)
{
await base.Compare_constant_true_to_expression_which_evaluates_to_null(async);

AssertSql(
"""
SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE [e].[NullableBoolA] IS NOT NULL
""");
}

private void AssertSql(params string[] expected)
=> Fixture.TestSqlLoggerFactory.AssertBaseline(expected);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6051,7 +6051,7 @@ INNER JOIN (
FROM [LocustHordes] AS [l1]
WHERE [l1].[Name] = N'Swarm'
) AS [l2] ON [u].[Name] = [l2].[CommanderName]
WHERE [l2].[Eradicated] = CAST(0 AS bit) OR [l2].[Eradicated] IS NULL
WHERE [l2].[Eradicated] <> CAST(1 AS bit) OR [l2].[Eradicated] IS NULL
""");
}

Expand All @@ -6074,7 +6074,7 @@ LEFT JOIN (
FROM [LocustHordes] AS [l1]
WHERE [l1].[Name] = N'Swarm'
) AS [l2] ON [u].[Name] = [l2].[CommanderName]
WHERE [l2].[Eradicated] = CAST(0 AS bit) OR [l2].[Eradicated] IS NULL
WHERE [l2].[Eradicated] <> CAST(1 AS bit) OR [l2].[Eradicated] IS NULL
""");
}

Expand Down Expand Up @@ -9333,7 +9333,7 @@ FROM [LocustCommanders] AS [l0]
INNER JOIN [LocustHordes] AS [l1] ON [u].[Name] = [l1].[CommanderName]
WHERE CASE
WHEN [l1].[Name] = N'Locust' THEN CAST(1 AS bit)
END = CAST(0 AS bit) OR CASE
END <> CAST(1 AS bit) OR CASE
WHEN [l1].[Name] = N'Locust' THEN CAST(1 AS bit)
END IS NULL
""");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5129,7 +5129,7 @@ public override async Task Null_semantics_on_nullable_bool_from_inner_join_subqu
await base.Null_semantics_on_nullable_bool_from_inner_join_subquery_is_fully_applied(async);

AssertSql(
"""
"""
SELECT [s].[Id], [s].[CapitalName], [s].[Name], [s].[ServerAddress], [s].[CommanderName], [s].[Eradicated], [s].[Discriminator]
FROM [LocustLeaders] AS [l]
INNER JOIN (
Expand All @@ -5140,7 +5140,7 @@ FROM [Factions] AS [f]
LEFT JOIN [LocustHordes] AS [l0] ON [f].[Id] = [l0].[Id]
WHERE [l0].[Id] IS NOT NULL AND [f].[Name] = N'Swarm'
) AS [s] ON [l].[Name] = [s].[CommanderName]
WHERE [s].[Eradicated] = CAST(0 AS bit) OR [s].[Eradicated] IS NULL
WHERE [s].[Eradicated] <> CAST(1 AS bit) OR [s].[Eradicated] IS NULL
""");
}

Expand All @@ -5160,7 +5160,7 @@ FROM [Factions] AS [f]
LEFT JOIN [LocustHordes] AS [l0] ON [f].[Id] = [l0].[Id]
WHERE [l0].[Id] IS NOT NULL AND [f].[Name] = N'Swarm'
) AS [s] ON [l].[Name] = [s].[CommanderName]
WHERE [s].[Eradicated] = CAST(0 AS bit) OR [s].[Eradicated] IS NULL
WHERE [s].[Eradicated] <> CAST(1 AS bit) OR [s].[Eradicated] IS NULL
""");
}

Expand Down Expand Up @@ -7918,7 +7918,7 @@ public override async Task Join_inner_source_custom_projection_followed_by_filte
await base.Join_inner_source_custom_projection_followed_by_filter(async);

AssertSql(
"""
"""
SELECT CASE
WHEN [s].[Name] = N'Locust' THEN CAST(1 AS bit)
END AS [IsEradicated], [s].[CommanderName], [s].[Name]
Expand All @@ -7931,7 +7931,7 @@ WHERE [l0].[Id] IS NOT NULL
) AS [s] ON [l].[Name] = [s].[CommanderName]
WHERE CASE
WHEN [s].[Name] = N'Locust' THEN CAST(1 AS bit)
END = CAST(0 AS bit) OR CASE
END <> CAST(1 AS bit) OR CASE
WHEN [s].[Name] = N'Locust' THEN CAST(1 AS bit)
END IS NULL
""");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2160,7 +2160,7 @@ SELECT CASE
INNER JOIN [Factions] FOR SYSTEM_TIME AS OF '2010-01-01T00:00:00.0000000' AS [f] ON [l].[Name] = [f].[CommanderName]
WHERE CASE
WHEN [f].[Name] = N'Locust' THEN CAST(1 AS bit)
END = CAST(0 AS bit) OR CASE
END <> CAST(1 AS bit) OR CASE
WHEN [f].[Name] = N'Locust' THEN CAST(1 AS bit)
END IS NULL
""");
Expand Down Expand Up @@ -3962,7 +3962,7 @@ INNER JOIN (
FROM [Factions] FOR SYSTEM_TIME AS OF '2010-01-01T00:00:00.0000000' AS [f]
WHERE [f].[Name] = N'Swarm'
) AS [f0] ON [l].[Name] = [f0].[CommanderName]
WHERE [f0].[Eradicated] = CAST(0 AS bit) OR [f0].[Eradicated] IS NULL
WHERE [f0].[Eradicated] <> CAST(1 AS bit) OR [f0].[Eradicated] IS NULL
""");
}

Expand Down Expand Up @@ -5676,7 +5676,7 @@ LEFT JOIN (
FROM [Factions] FOR SYSTEM_TIME AS OF '2010-01-01T00:00:00.0000000' AS [f]
WHERE [f].[Name] = N'Swarm'
) AS [f0] ON [l].[Name] = [f0].[CommanderName]
WHERE [f0].[Eradicated] = CAST(0 AS bit) OR [f0].[Eradicated] IS NULL
WHERE [f0].[Eradicated] <> CAST(1 AS bit) OR [f0].[Eradicated] IS NULL
""");
}

Expand Down
Loading

0 comments on commit 4abc7f6

Please sign in to comment.