Skip to content

Commit 8804a2c

Browse files
committed
add statement-level inference to avoid scope cycles
1 parent c49de7a commit 8804a2c

7 files changed

Lines changed: 472 additions & 31 deletions

File tree

crates/ty_python_semantic/resources/mdtest/bidirectional.md

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,9 @@ from typing import Callable, TypedDict
428428
class Bar(TypedDict):
429429
bar: int
430430

431+
def id[T](x: T) -> T:
432+
return x
433+
431434
f1 = lambda x: {"bar": 1}
432435
reveal_type(f1) # revealed: (x) -> dict[str, int]
433436

@@ -442,34 +445,46 @@ reveal_type(f3) # revealed: (int, /) -> Bar
442445
f4: Callable[[str], str] = lambda x: reveal_type(x) # revealed: str
443446
reveal_type(f4) # revealed: (x: str) -> str
444447

448+
f5: Callable[[str], str] = id(lambda x: reveal_type(x)) # revealed: str
449+
reveal_type(f5) # revealed: (x: str) -> str
450+
445451
# TODO: This should not error once we support `Unpack`.
446452
# error: [invalid-assignment]
447-
f5: Callable[[*tuple[int, ...]], None] = lambda x, y, z: None
448-
reveal_type(f5) # revealed: (tuple[int, ...], /) -> None
453+
f6: Callable[[*tuple[int, ...]], None] = lambda x, y, z: None
454+
reveal_type(f6) # revealed: (tuple[int, ...], /) -> None
449455

450-
f6: Callable[[int, str], None] = lambda *args: None
451-
reveal_type(f6) # revealed: (*args) -> None
456+
f7: Callable[[int, str], None] = lambda *args: None
457+
reveal_type(f7) # revealed: (*args) -> None
452458

453459
# N.B. `Callable` annotations only support positional parameters.
454460
# error: [invalid-assignment]
455-
f7: Callable[[int], None] = lambda *, x=1: None
456-
reveal_type(f7) # revealed: (int, /) -> None
461+
f8: Callable[[int], None] = lambda *, x=1: None
462+
reveal_type(f8) # revealed: (int, /) -> None
457463

458464
# TODO: This should reveal `(*args: int, *, x=1) -> None` once we support `Unpack`.
459-
f8: Callable[[*tuple[int, ...], int], None] = lambda *args, x=1: None
460-
reveal_type(f8) # revealed: (*args, *, x=1) -> None
465+
f9: Callable[[*tuple[int, ...], int], None] = lambda *args, x=1: None
466+
reveal_type(f9) # revealed: (*args, *, x=1) -> None
461467

462-
f9: Callable[[str, int, str], tuple[str, int, str]] = lambda x, y, z: reveal_type((x, y, z)) # revealed: tuple[str, int, str]
463-
reveal_type(f9) # revealed: (x: str, y: int, z: str) -> tuple[str, int, str]
468+
f10: Callable[[str, int, str], tuple[str, int, str]] = lambda x, y, z: reveal_type((x, y, z)) # revealed: tuple[str, int, str]
469+
reveal_type(f10) # revealed: (x: str, y: int, z: str) -> tuple[str, int, str]
464470

465471
# TODO: This should reveal `tuple[int, ...]` once we support `Unpack`.
466-
f10: Callable[[*tuple[int, ...]], tuple[int, ...]] = lambda *args: reveal_type(args) # revealed: tuple[Unknown, ...]
467-
reveal_type(f10) # revealed: (*args) -> tuple[Unknown, ...]
472+
f11: Callable[[*tuple[int, ...]], tuple[int, ...]] = lambda *args: reveal_type(args) # revealed: tuple[Unknown, ...]
473+
reveal_type(f11) # revealed: (*args) -> tuple[Unknown, ...]
468474

