Skip to content

Commit 2583e1a

Browse files
committed
infer lambda parameter types with Callable type context
1 parent edf274f commit 2583e1a

8 files changed

Lines changed: 122 additions & 52 deletions

File tree

crates/ty_ide/src/semantic_tokens.rs

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ use ruff_python_ast::{
4343
};
4444
use ruff_text_size::{Ranged, TextLen, TextRange, TextSize};
4545
use std::ops::Deref;
46-
use ty_python_semantic::semantic_index::definition::Definition;
46+
use ty_python_semantic::semantic_index::definition::{Definition, ParameterDefinitionNodeKind};
4747
use ty_python_semantic::types::TypeVarKind;
4848
use ty_python_semantic::{
4949
HasType, SemanticModel, semantic_index::definition::DefinitionKind, types::Type,
@@ -308,7 +308,7 @@ impl<'db> SemanticTokenVisitor<'db> {
308308
}
309309
DefinitionKind::Class(_) => Some((SemanticTokenType::Class, modifiers)),
310310
DefinitionKind::TypeVar(_) => Some((SemanticTokenType::TypeParameter, modifiers)),
311-
DefinitionKind::Parameter(parameter) => {
311+
DefinitionKind::Parameter(ParameterDefinitionNodeKind::Parameter(parameter)) => {
312312
let parsed = parsed_module(db, definition.file(db));
313313
let ty = parameter.node(&parsed.load(db)).inferred_type(&model);
314314

@@ -336,12 +336,7 @@ impl<'db> SemanticTokenVisitor<'db> {
336336

337337
Some((SemanticTokenType::Parameter, modifiers))
338338
}
339-
DefinitionKind::VariadicPositionalParameter(_) => {
340-
Some((SemanticTokenType::Parameter, modifiers))
341-
}
342-
DefinitionKind::VariadicKeywordParameter(_) => {
343-
Some((SemanticTokenType::Parameter, modifiers))
344-
}
339+
DefinitionKind::Parameter(_) => Some((SemanticTokenType::Parameter, modifiers)),
345340
DefinitionKind::TypeAlias(_) => Some((SemanticTokenType::TypeParameter, modifiers)),
346341
DefinitionKind::Import(_)
347342
| DefinitionKind::ImportFrom(_)
@@ -1787,7 +1782,7 @@ b: list["int | str"] | None
17871782
c: "list[int | str] | None"
17881783
d: "list[int | str]" | "None"
17891784
e: 'list["int | str"] | "None"'
1790-
f: """'list["int | str"]' | 'None'"""
1785+
f: """'list["int | str"]' | 'None'"""
17911786
"#,
17921787
);
17931788

@@ -2148,7 +2143,7 @@ class Baz:
21482143
prop: str = \"hello\"
21492144
21502145
baz = Baz()
2151-
s = baz.method
2146+
s = baz.method
21522147
t = baz.CONSTANT
21532148
r = baz.prop
21542149
q = Baz.prop
@@ -2383,7 +2378,7 @@ class MyClass:
23832378
def __init__(self): pass
23842379
23852380
"""unrelated string"""
2386-
2381+
23872382
x: str = "hello"
23882383
"#,
23892384
);
@@ -2414,7 +2409,7 @@ What a good module wooo
24142409
def my_func(): pass
24152410
24162411
"""unrelated string"""
2417-
2412+
24182413
x: str = "hello"
24192414
"#,
24202415
);

crates/ty_python_semantic/resources/mdtest/bidirectional.md

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -419,9 +419,8 @@ reveal_type(f2) # revealed: (x: int) -> Bar
419419
f3: Callable[[int], Bar] = lambda x: {}
420420
reveal_type(f3) # revealed: (int, /) -> Bar
421421

422-
# TODO: This should reveal `str`.
423-
f4: Callable[[str], str] = lambda x: reveal_type(x) # revealed: Unknown
424-
reveal_type(f4) # revealed: (x: str) -> Unknown
422+
f4: Callable[[str], str] = lambda x: reveal_type(x) # revealed: str
423+
reveal_type(f4) # revealed: (x: str) -> str
425424

