Skip to content

Commit d444d52

Browse files
authored
[ty] Infer lambda expressions with Callable type context (#22633)
Infer lambda expressions eagerly as part of their parent scope, and with type context. This allows us to infer more precise types for lambda expressions, as well as perform check assignability against `Callable` annotations. Note that this does not change the inferred type of a lambda parameter with the body of the lambda, even if it is annotated. That part is a little more tricky, so will be addressed in a followup PR.
1 parent 9622285 commit d444d52

File tree

7 files changed

+185
-35
lines changed

7 files changed

+185
-35
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# This test would previously panic with: `infer_definition_types(Id(1406)): execute: too many cycle iterations`.
2+
3+
lambda: name_4
4+
5+
@lambda: name_5
6+
class name_1: ...
7+
8+
name_2 = [lambda: name_4, name_1]
9+
10+
if name_2:
11+
@(*name_2,)
12+
class name_3: ...
13+
assert unique_name_19
14+
15+
@lambda: name_3
16+
class name_4[*name_2](0, name_1=name_3): ...
17+
18+
try:
19+
[name_5, name_4] = *name_4, = name_4
20+
except* 0:
21+
...
22+
else:
23+
async def name_4(): ...
24+
25+
for name_3 in name_4: ...

crates/ty_python_semantic/resources/mdtest/bidirectional.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,59 @@ def _(flag: bool):
397397
reveal_type(x2) # revealed: list[int | None]
398398
```
399399

400+
## Lambda expressions
401+
402+
If a lambda expression is annotated as a `Callable` type, the body of the lambda is inferred with
403+
the annotated return type as type context, and the annotated parameter types are respected:
404+
405+
```py
406+
from typing import Callable, TypedDict
407+
408+
class Bar(TypedDict):
409+
bar: int
410+
411+
f1 = lambda x: {"bar": 1}
412+
reveal_type(f1) # revealed: (x) -> dict[str, int]
413+
414+
f2: Callable[[int], Bar] = lambda x: {"bar": 1}
415+
reveal_type(f2) # revealed: (x: int) -> Bar
416+
417+
# error: [missing-typed-dict-key] "Missing required key 'bar' in TypedDict `Bar` constructor"
418+
# error: [invalid-assignment] "Object of type `(x: int) -> dict[Unknown, Unknown]` is not assignable to `(int, /) -> Bar`"
419+
f3: Callable[[int], Bar] = lambda x: {}
420+
reveal_type(f3) # revealed: (int, /) -> Bar
421+
422+
# TODO: This should reveal `str`.
423+
f4: Callable[[str], str] = lambda x: reveal_type(x) # revealed: Unknown
424+
reveal_type(f4) # revealed: (x: str) -> Unknown
425+
426+
# TODO: This should not error once we support `Unpack`.
427+
# error: [invalid-assignment]
428+
f5: Callable[[*tuple[int, ...]], None] = lambda x, y, z: None
429+
reveal_type(f5) # revealed: (tuple[int, ...], /) -> None
430+
431+
f6: Callable[[int, str], None] = lambda *args: None
432+
reveal_type(f6) # revealed: (*args) -> None
433+
434+
# N.B. `Callable` annotations only support positional parameters.
435+
# error: [invalid-assignment]
436+
f7: Callable[[int], None] = lambda *, x=1: None
437+
reveal_type(f7) # revealed: (int, /) -> None
438+
439+
# TODO: This should reveal `(*args: int, *, x=1) -> None` once we support `Unpack`.
440+
f8: Callable[[*tuple[int, ...], int], None] = lambda *args, x=1: None
441+
reveal_type(f8) # revealed: (*args, *, x=1) -> None
442+
```
443+
444+
We do not currently account for type annotations present later in the scope:
445+
446+
```py
447+
f9 = lambda: [1]
448+
# TODO: This should not error.
449+
_: list[int | str] = f9() # error: [invalid-assignment]
450+
reveal_type(f9) # revealed: () -> list[int]
451+
```
452+
400453
## Dunder Calls
401454

402455
The key and value parameters types are used as type context for `__setitem__` dunder calls:

crates/ty_python_semantic/resources/mdtest/cycle.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,16 @@ class C:
128128
self.c = lambda positional_only=self.c, /: positional_only
129129
self.d = lambda *, kw_only=self.d: kw_only
130130

131-
# revealed: (positional=...) -> Unknown
131+
# revealed: (positional=...) -> Unknown | ((positional=...) -> Divergent) | ((positional=...) -> Unknown) | ((positional=...) -> Divergent)
132132
reveal_type(self.a)
133133

134-
# revealed: (*, kw_only=...) -> Unknown
134+
# revealed: (*, kw_only=...) -> Unknown | ((*, kw_only=...) -> Divergent) | ((*, kw_only=...) -> Unknown) | ((*, kw_only=...) -> Divergent)
135135
reveal_type(self.b)
136136

137-
# revealed: (positional_only=..., /) -> Unknown
137+
# revealed: (positional_only=..., /) -> Unknown | ((positional_only=..., /) -> Divergent) | ((positional_only=..., /) -> Unknown) | ((positional_only=..., /) -> Divergent)
138138
reveal_type(self.c)
139139

140-
# revealed: (*, kw_only=...) -> Unknown
140+
# revealed: (*, kw_only=...) -> Unknown | ((*, kw_only=...) -> Divergent) | ((*, kw_only=...) -> Unknown) | ((*, kw_only=...) -> Divergent)
141141
reveal_type(self.d)
142142
```
143143

crates/ty_python_semantic/resources/mdtest/expression/lambda.md

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
`lambda` expressions can be defined without any parameters.
66

77
```py
8-
reveal_type(lambda: 1) # revealed: () -> Unknown
8+
reveal_type(lambda: 1) # revealed: () -> Literal[1]
99

1010
# error: [unresolved-reference]
1111
reveal_type(lambda: a) # revealed: () -> Unknown
@@ -24,7 +24,7 @@ reveal_type(lambda a, b: a + b) # revealed: (a, b) -> Unknown
2424
But, it can have default values:
2525

2626
```py
27-
reveal_type(lambda a=1: a) # revealed: (a=1) -> Unknown
27+
reveal_type(lambda a=1: a) # revealed: (a=1) -> Unknown | Literal[1]
2828
reveal_type(lambda a, b=2: a) # revealed: (a, b=2) -> Unknown
2929
```
3030

@@ -37,25 +37,25 @@ reveal_type(lambda a, b, /, c: c) # revealed: (a, b, /, c) -> Unknown
3737
And, keyword-only parameters:
3838

3939
```py
40-
reveal_type(lambda a, *, b=2, c: b) # revealed: (a, *, b=2, c) -> Unknown
40+
reveal_type(lambda a, *, b=2, c: b) # revealed: (a, *, b=2, c) -> Unknown | Literal[2]
4141
```
4242

4343
And, variadic parameter:
4444

4545
```py
46-
reveal_type(lambda *args: args) # revealed: (*args) -> Unknown
46+
reveal_type(lambda *args: args) # revealed: (*args) -> tuple[Unknown, ...]
4747
```
4848

4949
And, keyword-varidic parameter:
5050

5151
```py
52-
reveal_type(lambda **kwargs: kwargs) # revealed: (**kwargs) -> Unknown
52+
reveal_type(lambda **kwargs: kwargs) # revealed: (**kwargs) -> dict[str, Unknown]
5353
```
5454

5555
Mixing all of them together:
5656

5757
```py
58-
# revealed: (a, b, /, c=True, *args, *, d="default", e=5, **kwargs) -> Unknown
58+
# revealed: (a, b, /, c=True, *args, *, d="default", e=5, **kwargs) -> None
5959
reveal_type(lambda a, b, /, c=True, *args, d="default", e=5, **kwargs: None)
6060
```
6161

@@ -94,7 +94,7 @@ Here, a `lambda` expression is used as the default value for a parameter in anot
9494
expression.
9595

9696
```py
97-
reveal_type(lambda a=lambda x, y: 0: 2) # revealed: (a=...) -> Unknown
97+
reveal_type(lambda a=lambda x, y: 0: 2) # revealed: (a=...) -> Literal[2]
9898
```
9999

100100
## Assignment
@@ -114,6 +114,9 @@ a4: Callable[[int, int], None] = lambda *args: None
114114
a5: Callable[[], None] = lambda x: None
115115
# error: [invalid-assignment]
116116
a6: Callable[[int], None] = lambda: None
117+
118+
# error: [invalid-assignment]
119+
a7: Callable[[], str] = lambda: 1
117120
```
118121

119122
## Function-like behavior of lambdas

crates/ty_python_semantic/src/semantic_index/scope.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ impl<'db> ScopeId<'db> {
3838
pub(crate) fn accepts_type_context(self, db: &dyn Db) -> bool {
3939
matches!(
4040
self.node(db),
41-
NodeWithScopeKind::ListComprehension(_)
41+
NodeWithScopeKind::Lambda(_)
42+
| NodeWithScopeKind::ListComprehension(_)
4243
| NodeWithScopeKind::SetComprehension(_)
4344
| NodeWithScopeKind::DictComprehension(_)
4445
)

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 90 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
579579
NodeWithScopeKind::Function(function) => {
580580
self.infer_function_body(function.node(self.module()));
581581
}
582-
NodeWithScopeKind::Lambda(lambda) => self.infer_lambda_body(lambda.node(self.module())),
582+
NodeWithScopeKind::Lambda(lambda) => {
583+
self.infer_lambda_body(lambda.node(self.module()), tcx);
584+
}
583585
NodeWithScopeKind::Class(class) => self.infer_class_body(class.node(self.module())),
584586
NodeWithScopeKind::ClassTypeParameters(class) => {
585587
self.infer_class_type_params(class.node(self.module()));
@@ -5468,7 +5470,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
54685470
ast::Expr::Subscript(subscript) => self.infer_subscript_expression(subscript),
54695471
ast::Expr::Slice(slice) => self.infer_slice_expression(slice),
54705472
ast::Expr::If(if_expression) => self.infer_if_expression(if_expression, tcx),
5471-
ast::Expr::Lambda(lambda_expression) => self.infer_lambda_expression(lambda_expression),
5473+
ast::Expr::Lambda(lambda_expression) => {
5474+
self.infer_lambda_expression(lambda_expression, tcx)
5475+
}
54725476
ast::Expr::Call(call_expression) => self.infer_call_expression(call_expression, tcx),
54735477
ast::Expr::Starred(starred) => self.infer_starred_expression(starred, tcx),
54745478
ast::Expr::Yield(yield_expression) => self.infer_yield_expression(yield_expression),
@@ -6724,11 +6728,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
67246728
}
67256729
}
67266730

6727-
fn infer_lambda_body(&mut self, lambda_expression: &ast::ExprLambda) {
6728-
self.infer_expression(&lambda_expression.body, TypeContext::default());
6731+
fn infer_lambda_body(&mut self, lambda_expression: &ast::ExprLambda, tcx: TypeContext<'db>) {
6732+
self.infer_expression(&lambda_expression.body, tcx);
67296733
}
67306734

6731-
fn infer_lambda_expression(&mut self, lambda_expression: &ast::ExprLambda) -> Type<'db> {
6735+
fn infer_lambda_expression(
6736+
&mut self,
6737+
lambda_expression: &ast::ExprLambda,
6738+
tcx: TypeContext<'db>,
6739+
) -> Type<'db> {
67326740
let ast::ExprLambda {
67336741
range: _,
67346742
node_index: _,
@@ -6740,27 +6748,64 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
67406748
let in_stub = self.in_stub();
67416749
let previous_deferred_state = std::mem::replace(&mut self.deferred_state, in_stub.into());
67426750

6751+
let callable_tcx = if let Some(tcx) = tcx.annotation
6752+
// TODO: We could perform multi-inference here if there are multiple `Callable` annotations
6753+
// in the union.
6754+
&& let Some(callable) = tcx
6755+
.filter_union(self.db(), Type::is_callable_type)
6756+
.as_callable()
6757+
{
6758+
let [signature] = callable.signatures(self.db()).overloads.as_slice() else {
6759+
panic!("`Callable` type annotations cannot be overloaded");
6760+
};
6761+
6762+
Some(signature)
6763+
} else {
6764+
None
6765+
};
6766+
6767+
// Extract the annotated parameter types.
6768+
//
6769+
// Note that `Callable` annotations are only valid for positional parameters.
6770+
let mut parameter_types = match callable_tcx {
6771+
None => [].iter(),
6772+
Some(signature) => signature.parameters().into_iter(),
6773+
}
6774+
.map(Parameter::annotated_type);
6775+
67436776
let parameters = if let Some(parameters) = parameters {
67446777
let positional_only = parameters
67456778
.posonlyargs
67466779
.iter()
67476780
.map(|param| {
6748-
Parameter::positional_only(Some(param.name().id.clone()))
6781+
let parameter = Parameter::positional_only(Some(param.name().id.clone()))
67496782
.with_optional_default_type(param.default().map(|default_expr| {
67506783
self.infer_expression(default_expr, TypeContext::default())
67516784
.replace_parameter_defaults(self.db())
6752-
}))
6785+
}));
6786+
6787+
if let Some(annotated_type) = parameter_types.next() {
6788+
parameter.with_annotated_type(annotated_type)
6789+
} else {
6790+
parameter
6791+
}
67536792
})
67546793
.collect::<Vec<_>>();
67556794
let positional_or_keyword = parameters
67566795
.args
67576796
.iter()
67586797
.map(|param| {
6759-
Parameter::positional_or_keyword(param.name().id.clone())
6798+
let parameter = Parameter::positional_or_keyword(param.name().id.clone())
67606799
.with_optional_default_type(param.default().map(|default_expr| {
67616800
self.infer_expression(default_expr, TypeContext::default())
67626801
.replace_parameter_defaults(self.db())
6763-
}))
6802+
}));
6803+
6804+
if let Some(annotated_type) = parameter_types.next() {
6805+
parameter.with_annotated_type(annotated_type)
6806+
} else {
6807+
parameter
6808+
}
67646809
})
67656810
.collect::<Vec<_>>();
67666811
let variadic = parameters
@@ -6784,25 +6829,48 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
67846829
.as_ref()
67856830
.map(|param| Parameter::keyword_variadic(param.name().id.clone()));
67866831