469475
# TODO: Better generic call inference.
470476
def _(x: list[int]):
471-
f11 = list(map(lambda y: y + 1, x))
472-
reveal_type(f11) # revealed: list[Unknown]
477+
f12 = list(map(lambda y: y + 1, x))
478+
reveal_type(f12) # revealed: list[Unknown]
479+
480+
def _() -> Callable[[int], int]:
481+
return id(lambda x: reveal_type(x)) # revealed: int
482+
483+
def _():
484+
def takes_callable(_: Callable[[int], int]): ...
485+
486+
takes_callable(lambda x: reveal_type(x)) # revealed: int
487+
takes_callable(id(id(lambda x: reveal_type(x)))) # revealed: int
473488
```
474489

475490
We do not currently account for type annotations present later in the scope:

crates/ty_python_semantic/src/semantic_index.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pub use crate::semantic_index::scope::FileScopeId;
2727
use crate::semantic_index::scope::{
2828
NodeWithScopeKey, NodeWithScopeRef, Scope, ScopeId, ScopeKind, ScopeLaziness,
2929
};
30+
use crate::semantic_index::statement::{Statement, StatementNodeKey};
3031
use crate::semantic_index::symbol::ScopedSymbolId;
3132
use crate::semantic_index::use_def::{EnclosingSnapshotKey, ScopedEnclosingSnapshotId, UseDefMap};
3233
use crate::semantic_model::HasTrackedScope;
@@ -42,6 +43,7 @@ pub(crate) mod predicate;
4243
mod re_exports;
4344
mod reachability_constraints;
4445
pub(crate) mod scope;
46+
pub mod statement;
4547
pub(crate) mod symbol;
4648
mod use_def;
4749

@@ -321,9 +323,15 @@ pub(crate) struct SemanticIndex<'db> {
321323
/// Map from a standalone expression to its [`Expression`] ingredient.
322324
expressions_by_node: FxHashMap<ExpressionNodeKey, Expression<'db>>,
323325

326+
/// Map from a standalone statemetn to its [`Statement`] ingredient.
327+
statements_by_node: FxHashMap<StatementNodeKey, Statement<'db>>,
328+
324329
/// Map from nodes that create a scope to the scope they create.
325330
scopes_by_node: FxHashMap<NodeWithScopeKey, FileScopeId>,
326331

332+
/// Maps from lambda expressions to their containing statement.
333+
enclosing_lambda_statements: FxHashMap<ExpressionNodeKey, Statement<'db>>,
334+
327335
/// Map from the file-local [`FileScopeId`] to the salsa-ingredient [`ScopeId`].
328336
scope_ids_by_scope: IndexVec<FileScopeId, ScopeId<'db>>,
329337

@@ -473,6 +481,13 @@ impl<'db> SemanticIndex<'db> {
473481
.map(|node_ref| self.expect_single_definition(node_ref))
474482
}
475483

484+
pub(crate) fn enclosing_lambda_statement(
485+
&self,
486+
lambda: ExpressionNodeKey,
487+
) -> Option<Statement<'db>> {
488+
self.enclosing_lambda_statements.get(&lambda).copied()
489+
}
490+
476491
pub(crate) fn is_scope_reachable(&self, db: &'db dyn Db, scope_id: FileScopeId) -> bool {
477492
self.parent_scope_id(scope_id)
478493
.is_none_or(|parent_scope_id| {
@@ -601,6 +616,13 @@ impl<'db> SemanticIndex<'db> {
601616
.contains_key(&expression_key.into())
602617
}
603618

619+
pub(crate) fn try_statement(
620+
&self,
621+
statement_key: impl Into<StatementNodeKey>,
622+
) -> Option<Statement<'db>> {
623+
self.statements_by_node.get(&statement_key.into()).copied()
624+
}
625+
604626
/// Returns the id of the scope that `node` creates.
605627
/// This is different from [`definition::Definition::scope`] which
606628
/// returns the scope in which that definition is defined in.

crates/ty_python_semantic/src/semantic_index/builder.rs

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ use crate::semantic_index::scope::{
4747
FileScopeId, NodeWithScopeKey, NodeWithScopeKind, NodeWithScopeRef,
4848
};
4949
use crate::semantic_index::scope::{Scope, ScopeId, ScopeKind, ScopeLaziness};
50+
use crate::semantic_index::statement::{Statement, StatementNodeKey};
5051
use crate::semantic_index::symbol::{ScopedSymbolId, Symbol};
5152
use crate::semantic_index::use_def::{
5253
EnclosingSnapshotKey, FlowSnapshot, PreviousDefinitions, ScopedDefinitionId,
@@ -100,6 +101,8 @@ pub(super) struct SemanticIndexBuilder<'db, 'ast> {
100101
/// The assignments we're currently visiting, with
101102
/// the most recent visit at the end of the Vec
102103
current_assignments: Vec<CurrentAssignment<'ast, 'db>>,
104+
/// The statement we're currently visiting.
105+
current_statement: Option<(&'ast ast::Stmt, FileScopeId)>,
103106
/// The match case we're currently visiting.
104107
current_match_case: Option<CurrentMatchCase<'ast>>,
105108
/// The name of the first function parameter of the innermost function that we're currently visiting.
@@ -129,6 +132,8 @@ pub(super) struct SemanticIndexBuilder<'db, 'ast> {
129132
scopes_by_expression: ExpressionsScopeMapBuilder,
130133
definitions_by_node: FxHashMap<DefinitionNodeKey, Definitions<'db>>,
131134
expressions_by_node: FxHashMap<ExpressionNodeKey, Expression<'db>>,
135+
statements_by_node: FxHashMap<StatementNodeKey, Statement<'db>>,
136+
enclosing_lambda_statements: FxHashMap<ExpressionNodeKey, Statement<'db>>,
132137
imported_modules: FxHashSet<ModuleName>,
133138
seen_submodule_imports: FxHashSet<String>,
134139
/// Hashset of all [`FileScopeId`]s that correspond to [generator functions].
@@ -150,6 +155,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
150155
module: module_ref,
151156
scope_stack: Vec::new(),
152157
current_assignments: vec![],
158+
current_statement: None,
153159
current_match_case: None,
154160
current_first_parameter_name: None,
155161
try_node_context_stack_manager: TryNodeContextStackManager::default(),
@@ -167,6 +173,8 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
167173
scopes_by_node: FxHashMap::default(),
168174
definitions_by_node: FxHashMap::default(),
169175
expressions_by_node: FxHashMap::default(),
176+
statements_by_node: FxHashMap::default(),
177+
enclosing_lambda_statements: FxHashMap::default(),
170178

171179
seen_submodule_imports: FxHashSet::default(),
172180
imported_modules: FxHashSet::default(),
@@ -1493,6 +1501,22 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
14931501
expression
14941502
}
14951503

1504+
fn add_standalone_statement(
1505+
&mut self,
1506+
statement_node: &ast::Stmt,
1507+
scope: FileScopeId,
1508+
) -> Statement<'db> {
1509+
let statement = Statement::new(
1510+
self.db,
1511+
self.file,
1512+
scope,
1513+
AstNodeRef::new(self.module, statement_node),
1514+
);
1515+
self.statements_by_node
1516+
.insert(statement_node.into(), statement);
1517+
statement
1518+
}
1519+
14961520
fn with_type_params(
14971521
&mut self,
14981522
with_scope: NodeWithScopeRef,
@@ -1858,6 +1882,8 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
18581882
use_def_maps.shrink_to_fit();
18591883
ast_ids.shrink_to_fit();
18601884
self.definitions_by_node.shrink_to_fit();
1885+
self.statements_by_node.shrink_to_fit();
1886+
self.enclosing_lambda_statements.shrink_to_fit();
18611887

18621888
self.scope_ids_by_scope.shrink_to_fit();
18631889
self.scopes_by_node.shrink_to_fit();
@@ -1869,11 +1895,13 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
18691895
scopes: self.scopes,
18701896
definitions_by_node: self.definitions_by_node,
18711897
expressions_by_node: self.expressions_by_node,
1898+
statements_by_node: self.statements_by_node,
18721899
scope_ids_by_scope: self.scope_ids_by_scope,
18731900
ast_ids,
18741901
scopes_by_expression: self.scopes_by_expression.build(),
18751902
scopes_by_node: self.scopes_by_node,
18761903
use_def_maps,
1904+
enclosing_lambda_statements: self.enclosing_lambda_statements,
18771905
imported_modules: Arc::new(self.imported_modules),
18781906
has_future_annotations: self.has_future_annotations,
18791907
enclosing_snapshots: self.enclosing_snapshots,
@@ -1892,10 +1920,8 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
18921920
self.source_text
18931921
.get_or_init(|| source_text(self.db, self.file))
18941922
}
1895-
}
18961923

1897-
impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
1898-
fn visit_stmt(&mut self, stmt: &'ast ast::Stmt) {
1924+
fn visit_stmt_impl(&mut self, stmt: &'ast ast::Stmt) {
18991925
self.with_semantic_checker(|semantic, context| semantic.visit_stmt(stmt, context));
19001926

19011927
self.current_use_def_map_mut()
@@ -3107,6 +3133,14 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
31073133
}
31083134
}
31093135
}
3136+
}
3137+
3138+
impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
3139+
fn visit_stmt(&mut self, stmt: &'ast ast::Stmt) {
3140+
self.current_statement = Some((stmt, self.current_scope()));
3141+
self.visit_stmt_impl(stmt);
3142+
self.current_statement = None;
3143+
}
31103144

31113145
fn visit_keyword(&mut self, keyword: &'ast ast::Keyword) {
31123146
walk_keyword(self, keyword);
@@ -3238,6 +3272,15 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
32383272
}
32393273
}
32403274
ast::Expr::Lambda(lambda) => {
3275+
// The body of a lambda expression needs access to the `Callable` type
3276+
// context the lambda is being inferred with, and so any statement
3277+
// containing a lambda must be inferable as a standalone statement.
3278+
if let Some((stmt, scope)) = self.current_statement {
3279+
let standalone_stmt = self.add_standalone_statement(stmt, scope);
3280+
self.enclosing_lambda_statements
3281+
.insert(lambda.into(), standalone_stmt);
3282+
}
3283+
32413284
if let Some(parameters) = &lambda.parameters {
32423285
// The default value of the parameters needs to be evaluated in the
32433286
// enclosing scope.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
use crate::ast_node_ref::AstNodeRef;
2+
use crate::db::Db;
3+
use crate::node_key::NodeKey;
4+
use crate::semantic_index::scope::{FileScopeId, ScopeId};
5+
use ruff_db::files::File;
6+
use ruff_python_ast as ast;
7+
use salsa;
8+
9+
/// An independently type-inferable statement.
10+
///
11+
/// ## Module-local type
12+
/// This type should not be used as part of any cross-module API because
13+
/// it holds a reference to the AST node. Range-offset changes
14+
/// then propagate through all usages, and deserialization requires
15+
/// reparsing the entire module.
16+
///
17+
/// E.g. don't use this type in:
18+
///
19+
/// * a return type of a cross-module query
20+
/// * a field of a type that is a return type of a cross-module query
21+
/// * an argument of a cross-module query
22+
#[salsa::tracked(debug, heap_size=ruff_memory_usage::heap_size)]
23+
pub(crate) struct Statement<'db> {
24+
/// The file in which the statement occurs.
25+
pub(crate) file: File,
26+
27+
/// The scope in which the statement occurs.
28+
pub(crate) file_scope: FileScopeId,
29+
30+
/// The statement node.
31+
#[no_eq]
32+
#[tracked]
33+
#[returns(ref)]
34+
pub(crate) node_ref: AstNodeRef<ast::Stmt>,
35+
}
36+
37+
// The Salsa heap is tracked separately.
38+
impl get_size2::GetSize for Statement<'_> {}
39+
40+
impl<'db> Statement<'db> {
41+
pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> {
42+
self.file_scope(db).to_scope_id(db, self.file(db))
43+
}
44+
}
45+
46+
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, salsa::Update, get_size2::GetSize)]
47+
pub(crate) struct StatementNodeKey(NodeKey);
48+
49+
impl From<&ast::Stmt> for StatementNodeKey {
50+
fn from(node: &ast::Stmt) -> Self {
51+
Self(NodeKey::from_node(node))
52+
}
53+
}

0 commit comments

Comments
 (0)