Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions rewrite-python/rewrite/src/rewrite/python/template/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ def _compare(
return self._compare_assignment(pattern, cast(j.Assignment, target), cursor)
elif isinstance(pattern, j.Parentheses):
return self._compare_parentheses(pattern, cast(j.Parentheses, target), cursor)
elif isinstance(pattern, j.Ternary):
return self._compare_ternary(pattern, cast(j.Ternary, target), cursor)
elif isinstance(pattern, j.Return):
return self._compare_return(pattern, cast(j.Return, target), cursor)
elif isinstance(pattern, py.ExpressionStatement):
Expand Down Expand Up @@ -344,6 +346,21 @@ def _compare_assignment(

return self._compare(pattern.assignment, target.assignment, cursor)

def _compare_ternary(
self,
pattern: j.Ternary,
target: j.Ternary,
cursor: 'Cursor'
) -> bool:
"""Compare two ternary (conditional) expressions."""
if not self._compare(pattern.condition, target.condition, cursor):
return False

if not self._compare(pattern.true_part, target.true_part, cursor):
return False

return self._compare(pattern.false_part, target.false_part, cursor)

def _compare_parentheses(
self,
pattern: j.Parentheses,
Expand Down
64 changes: 64 additions & 0 deletions rewrite-python/rewrite/tests/python/template/test_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,70 @@ def test_dict_element_count_mismatch(self):
assert result is None


class TestTernaryMatching:
"""Tests for ternary (conditional) expression comparison."""

def setup_method(self):
TemplateEngine.clear_cache()

def teardown_method(self):
TemplateEngine.clear_cache()

def test_placeholder_ternary_captures(self):
"""{a} if {cond} else {b} should capture all three parts."""
captures = {'a': capture('a'), 'cond': capture('cond'), 'b': capture('b')}
pattern_tree = TemplateEngine.get_template_tree("{a} if {cond} else {b}", captures)
target_tree = TemplateEngine.get_template_tree("x if flag else y", {})
cursor = _make_cursor(target_tree)

comparator = PatternMatchingComparator(captures)
result = comparator.match(pattern_tree, target_tree, cursor)
assert result is not None
assert 'a' in result
assert 'cond' in result
assert 'b' in result

def test_concrete_ternary_match(self):
"""x if True else y should match x if True else y."""
pattern_tree = TemplateEngine.get_template_tree("x if True else y", {})
target_tree = TemplateEngine.get_template_tree("x if True else y", {})
cursor = _make_cursor(target_tree)

comparator = PatternMatchingComparator({})
result = comparator.match(pattern_tree, target_tree, cursor)
assert result is not None

def test_ternary_condition_mismatch(self):
"""x if True else y should not match x if False else y."""
pattern_tree = TemplateEngine.get_template_tree("x if True else y", {})
target_tree = TemplateEngine.get_template_tree("x if False else y", {})
cursor = _make_cursor(target_tree)

comparator = PatternMatchingComparator({})
result = comparator.match(pattern_tree, target_tree, cursor)
assert result is None

def test_ternary_true_part_mismatch(self):
"""x if True else y should not match z if True else y."""
pattern_tree = TemplateEngine.get_template_tree("x if True else y", {})
target_tree = TemplateEngine.get_template_tree("z if True else y", {})
cursor = _make_cursor(target_tree)

comparator = PatternMatchingComparator({})
result = comparator.match(pattern_tree, target_tree, cursor)
assert result is None

def test_ternary_false_part_mismatch(self):
"""x if True else y should not match x if True else z."""
pattern_tree = TemplateEngine.get_template_tree("x if True else y", {})
target_tree = TemplateEngine.get_template_tree("x if True else z", {})
cursor = _make_cursor(target_tree)

comparator = PatternMatchingComparator({})
result = comparator.match(pattern_tree, target_tree, cursor)
assert result is None


class TestDefaultFallthrough:
"""Tests for the default comparison behavior on unrecognized types."""

Expand Down