6787-
Parameters::new(
6788-
self.db(),
6789-
positional_only
6790-
.into_iter()
6791-
.chain(positional_or_keyword)
6792-
.chain(variadic)
6793-
.chain(keyword_only)
6794-
.chain(keyword_variadic),
6795-
)
6832+
let parameters = positional_only
6833+
.into_iter()
6834+
.chain(positional_or_keyword)
6835+
.chain(variadic)
6836+
.chain(keyword_only)
6837+
.chain(keyword_variadic);
6838+
6839+
Parameters::new(self.db(), parameters)
67966840
} else {
67976841
Parameters::empty()
67986842
};
67996843

68006844
self.deferred_state = previous_deferred_state;
68016845

6802-
// TODO: Useful inference of a lambda's return type will require a different approach,
6803-
// which does the inference of the body expression based on arguments at each call site,
6804-
// rather than eagerly computing a return type without knowing the argument types.
6805-
Type::function_like_callable(self.db(), Signature::new(parameters, Type::unknown()))
6846+
let Some(scope_id) = self
6847+
.index
6848+
.try_node_scope(NodeWithScopeRef::Lambda(lambda_expression))
6849+
else {
6850+
return Type::unknown();
6851+
};
6852+
6853+
let scope = scope_id.to_scope_id(self.db(), self.file());
6854+
6855+
// If we have a direct `Callable` type context, we can infer the body with the annotated
6856+
// return type as type context.
6857+
let return_tcx = if let Some(signature) = callable_tcx {
6858+
match signature.return_ty {
6859+
Type::Dynamic(DynamicType::Unknown) => TypeContext::new(None),
6860+
_ => TypeContext::new(Some(signature.return_ty)),
6861+
}
6862+
} else {
6863+
// TODO: Useful inference of a lambda's return type will require a different approach,
6864+
// which does the inference of the body expression based on arguments at each call site,
6865+
// rather than eagerly computing a return type without knowing the argument types.
6866+
TypeContext::new(None)
6867+
};
6868+
6869+
let inference = infer_scope_types(self.db(), scope, return_tcx);
6870+
self.extend_scope(inference);
6871+
6872+
let return_ty = inference.expression_type(lambda_expression.body.as_ref());
6873+
Type::function_like_callable(self.db(), Signature::new(parameters, return_ty))
68066874
}
68076875

68086876
/// Attempt to narrow a splatted dictionary argument based on the narrowed types of individual

crates/ty_python_semantic/src/types/infer/builder/type_expression.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
461461

462462
ast::Expr::Lambda(lambda_expression) => {
463463
if !self.deferred_state.in_string_annotation() {
464-
self.infer_lambda_expression(lambda_expression);
464+
self.infer_lambda_expression(lambda_expression, TypeContext::default());
465465
}
466466
self.report_invalid_type_expression(
467467
expression,

0 commit comments

Comments
 (0)