426425
# TODO: This should not error once we support `Unpack`.
427426
# error: [invalid-assignment]
@@ -439,15 +438,27 @@ reveal_type(f7) # revealed: (int, /) -> None
439438
# TODO: This should reveal `(*args: int, *, x=1) -> None` once we support `Unpack`.
440439
f8: Callable[[*tuple[int, ...], int], None] = lambda *args, x=1: None
441440
reveal_type(f8) # revealed: (*args, *, x=1) -> None
441+
442+
f9: Callable[[str, int, str], tuple[str, int, str]] = lambda x, y, z: reveal_type((x, y, z)) # revealed: tuple[str, int, str]
443+
reveal_type(f9) # revealed: (x: str, y: int, z: str) -> tuple[str, int, str]
444+
445+
# TODO: This should reveal `tuple[int, ...]` once we support `Unpack`.
446+
f10: Callable[[*tuple[int, ...]], tuple[int, ...]] = lambda *args: reveal_type(args) # revealed: tuple[Unknown, ...]
447+
reveal_type(f10) # revealed: (*args) -> tuple[Unknown, ...]
448+
449+
# TODO: Better generic call inference.
450+
def _(x: list[int]):
451+
f11 = list(map(lambda y: y + 1, x))
452+
reveal_type(f11) # revealed: list[Unknown]
442453
```
443454

444455
We do not currently account for type annotations present later in the scope:
445456

446457
```py
447-
f9 = lambda: [1]
458+
f12 = lambda: [1]
448459
# TODO: This should not error.
449-
_: list[int | str] = f9() # error: [invalid-assignment]
450-
reveal_type(f9) # revealed: () -> list[int]
460+
_: list[int | str] = f12() # error: [invalid-assignment]
461+
reveal_type(f12) # revealed: () -> list[int]
451462
```
452463

453464
## Dunder Calls

crates/ty_python_semantic/src/semantic_index.rs

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,9 @@ mod tests {
865865
use crate::Db;
866866
use crate::db::tests::{TestDb, TestDbBuilder};
867867
use crate::semantic_index::ast_ids::{HasScopedUseId, ScopedUseId};
868-
use crate::semantic_index::definition::{Definition, DefinitionKind};
868+
use crate::semantic_index::definition::{
869+
Definition, DefinitionKind, LambdaParameterDefinitionNodeKind, ParameterDefinitionNodeKind,
870+
};
869871
use crate::semantic_index::place::PlaceTable;
870872
use crate::semantic_index::scope::{FileScopeId, Scope, ScopeKind};
871873
use crate::semantic_index::symbol::ScopedSymbolId;
@@ -1156,14 +1158,14 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
11561158
.unwrap();
11571159
assert!(matches!(
11581160
args_binding.kind(&db),
1159-
DefinitionKind::VariadicPositionalParameter(_)
1161+
DefinitionKind::Parameter(ParameterDefinitionNodeKind::VariadicPositionalParameter(_))
11601162
));
11611163
let kwargs_binding = use_def
11621164
.first_public_binding(function_table.symbol_id("kwargs").expect("symbol exists"))
11631165
.unwrap();
11641166
assert!(matches!(
11651167
kwargs_binding.kind(&db),
1166-
DefinitionKind::VariadicKeywordParameter(_)
1168+
DefinitionKind::Parameter(ParameterDefinitionNodeKind::VariadicKeywordParameter(_))
11671169
));
11681170
}
11691171

@@ -1186,29 +1188,44 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs):
11861188
let lambda_table = index.place_table(lambda_scope_id);
11871189
assert_eq!(
11881190
names(lambda_table),
1189-
vec!["a", "b", "c", "d", "args", "kwargs"],
1191+
vec!["a", "b", "c", "args", "d", "kwargs"],
11901192
);
11911193

11921194
let use_def = index.use_def_map(lambda_scope_id);
11931195
for name in ["a", "b", "c", "d"] {
11941196
let binding = use_def
11951197
.first_public_binding(lambda_table.symbol_id(name).expect("symbol exists"))
11961198
.unwrap();
1197-
assert!(matches!(binding.kind(&db), DefinitionKind::Parameter(_)));
1199+
assert!(matches!(
1200+
binding.kind(&db),
1201+
DefinitionKind::LambdaParameter(LambdaParameterDefinitionNodeKind {
1202+
index: _,
1203+
lambda: _,
1204+
parameter: ParameterDefinitionNodeKind::Parameter(_)
1205+
})
1206+
));
11981207
}
11991208
let args_binding = use_def
12001209
.first_public_binding(lambda_table.symbol_id("args").expect("symbol exists"))
12011210
.unwrap();
12021211
assert!(matches!(
12031212
args_binding.kind(&db),
1204-
DefinitionKind::VariadicPositionalParameter(_)
1213+
DefinitionKind::LambdaParameter(LambdaParameterDefinitionNodeKind {
1214+
index: 3,
1215+
lambda: _,
1216+
parameter: ParameterDefinitionNodeKind::VariadicPositionalParameter(_)
1217+
})
12051218
));
12061219
let kwargs_binding = use_def
12071220
.first_public_binding(lambda_table.symbol_id("kwargs").expect("symbol exists"))
12081221
.unwrap();
12091222
assert!(matches!(
12101223
kwargs_binding.kind(&db),
1211-
DefinitionKind::VariadicKeywordParameter(_)
1224+
DefinitionKind::LambdaParameter(LambdaParameterDefinitionNodeKind {
1225+
index: 5,
1226+
lambda: _,
1227+
parameter: ParameterDefinitionNodeKind::VariadicKeywordParameter(_)
1228+
})
12121229
));
12131230
}
12141231

crates/ty_python_semantic/src/semantic_index/ast_ids.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,12 @@ pub(crate) mod node_key {
149149
}
150150
}
151151

152+
impl From<&ast::ExprLambda> for ExpressionNodeKey {
153+
fn from(value: &ast::ExprLambda) -> Self {
154+
Self(NodeKey::from_node(value))
155+
}
156+
}
157+
152158
impl From<&ast::Identifier> for ExpressionNodeKey {
153159
fn from(value: &ast::Identifier) -> Self {
154160
Self(NodeKey::from_node(value))

crates/ty_python_semantic/src/semantic_index/builder.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,14 +1673,18 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
16731673
debug_assert_eq!(existing_definition, None);
16741674
}
16751675

1676-
fn declare_lambda_parameters(&mut self, parameters: &'ast ast::Parameters) {
1676+
fn declare_lambda_parameters(
1677+
&mut self,
1678+
parameters: &'ast ast::Parameters,
1679+
lambda: &'ast ast::ExprLambda,
1680+
) {
16771681
let mut index = 0;
16781682
for parameter in &parameters.posonlyargs {
1679-
self.declare_lambda_parameter(index, parameter);
1683+
self.declare_lambda_parameter(index, parameter, lambda);
16801684
index += 1;
16811685
}
16821686
for parameter in &parameters.args {
1683-
self.declare_lambda_parameter(index, parameter);
1687+
self.declare_lambda_parameter(index, parameter, lambda);
16841688
index += 1;
16851689
}
16861690
if let Some(vararg) = parameters.vararg.as_ref() {
@@ -1692,13 +1696,14 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
16921696
symbol.into(),
16931697
LambdaParameterDefinitionNodeRef {
16941698
index,
1699+
lambda,
16951700
parameter: ParameterDefinitionNodeRef::VariadicPositionalParameter(vararg),
16961701
},
16971702
);
16981703
index += 1;
16991704
}
17001705
for parameter in &parameters.kwonlyargs {
1701-
self.declare_lambda_parameter(index, parameter);
1706+
self.declare_lambda_parameter(index, parameter, lambda);
17021707
index += 1;
17031708
}
17041709
if let Some(kwarg) = parameters.kwarg.as_ref() {
@@ -1710,6 +1715,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
17101715
symbol.into(),
17111716
LambdaParameterDefinitionNodeRef {
17121717
index,
1718+
lambda,
17131719
parameter: ParameterDefinitionNodeRef::VariadicKeywordParameter(kwarg),
17141720
},
17151721
);
@@ -1720,13 +1726,15 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
17201726
&mut self,
17211727
index: usize,
17221728
parameter: &'ast ast::ParameterWithDefault,
1729+
lambda: &'ast ast::ExprLambda,
17231730
) {
17241731
let symbol = self.add_symbol(parameter.name().id().clone());
17251732

17261733
let definition = self.add_definition(
17271734
symbol.into(),
17281735
LambdaParameterDefinitionNodeRef {
17291736
index,
1737+
lambda,
17301738
parameter: ParameterDefinitionNodeRef::Parameter(parameter),
17311739
},
17321740
);
@@ -3244,7 +3252,7 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
32443252

32453253
// Add symbols and definitions for the parameters to the lambda scope.
32463254
if let Some(parameters) = lambda.parameters.as_ref() {
3247-
self.declare_lambda_parameters(parameters);
3255+
self.declare_lambda_parameters(parameters, lambda);
32483256
}
32493257

32503258
self.visit_expr(lambda.body.as_ref());

crates/ty_python_semantic/src/semantic_index/definition.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -392,13 +392,13 @@ impl<'ast, 'db> From<ComprehensionDefinitionNodeRef<'ast, 'db>> for DefinitionNo
392392
}
393393
}
394394

395-
impl<'ast, 'db> From<ParameterDefinitionNodeRef<'ast>> for DefinitionNodeRef<'ast, 'db> {
395+
impl<'ast> From<ParameterDefinitionNodeRef<'ast>> for DefinitionNodeRef<'ast, '_> {
396396
fn from(node: ParameterDefinitionNodeRef<'ast>) -> Self {
397397
Self::Parameter(node)
398398
}
399399
}
400400

401-
impl<'ast, 'db> From<LambdaParameterDefinitionNodeRef<'ast>> for DefinitionNodeRef<'ast, 'db> {
401+
impl<'ast> From<LambdaParameterDefinitionNodeRef<'ast>> for DefinitionNodeRef<'ast, '_> {
402402
fn from(node: LambdaParameterDefinitionNodeRef<'ast>) -> Self {
403403
Self::LambdaParameter(node)
404404
}
@@ -546,8 +546,9 @@ impl ParameterDefinitionNodeRef<'_> {
546546

547547
#[derive(Copy, Clone, Debug)]
548548
pub(crate) struct LambdaParameterDefinitionNodeRef<'ast> {
549-
index: usize,
550-
parameter: ParameterDefinitionNodeRef<'ast>,
549+
pub(crate) index: usize,
550+
pub(crate) parameter: ParameterDefinitionNodeRef<'ast>,
551+
pub(crate) lambda: &'ast ast::ExprLambda,
551552
}
552553

553554
#[derive(Copy, Clone, Debug)]
@@ -671,9 +672,11 @@ impl<'db> DefinitionNodeRef<'_, 'db> {
671672
DefinitionNodeRef::LambdaParameter(LambdaParameterDefinitionNodeRef {
672673
index,
673674
parameter,
675+
lambda,
674676
}) => DefinitionKind::LambdaParameter(LambdaParameterDefinitionNodeKind {
675677
index,
676678
parameter: parameter.into_owned(parsed),
679+
lambda: AstNodeRef::new(parsed, lambda),
677680
}),
678681
DefinitionNodeRef::WithItem(WithItemDefinitionNodeRef {
679682
unpack,
@@ -1228,6 +1231,7 @@ impl ParameterDefinitionNodeKind {
12281231
#[derive(Clone, Debug, get_size2::GetSize)]
12291232
pub struct LambdaParameterDefinitionNodeKind {
12301233
pub(crate) index: usize,
1234+
pub(crate) lambda: AstNodeRef<ast::ExprLambda>,
12311235
pub(crate) parameter: ParameterDefinitionNodeKind,
12321236
}
12331237

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -824,31 +824,34 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
824824
}
825825
DefinitionKind::LambdaParameter(LambdaParameterDefinitionNodeKind {
826826
index,
827+
lambda,
827828
parameter: ParameterDefinitionNodeKind::VariadicPositionalParameter(parameter),
828829
}) => {
829830
self.infer_variadic_positional_lambda_parameter_definition(
830831
*index,
831832
parameter.node(self.module()),
833+
lambda.node(self.module()),
832834
definition,
833835
);
834836
}
835837
DefinitionKind::LambdaParameter(LambdaParameterDefinitionNodeKind {
836-
index,
837838
parameter: ParameterDefinitionNodeKind::VariadicKeywordParameter(parameter),
839+
..
838840
}) => {
839841
self.infer_variadic_keyword_lambda_parameter_definition(
840-
*index,
841842
parameter.node(self.module()),
842843
definition,
843844
);
844845
}
845846
DefinitionKind::LambdaParameter(LambdaParameterDefinitionNodeKind {
846847
index,
848+
lambda,
847849
parameter: ParameterDefinitionNodeKind::Parameter(parameter_with_default),
848850
}) => {
849851
self.infer_lambda_parameter_definition(
850852
*index,
851853
parameter_with_default.node(self.module()),
854+
lambda.node(self.module()),
852855
definition,
853856
);
854857
}

0 commit comments

Comments
 (0)