Skip to content

Commit f378a22

Browse files
committed
Change ExecuteUpdate to accept non-expression lambda
Closes #32018
1 parent 76d5bef commit f378a22

38 files changed

Lines changed: 931 additions & 383 deletions

File tree

src/EFCore.Design/Query/Internal/CSharpToLinqTranslator.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,10 +1245,10 @@ private sealed class FakeFieldInfo(
12451245
public bool IsNonNullableReferenceType { get; } = isNonNullableReferenceType;
12461246

12471247
public override object[] GetCustomAttributes(bool inherit)
1248-
=> Array.Empty<object>();
1248+
=> [];
12491249

12501250
public override object[] GetCustomAttributes(Type attributeType, bool inherit)
1251-
=> Array.Empty<object>();
1251+
=> [];
12521252

12531253
public override bool IsDefined(Type attributeType, bool inherit)
12541254
=> false;
@@ -1289,10 +1289,10 @@ public override RuntimeFieldHandle FieldHandle
12891289
private sealed class FakeConstructorInfo(Type type, ParameterInfo[] parameters) : ConstructorInfo
12901290
{
12911291
public override object[] GetCustomAttributes(bool inherit)
1292-
=> Array.Empty<object>();
1292+
=> [];
12931293

12941294
public override object[] GetCustomAttributes(Type attributeType, bool inherit)
1295-
=> Array.Empty<object>();
1295+
=> [];
12961296

12971297
public override bool IsDefined(Type attributeType, bool inherit)
12981298
=> false;

src/EFCore.Design/Query/Internal/PrecompiledQueryCodeGenerator.cs

Lines changed: 110 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -734,18 +734,52 @@ void ProcessCapturedVariables()
734734

735735
for (var i = 1; i < parameters.Length; i++)
736736
{
737-
var parameter = parameters[i];
737+
var (parameterName, parameterType) = (parameters[i].Name!, parameters[i].ParameterType);
738738

739-
if (parameter.ParameterType == typeof(CancellationToken))
739+
if (parameterType == typeof(CancellationToken))
740740
{
741741
continue;
742742
}
743743

744-
if (_funcletizer.CalculatePathsToEvaluatableRoots(operatorMethodCall, i) is not ExpressionTreeFuncletizer.PathNode
745-
evaluatableRootPaths)
744+
ExpressionTreeFuncletizer.PathNode? evaluatableRootPaths;
745+
746+
// ExecuteUpdate requires really special handling: the function accepts a Func<SetPropertyCalls...> argument, but
747+
// we need to run funcletization on the setter lambdas added via that Func<>.
748+
if (operatorMethodCall.Method is
749+
{
750+
Name: nameof(EntityFrameworkQueryableExtensions.ExecuteUpdate)
751+
or nameof(EntityFrameworkQueryableExtensions.ExecuteUpdateAsync),
752+
IsGenericMethod: true
753+
}
754+
&& operatorMethodCall.Method.DeclaringType == typeof(EntityFrameworkQueryableExtensions))
746755
{
747-
// There are no captured variables in this lambda argument - skip the argument
748-
continue;
756+
// First, statically convert the Func<SetPropertyCalls...> to a NewArrayExpression which represents all the
757+
// setters; since that's an expression, we can run the funcletizer on it.
758+
var settersExpression = ProcessExecuteUpdate(operatorMethodCall);
759+
evaluatableRootPaths = _funcletizer.CalculatePathsToEvaluatableRoots(settersExpression);
760+
761+
if (evaluatableRootPaths is null)
762+
{
763+
// There are no captured variables in this lambda argument - skip the argument
764+
continue;
765+
}
766+
767+
// If there were captured variables, generate code to evaluate and build the same NewArrayExpression at runtime,
768+
// and then fall through to the normal logic, generating variable extractors against that NewArrayExpression
769+
// (local var) instead of against the method argument.
770+
code.AppendLine(
771+
$"var setters = {parameterName}(new SetPropertyCalls<{sourceElementTypeName}>()).BuildSettersExpression();");
772+
parameterName = "setters";
773+
parameterType = typeof(NewArrayExpression);
774+
}
775+
else
776+
{
777+
evaluatableRootPaths = _funcletizer.CalculatePathsToEvaluatableRoots(operatorMethodCall, i);
778+
if (evaluatableRootPaths is null)
779+
{
780+
// There are no captured variables in this lambda argument - skip the argument
781+
continue;
782+
}
749783
}
750784

751785
// We have a lambda argument with captured variables. Use the information returned by the funcletizer to generate code
@@ -756,11 +790,11 @@ void ProcessCapturedVariables()
756790
declaredQueryContextVariable = true;
757791
}
758792

759-
if (!parameter.ParameterType.IsSubclassOf(typeof(Expression)))
793+
if (!parameterType.IsSubclassOf(typeof(Expression)))
760794
{
761795
// Special case: this is a non-lambda argument (Skip/Take/FromSql).
762796
// Simply add the argument directly as a parameter
763-
code.AppendLine($"""queryContext.AddParameter("{evaluatableRootPaths.ParameterName}", {parameter.Name});""");
797+
code.AppendLine($"""queryContext.AddParameter("{evaluatableRootPaths.ParameterName}", {parameterName});""");
764798
continue;
765799
}
766800

@@ -769,7 +803,7 @@ void ProcessCapturedVariables()
769803
// Lambda argument. Recurse through evaluatable path trees.
770804
foreach (var child in evaluatableRootPaths.Children!)
771805
{
772-
GenerateCapturedVariableExtractors(parameter.Name!, parameter.ParameterType, child);
806+
GenerateCapturedVariableExtractors(parameterName, parameterType, child);
773807

774808
void GenerateCapturedVariableExtractors(
775809
string currentIdentifier,
@@ -786,12 +820,13 @@ void GenerateCapturedVariableExtractors(
786820

787821
var variableName = capturedVariablesPathTree.ExpressionType.Name;
788822
variableName = char.ToLower(variableName[0]) + variableName[1..^"Expression".Length] + ++variableCounter;
789-
code.AppendLine(
790-
$"var {variableName} = ({capturedVariablesPathTree.ExpressionType.Name}){roslynPathSegment};");
791823

792824
if (capturedVariablesPathTree.Children?.Count > 0)
793825
{
794826
// This is an intermediate node which has captured variables in the children. Continue recursing down.
827+
code.AppendLine(
828+
$"var {variableName} = ({capturedVariablesPathTree.ExpressionType.Name}){roslynPathSegment};");
829+
795830
foreach (var child in capturedVariablesPathTree.Children)
796831
{
797832
GenerateCapturedVariableExtractors(variableName, capturedVariablesPathTree.ExpressionType, child);
@@ -816,7 +851,7 @@ void GenerateCapturedVariableExtractors(
816851
{
817852
code
818853
.Append('"').Append(capturedVariablesPathTree.ParameterName!).AppendLine("\",")
819-
.AppendLine($"Expression.Lambda<Func<object?>>(Expression.Convert({variableName}, typeof(object)))")
854+
.AppendLine($"Expression.Lambda<Func<object?>>(Expression.Convert({roslynPathSegment}, typeof(object)))")
820855
.AppendLine(".Compile(preferInterpretation: true)")
821856
.AppendLine(".Invoke());");
822857
}
@@ -1073,15 +1108,23 @@ or nameof(EntityFrameworkQueryableExtensions.ToListAsync)
10731108
QueryableMethods.GetSumWithSelector(
10741109
method.GetParameters()[1].ParameterType.GenericTypeArguments[0].GenericTypeArguments[1])),
10751110

1076-
// ExecuteDelete/Update behave just like other scalar-returning operators
1111+
// ExecuteDelete behaves just like other scalar-returning operators
10771112
nameof(EntityFrameworkQueryableExtensions.ExecuteDeleteAsync) when method.DeclaringType
10781113
== typeof(EntityFrameworkQueryableExtensions)
10791114
=> RewriteToSync(
10801115
typeof(EntityFrameworkQueryableExtensions).GetMethod(nameof(EntityFrameworkQueryableExtensions.ExecuteDelete))),
1081-
nameof(EntityFrameworkQueryableExtensions.ExecuteUpdateAsync) when method.DeclaringType
1082-
== typeof(EntityFrameworkQueryableExtensions)
1083-
=> RewriteToSync(
1084-
typeof(EntityFrameworkQueryableExtensions).GetMethod(nameof(EntityFrameworkQueryableExtensions.ExecuteUpdate))),
1116+
1117+
// ExecuteUpdate is special; it accepts a non-expression-tree argument (Func<SetPropertyCalls, SetPropertyCalls>),
1118+
// evaluates it immediately, and injects a different MethodCall node into the expression tree with the resulting setter
1119+
// expressions.
1120+
// When statically analyzing ExecuteUpdate, we have to manually perform the same thing.
1121+
nameof(EntityFrameworkQueryableExtensions.ExecuteUpdate) or nameof(EntityFrameworkQueryableExtensions.ExecuteUpdateAsync)
1122+
when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions)
1123+
=> Expression.Call(
1124+
EntityFrameworkQueryableExtensions.ExecuteUpdateMethodInfo.MakeGenericMethod(
1125+
terminatingOperator.Arguments[0].Type.GetSequenceType()),
1126+
penultimateOperator,
1127+
ProcessExecuteUpdate(terminatingOperator)),
10851128

10861129
// In the regular case (sync terminating operator which needs to stay in the query tree), simply compose the terminating
10871130
// operator over the penultimate and return that.
@@ -1116,6 +1159,56 @@ MethodCallExpression RewriteToSync(MethodInfo? syncMethod)
11161159
}
11171160
}
11181161

1162+
// Accepts an expression tree representing a series of SetProperty() calls, parses them and passes them through the SetPropertyCalls
1163+
// builder; returns the resulting NewArrayExpression representing all the setters.
1164+
private static NewArrayExpression ProcessExecuteUpdate(MethodCallExpression executeUpdateCall)
1165+
{
1166+
var setPropertyCalls = Activator.CreateInstance<SetPropertyCalls>();
1167+
var settersLambda = (LambdaExpression)executeUpdateCall.Arguments[1];
1168+
var settersParameter = settersLambda.Parameters.Single();
1169+
var expression = settersLambda.Body;
1170+
1171+
while (expression != settersParameter)
1172+
{
1173+
if (expression is MethodCallExpression
1174+
{
1175+
Method:
1176+
{
1177+
IsGenericMethod: true,
1178+
Name: nameof(SetPropertyCalls<int>.SetProperty),
1179+
DeclaringType.IsGenericType: true,
1180+
},
1181+
Arguments:
1182+
[
1183+
UnaryExpression { NodeType: ExpressionType.Quote, Operand: LambdaExpression propertySelector },
1184+
Expression valueSelector
1185+
]
1186+
} methodCallExpression
1187+
&& methodCallExpression.Method.DeclaringType.GetGenericTypeDefinition() == typeof(SetPropertyCalls<>))
1188+
{
1189+
if (valueSelector is UnaryExpression
1190+
{
1191+
NodeType: ExpressionType.Quote,
1192+
Operand: LambdaExpression unwrappedValueSelector
1193+
})
1194+
{
1195+
setPropertyCalls.SetProperty(propertySelector, unwrappedValueSelector);
1196+
}
1197+
else
1198+
{
1199+
setPropertyCalls.SetProperty(propertySelector, valueSelector);
1200+
}
1201+
1202+
expression = methodCallExpression.Object;
1203+
continue;
1204+
}
1205+
1206+
throw new InvalidOperationException(RelationalStrings.InvalidArgumentToExecuteUpdate);
1207+
}
1208+
1209+
return setPropertyCalls.BuildSettersExpression();
1210+
}
1211+
11191212
/// <summary>
11201213
/// Contains information on a failure to precompile a specific query in the user's source code.
11211214
/// Includes information about the query, its location, and the exception that occured.

src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.ExecuteUpdate.cs

Lines changed: 11 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,22 @@ public partial class RelationalQueryableMethodTranslatingExpressionVisitor
1616
typeof(RelationalSqlTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ParameterValueExtractor))!;
1717

1818
/// <inheritdoc />
19-
protected override UpdateExpression? TranslateExecuteUpdate(ShapedQueryExpression source, LambdaExpression setPropertyCalls)
19+
protected override UpdateExpression? TranslateExecuteUpdate(ShapedQueryExpression source, IReadOnlyList<ExecuteUpdateSetter> setters)
2020
{
21+
if (setters.Count == 0)
22+
{
23+
throw new UnreachableException("Empty setters list");
24+
}
25+
2126
// Our source may have IncludeExpressions because of owned entities or auto-include; unwrap these, as they're meaningless for
2227
// ExecuteUpdate's lambdas. Note that we don't currently support updates across tables.
2328
source = source.UpdateShaperExpression(new IncludePruner().Visit(source.ShaperExpression));
2429

25-
var setters = new List<(LambdaExpression PropertySelector, Expression ValueExpression)>();
26-
PopulateSetPropertyCalls(setPropertyCalls.Body, setters, setPropertyCalls.Parameters[0]);
2730
if (TranslationErrorDetails != null)
2831
{
2932
return null;
3033
}
3134

32-
if (setters.Count == 0)
33-
{
34-
AddTranslationErrorDetails(RelationalStrings.NoSetPropertyInvocation);
35-
return null;
36-
}
37-
3835
// Translate the setters: the left (property) selectors get translated to ColumnExpressions, the right (value) selectors to
3936
// arbitrary SqlExpressions.
4037
// Note that if the query isn't natively supported, we'll do a pushdown (see PushdownWithPkInnerJoinPredicate below); if that
@@ -67,42 +64,9 @@ public partial class RelationalQueryableMethodTranslatingExpressionVisitor
6764

6865
return PushdownWithPkInnerJoinPredicate();
6966

70-
void PopulateSetPropertyCalls(
71-
Expression expression,
72-
List<(LambdaExpression, Expression)> list,
73-
ParameterExpression parameter)
74-
{
75-
switch (expression)
76-
{
77-
case ParameterExpression p
78-
when parameter == p:
79-
break;
80-
81-
case MethodCallExpression
82-
{
83-
Method:
84-
{
85-
IsGenericMethod: true,
86-
Name: nameof(SetPropertyCalls<int>.SetProperty),
87-
DeclaringType.IsGenericType: true
88-
}
89-
} methodCallExpression
90-
when methodCallExpression.Method.DeclaringType.GetGenericTypeDefinition() == typeof(SetPropertyCalls<>):
91-
list.Add(((LambdaExpression)methodCallExpression.Arguments[0], methodCallExpression.Arguments[1]));
92-
93-
PopulateSetPropertyCalls(methodCallExpression.Object!, list, parameter);
94-
95-
break;
96-
97-
default:
98-
AddTranslationErrorDetails(RelationalStrings.InvalidArgumentToExecuteUpdate);
99-
break;
100-
}
101-
}
102-
10367
bool TranslateSetters(
10468
ShapedQueryExpression source,
105-
List<(LambdaExpression PropertySelector, Expression ValueExpression)> setters,
69+
IReadOnlyList<ExecuteUpdateSetter> setters,
10670
[NotNullWhen(true)] out List<ColumnValueSetter>? translatedSetters,
10771
[NotNullWhen(true)] out TableExpressionBase? targetTable)
10872
{
@@ -464,7 +428,7 @@ SqlParameterExpression parameter
464428
var inner = source;
465429
var outerParameter = Expression.Parameter(entityType.ClrType);
466430
var outerKeySelector = Expression.Lambda(outerParameter.CreateKeyValuesExpression(pk.Properties), outerParameter);
467-
var firstPropertyLambdaExpression = setters[0].Item1;
431+
var firstPropertyLambdaExpression = setters[0].PropertySelector;
468432
var entitySource = GetEntitySource(RelationalDependencies.Model, firstPropertyLambdaExpression.Body);
469433
var innerKeySelector = Expression.Lambda(
470434
entitySource.CreateKeyValuesExpression(pk.Properties), firstPropertyLambdaExpression.Parameters);
@@ -481,6 +445,7 @@ SqlParameterExpression parameter
481445

482446
var propertyReplacement = AccessField(transparentIdentifierType, transparentIdentifierParameter, "Outer");
483447
var valueReplacement = AccessField(transparentIdentifierType, transparentIdentifierParameter, "Inner");
448+
var rewrittenSetters = new ExecuteUpdateSetter[setters.Count];
484449
for (var i = 0; i < setters.Count; i++)
485450
{
486451
var (propertyExpression, valueExpression) = setters[i];
@@ -499,14 +464,14 @@ SqlParameterExpression parameter
499464
transparentIdentifierParameter)
500465
: valueExpression;
501466

502-
setters[i] = (propertyExpression, valueExpression);
467+
rewrittenSetters[i] = new(propertyExpression, valueExpression);
503468
}
504469

505470
tableExpression = (TableExpression)outerSelectExpression.Tables[0];
506471

507472
// Re-translate the property selectors to get column expressions pointing to the new outer select expression (the original one
508473
// has been pushed down into a subquery).
509-
if (!TranslateSetters(outer, setters, out var translatedSetters, out _))
474+
if (!TranslateSetters(outer, rewrittenSetters, out var translatedSetters, out _))
510475
{
511476
return null;
512477
}

src/EFCore.Relational/Query/SqlNullabilityProcessor.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1316,7 +1316,12 @@ protected virtual SqlExpression VisitSqlParameter(
13161316
bool allowOptimizedExpansion,
13171317
out bool nullable)
13181318
{
1319-
var parameterValue = ParameterValues[sqlParameterExpression.Name];
1319+
if (!ParameterValues.TryGetValue(sqlParameterExpression.Name, out var parameterValue))
1320+
{
1321+
throw new UnreachableException(
1322+
$"Encountered SqlParameter with name '{sqlParameterExpression.Name}', but such a parameter does not exist.");
1323+
}
1324+
13201325
nullable = parameterValue == null;
13211326

13221327
if (nullable)

0 commit comments

Comments
 (0)