Skip to content

Commit 7c6dcd9

Browse files
[ty] Add caching for pattern match narrowing (#25613)
## Summary When analyzing a `match` statement, we currently rebuild the narrowed subject for every case by collecting all preceding unguarded patterns into a union and intersecting the subject with its negation. As pattern types become richer, this can repeatedly distribute the same intersections and make a chain of adjacent cases exponentially expensive. For example, without _this_ change, #25493 starts to surface significant slowdowns in select projects. As a concrete example, this small `match` blows up with a 9x regression after #25493, without the caching introduced here: https://github.com/kornia/kornia/blob/7c2fee7216599a5e6ef149d4ab2fe33dd70f18c3/kornia/io/io.py#L132-L156. Without caching, every branch has to recompute the prefix in `subject & ~(pattern_1 | pattern_2 | ... | pattern_k-1)`; with caching, we turn it into: ``` T_0 = subject T_i = T_i-1 & ~pattern_i ``` (Prior to #25493, the sequence patterns generally contributed `Never` here, making it inexpensive.)
1 parent 27058fc commit 7c6dcd9

1 file changed

Lines changed: 49 additions & 21 deletions

File tree

crates/ty_python_semantic/src/reachability.rs

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,7 @@ use crate::{
199199
place::{DefinedPlace, Definedness, Place, RequiresExplicitReExport, imported_symbol},
200200
types::{
201201
CallableTypes, ClassLiteral, IntersectionBuilder, KnownClass, NarrowingConstraint, Type,
202-
TypeContext, UnionBuilder, UnionType, enum_metadata, infer_expression_type,
203-
infer_narrowing_constraint,
202+
TypeContext, UnionType, enum_metadata, infer_expression_type, infer_narrowing_constraint,
204203
},
205204
};
206205
use ruff_index::IndexSlice;
@@ -297,21 +296,56 @@ fn pattern_kind_to_type<'db>(db: &'db dyn Db, kind: &PatternPredicateKind<'db>)
297296
}
298297
}
299298

300-
/// Go through the list of previous match cases, and accumulate a union of all types that were already
301-
/// matched by these patterns.
302-
fn type_excluded_by_previous_patterns<'db>(
299+
/// Narrow `subject_ty` by all preceding unguarded match patterns.
300+
///
301+
/// Caching each prefix lets the next case reuse the already-normalized subject instead of
302+
/// rebuilding it from the union of all preceding patterns, which can repeatedly distribute the
303+
/// same intersections.
304+
#[salsa::tracked(
305+
cycle_initial = |_, id, _, _| Type::divergent(id),
306+
cycle_fn = |db, cycle, previous: &Type<'db>, result: Type<'db>, _, _| {
307+
result.cycle_normalized(db, *previous, cycle)
308+
},
309+
heap_size = ruff_memory_usage::heap_size
310+
)]
311+
fn type_narrowed_by_previous_patterns<'db>(
303312
db: &'db dyn Db,
304-
mut predicate: PatternPredicate<'db>,
313+
predicate: PatternPredicate<'db>,
314+
subject_ty: Type<'db>,
305315
) -> Type<'db> {
306-
let mut builder = UnionBuilder::new(db);
307-
while let Some(previous) = predicate.previous_predicate(db) {
308-
predicate = *previous;
316+
let Some(previous) = predicate.previous_predicate(db) else {
317+
return subject_ty;
318+
};
319+
let previous = *previous;
320+
let narrowed_by_previous_patterns =
321+
type_narrowed_by_previous_patterns(db, previous, subject_ty);
309322

310-
if predicate.guard(db).is_none() {
311-
builder = builder.add(pattern_kind_to_type(db, predicate.kind(db)));
312-
}
323+
if previous.guard(db).is_some() {
324+
narrowed_by_previous_patterns
325+
} else {
326+
type_narrowed_by_pattern(db, previous, narrowed_by_previous_patterns)
313327
}
314-
builder.build()
328+
}
329+
330+
/// Narrow `subject_ty` by a match pattern.
331+
///
332+
/// This result is also the preceding-pattern prefix for the next unguarded case.
333+
#[salsa::tracked(
334+
cycle_initial = |_, id, _, _| Type::divergent(id),
335+
cycle_fn = |db, cycle, previous: &Type<'db>, result: Type<'db>, _, _| {
336+
result.cycle_normalized(db, *previous, cycle)
337+
},
338+
heap_size = ruff_memory_usage::heap_size
339+
)]
340+
fn type_narrowed_by_pattern<'db>(
341+
db: &'db dyn Db,
342+
predicate: PatternPredicate<'db>,
343+
subject_ty: Type<'db>,
344+
) -> Type<'db> {
345+
IntersectionBuilder::new(db)
346+
.add_positive(subject_ty)
347+
.add_negative(pattern_kind_to_type(db, predicate.kind(db)))
348+
.build()
315349
}
316350

317351
/// Return the enum class and canonical member names represented by an enum-literal subject type.
@@ -506,11 +540,7 @@ fn analyze_pattern_predicate<'db>(db: &'db dyn Db, predicate: PatternPredicate<'
506540
return truthiness;
507541
}
508542

509-
let narrowed_subject = IntersectionBuilder::new(db)
510-
.add_positive(subject_ty)
511-
.add_negative(type_excluded_by_previous_patterns(db, predicate));
512-
513-
let narrowed_subject_ty = narrowed_subject.clone().build();
543+
let narrowed_subject_ty = type_narrowed_by_previous_patterns(db, predicate, subject_ty);
514544

515545
// Consider a case where we match on a subject type of `Self` with an upper bound of `Answer`,
516546
// where `Answer` is a {YES, NO} enum. After a previous pattern matching on `NO`, the narrowed
@@ -521,9 +551,7 @@ fn analyze_pattern_predicate<'db>(db: &'db dyn Db, predicate: PatternPredicate<'
521551
// means that subsequent patterns can never match. And we know that if we reach this point,
522552
// the current pattern will have to match. We return `AlwaysTrue` here, since the call to
523553
// `analyze_single_pattern_predicate_kind` below would return `Ambiguous` in this case.
524-
let next_narrowed_subject_ty = narrowed_subject
525-
.add_negative(pattern_kind_to_type(db, predicate.kind(db)))
526-
.build();
554+
let next_narrowed_subject_ty = type_narrowed_by_pattern(db, predicate, narrowed_subject_ty);
527555
if !narrowed_subject_ty.is_never() && next_narrowed_subject_ty.is_never() {
528556
return Truthiness::AlwaysTrue;
529557
}

0 commit comments

Comments
 (0)