Skip to content

Commit 6f3b72c

Browse files
authored
Fixed bug in type narrowing logic for sequence pattern matching. This addresses #11056. (#11085)
1 parent 89825e1 commit 6f3b72c

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

packages/pyright-internal/src/analyzer/patternMatching.ts

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1414,23 +1414,28 @@ function getSequencePatternInfo(
14141414
{ type: UnknownType.create(), isUnbounded: true },
14151415
];
14161416

1417-
const tupleIndeterminateIndex = typeArgs.findIndex(
1417+
let tupleIndeterminateIndex = typeArgs.findIndex(
14181418
(t) => t.isUnbounded || isUnpackedTypeVarTuple(t.type) || isUnpackedTypeVar(t.type)
14191419
);
14201420

14211421
let tupleDeterminateEntryCount = typeArgs.length;
14221422

14231423
// If the tuple contains an indeterminate entry, expand or remove that
14241424
// entry to match the length of the pattern if possible.
1425+
let expandedIndeterminate = false;
14251426
if (tupleIndeterminateIndex >= 0) {
14261427
tupleDeterminateEntryCount--;
14271428

14281429
while (typeArgs.length < patternEntryCount) {
14291430
typeArgs.splice(tupleIndeterminateIndex, 0, typeArgs[tupleIndeterminateIndex]);
1431+
tupleDeterminateEntryCount++;
1432+
tupleIndeterminateIndex++;
1433+
expandedIndeterminate = true;
14301434
}
14311435

14321436
if (typeArgs.length > patternEntryCount && patternStarEntryIndex === undefined) {
14331437
typeArgs.splice(tupleIndeterminateIndex, 1);
1438+
tupleIndeterminateIndex = -1;
14341439
}
14351440
}
14361441

@@ -1449,6 +1454,20 @@ function getSequencePatternInfo(
14491454
(t) => t.isUnbounded || isUnpackedTypeVarTuple(t.type) || isUnpackedTypeVar(t.type)
14501455
),
14511456
});
1457+
1458+
tupleDeterminateEntryCount -= entriesToCombine;
1459+
if (!typeArgs[patternStarEntryIndex].isUnbounded) {
1460+
tupleDeterminateEntryCount++;
1461+
}
1462+
1463+
// If the collapsed range included the tupleIndeterminateIndex, adjust
1464+
// it to reflect the new collapsed entry.
1465+
if (
1466+
tupleIndeterminateIndex >= patternStarEntryIndex &&
1467+
tupleIndeterminateIndex < patternStarEntryIndex + entriesToCombine
1468+
) {
1469+
tupleIndeterminateIndex = patternStarEntryIndex;
1470+
}
14521471
}
14531472

14541473
if (typeArgs.length === patternEntryCount) {
@@ -1459,6 +1478,7 @@ function getSequencePatternInfo(
14591478
// indeterminate-length entry that aligns to the star entry, we can
14601479
// assume it will always match.
14611480
if (
1481+
!expandedIndeterminate &&
14621482
patternStarEntryIndex !== undefined &&
14631483
tupleIndeterminateIndex >= 0 &&
14641484
pattern.d.entries.length - 1 === tupleDeterminateEntryCount &&

packages/pyright-internal/src/tests/samples/matchSequence1.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,38 @@ def test_unbounded_tuple_6(subj: tuple[str, ...]):
627627
reveal_type(r, expected_text="tuple[str, ...]")
628628

629629

630+
def test_unbound_tuple_7(subj: tuple[str, Unpack[tuple[object, ...]], int]):
631+
match subj:
632+
case (*args,):
633+
reveal_type(args, expected_text="list[str | object | int]")
634+
case a:
635+
reveal_type(a, expected_text="Never")
636+
637+
match subj:
638+
case (*args, last):
639+
reveal_type(args, expected_text="list[str | object]")
640+
reveal_type(last, expected_text="int")
641+
case a:
642+
reveal_type(a, expected_text="Never")
643+
644+
match subj:
645+
case (first, *args, last):
646+
reveal_type(first, expected_text="str")
647+
reveal_type(args, expected_text="list[object]")
648+
reveal_type(last, expected_text="int")
649+
case a:
650+
reveal_type(a, expected_text="Never")
651+
652+
match subj:
653+
case (first, second, *args, last):
654+
reveal_type(first, expected_text="str")
655+
reveal_type(second, expected_text="object")
656+
reveal_type(args, expected_text="list[object]")
657+
reveal_type(last, expected_text="int")
658+
case a:
659+
reveal_type(a, expected_text="tuple[str, *tuple[object, ...], int]")
660+
661+
630662
def test_variadic_tuple(subj: tuple[int, Unpack[Ts]]) -> tuple[Unpack[Ts]]:
631663
match subj:
632664
case _, *rest:

0 commit comments

Comments
 (0)