Skip to content

Commit 8414b25

Browse files
committed
Python: Support ternary comparisons in PatternMatchingComparator
1 parent bbede7e commit 8414b25

2 files changed

Lines changed: 81 additions & 0 deletions

File tree

rewrite-python/rewrite/src/rewrite/python/template/comparator.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ def _compare(
158158
return self._compare_assignment(pattern, cast(j.Assignment, target), cursor)
159159
elif isinstance(pattern, j.Parentheses):
160160
return self._compare_parentheses(pattern, cast(j.Parentheses, target), cursor)
161+
elif isinstance(pattern, j.Ternary):
162+
return self._compare_ternary(pattern, cast(j.Ternary, target), cursor)
161163
elif isinstance(pattern, j.Return):
162164
return self._compare_return(pattern, cast(j.Return, target), cursor)
163165
elif isinstance(pattern, py.ExpressionStatement):
@@ -344,6 +346,21 @@ def _compare_assignment(
344346

345347
return self._compare(pattern.assignment, target.assignment, cursor)
346348

349+
def _compare_ternary(
350+
self,
351+
pattern: j.Ternary,
352+
target: j.Ternary,
353+
cursor: 'Cursor'
354+
) -> bool:
355+
"""Compare two ternary (conditional) expressions."""
356+
if not self._compare(pattern.condition, target.condition, cursor):
357+
return False
358+
359+
if not self._compare(pattern.true_part, target.true_part, cursor):
360+
return False
361+
362+
return self._compare(pattern.false_part, target.false_part, cursor)
363+
347364
def _compare_parentheses(
348365
self,
349366
pattern: j.Parentheses,

rewrite-python/rewrite/tests/python/template/test_comparator.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,70 @@ def test_dict_element_count_mismatch(self):
590590
assert result is None
591591

592592

593+
class TestTernaryMatching:
594+
"""Tests for ternary (conditional) expression comparison."""
595+
596+
def setup_method(self):
597+
TemplateEngine.clear_cache()
598+
599+
def teardown_method(self):
600+
TemplateEngine.clear_cache()
601+
602+
def test_placeholder_ternary_captures(self):
603+
"""{a} if {cond} else {b} should capture all three parts."""
604+
captures = {'a': capture('a'), 'cond': capture('cond'), 'b': capture('b')}
605+
pattern_tree = TemplateEngine.get_template_tree("{a} if {cond} else {b}", captures)
606+
target_tree = TemplateEngine.get_template_tree("x if flag else y", {})
607+
cursor = _make_cursor(target_tree)
608+
609+
comparator = PatternMatchingComparator(captures)
610+
result = comparator.match(pattern_tree, target_tree, cursor)
611+
assert result is not None
612+
assert 'a' in result
613+
assert 'cond' in result
614+
assert 'b' in result
615+
616+
def test_concrete_ternary_match(self):
617+
"""x if True else y should match x if True else y."""
618+
pattern_tree = TemplateEngine.get_template_tree("x if True else y", {})
619+
target_tree = TemplateEngine.get_template_tree("x if True else y", {})
620+
cursor = _make_cursor(target_tree)
621+
622+
comparator = PatternMatchingComparator({})
623+
result = comparator.match(pattern_tree, target_tree, cursor)
624+
assert result is not None
625+
626+
def test_ternary_condition_mismatch(self):
627+
"""x if True else y should not match x if False else y."""
628+
pattern_tree = TemplateEngine.get_template_tree("x if True else y", {})
629+
target_tree = TemplateEngine.get_template_tree("x if False else y", {})
630+
cursor = _make_cursor(target_tree)
631+
632+
comparator = PatternMatchingComparator({})
633+
result = comparator.match(pattern_tree, target_tree, cursor)
634+
assert result is None
635+
636+
def test_ternary_true_part_mismatch(self):
637+
"""x if True else y should not match z if True else y."""
638+
pattern_tree = TemplateEngine.get_template_tree("x if True else y", {})
639+
target_tree = TemplateEngine.get_template_tree("z if True else y", {})
640+
cursor = _make_cursor(target_tree)
641+
642+
comparator = PatternMatchingComparator({})
643+
result = comparator.match(pattern_tree, target_tree, cursor)
644+
assert result is None
645+
646+
def test_ternary_false_part_mismatch(self):
647+
"""x if True else y should not match x if True else z."""
648+
pattern_tree = TemplateEngine.get_template_tree("x if True else y", {})
649+
target_tree = TemplateEngine.get_template_tree("x if True else z", {})
650+
cursor = _make_cursor(target_tree)
651+
652+
comparator = PatternMatchingComparator({})
653+
result = comparator.match(pattern_tree, target_tree, cursor)
654+
assert result is None
655+
656+
593657
class TestDefaultFallthrough:
594658
"""Tests for the default comparison behavior on unrecognized types."""
595659

0 commit comments

Comments
 (0)