Skip to content

Commit 12cc8f8

Browse files
committed
Python: Fix cross-type literal matching in pattern comparator
Literal comparison (`_compare_literal`) previously only checked value equality, which could produce false positives when different literal types share the same internal representation (e.g., `None`, `...`, and unicode-escaped strings all store `value=None`). The fix adds two-level comparison: 1. Reject immediately when value types differ (NoneType vs bytes, etc.) 2. Fall back to `value_source` comparison when both values are `None` to distinguish `None` from `...` and unicode-escaped literals. This prevents patterns like `{x} == None` from matching byte string comparisons like `x == b""`.
1 parent 7736e1b commit 12cc8f8

2 files changed

Lines changed: 93 additions & 1 deletion

File tree

rewrite-python/rewrite/src/rewrite/python/template/comparator.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,19 @@ def _compare_identifier(self, pattern: j.Identifier, target: j.Identifier) -> bo
222222
return pattern.simple_name == target.simple_name
223223

224224
def _compare_literal(self, pattern: j.Literal, target: j.Literal) -> bool:
225-
"""Compare two literals."""
225+
"""Compare two literals.
226+
227+
Uses a two-level comparison:
228+
1. Value types must match (prevents cross-type false positives like
229+
None vs b"" where both might serialize to the same representation).
230+
2. When both values are None (Ellipsis, None keyword, or literals with
231+
unicode escapes all store value=None), fall back to value_source
232+
comparison to distinguish them.
233+
"""
234+
if type(pattern.value) != type(target.value):
235+
return False
236+
if pattern.value is None:
237+
return pattern.value_source == target.value_source
226238
return pattern.value == target.value
227239

228240
def _compare_method_invocation(

rewrite-python/rewrite/tests/python/template/test_comparator.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,86 @@ def test_mismatched_string_literals_no_match(self):
128128
assert result is None
129129

130130

131+
class TestCrossTypeLiteralMatching:
132+
"""Tests that literals of different Python types never match each other."""
133+
134+
def setup_method(self):
135+
TemplateEngine.clear_cache()
136+
137+
def teardown_method(self):
138+
TemplateEngine.clear_cache()
139+
140+
def test_none_does_not_match_bytes_literal(self):
141+
"""None should not match b''."""
142+
pattern_tree = TemplateEngine.get_template_tree("None", {})
143+
target_tree = TemplateEngine.get_template_tree('b""', {})
144+
cursor = _make_cursor(target_tree)
145+
146+
comparator = PatternMatchingComparator({})
147+
result = comparator.match(pattern_tree, target_tree, cursor)
148+
assert result is None
149+
150+
def test_none_does_not_match_nonempty_bytes(self):
151+
"""None should not match b'hello'."""
152+
pattern_tree = TemplateEngine.get_template_tree("None", {})
153+
target_tree = TemplateEngine.get_template_tree('b"hello"', {})
154+
cursor = _make_cursor(target_tree)
155+
156+
comparator = PatternMatchingComparator({})
157+
result = comparator.match(pattern_tree, target_tree, cursor)
158+
assert result is None
159+
160+
def test_none_does_not_match_empty_string(self):
161+
"""None should not match ''."""
162+
pattern_tree = TemplateEngine.get_template_tree("None", {})
163+
target_tree = TemplateEngine.get_template_tree('""', {})
164+
cursor = _make_cursor(target_tree)
165+
166+
comparator = PatternMatchingComparator({})
167+
result = comparator.match(pattern_tree, target_tree, cursor)
168+
assert result is None
169+
170+
def test_none_does_not_match_zero(self):
171+
"""None should not match 0."""
172+
pattern_tree = TemplateEngine.get_template_tree("None", {})
173+
target_tree = TemplateEngine.get_template_tree("0", {})
174+
cursor = _make_cursor(target_tree)
175+
176+
comparator = PatternMatchingComparator({})
177+
result = comparator.match(pattern_tree, target_tree, cursor)
178+
assert result is None
179+
180+
def test_none_matches_none(self):
181+
"""None should match None."""
182+
pattern_tree = TemplateEngine.get_template_tree("None", {})
183+
target_tree = TemplateEngine.get_template_tree("None", {})
184+
cursor = _make_cursor(target_tree)
185+
186+
comparator = PatternMatchingComparator({})
187+
result = comparator.match(pattern_tree, target_tree, cursor)
188+
assert result is not None
189+
190+
def test_none_does_not_match_ellipsis(self):
191+
"""None should not match ... (Ellipsis) — both have value=None internally."""
192+
pattern_tree = TemplateEngine.get_template_tree("None", {})
193+
target_tree = TemplateEngine.get_template_tree("...", {})
194+
cursor = _make_cursor(target_tree)
195+
196+
comparator = PatternMatchingComparator({})
197+
result = comparator.match(pattern_tree, target_tree, cursor)
198+
assert result is None
199+
200+
def test_bytes_does_not_match_string(self):
201+
"""b'hello' should not match 'hello'."""
202+
pattern_tree = TemplateEngine.get_template_tree('b"hello"', {})
203+
target_tree = TemplateEngine.get_template_tree('"hello"', {})
204+
cursor = _make_cursor(target_tree)
205+
206+
comparator = PatternMatchingComparator({})
207+
result = comparator.match(pattern_tree, target_tree, cursor)
208+
assert result is None
209+
210+
131211
class TestMethodInvocationMatching:
132212
"""Tests for method invocation comparison."""
133213

0 commit comments

Comments
 (0)