Skip to content

Commit 3abfc78

Browse files
committed
Extract and strengthen signature-compatibility logic
Move continuation signature-compatibility checks into a dedicated mockito.sameish subsystem and make Mock delegate to it. The new logic compares signatures structurally (including matcher internals) instead of calling matcher .matches() across stub signatures. This avoids executing user arg_that predicates during bookkeeping and prevents side effects/crashes from predicate evaluation against matcher objects. Add focused unit coverage in tests/sameish_test.py. Add chaining regressions that assert branch sharing for equivalent any_(int) signatures and reused arg_that predicates. These were the original fatal regressions that motivated the effort.
1 parent 28bc30b commit 3abfc78

File tree

4 files changed

+297
-12
lines changed

4 files changed

+297
-12
lines changed

mockito/mocking.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from dataclasses import dataclass
2929
from typing import Any, AsyncIterator, Callable, Iterable, Iterator, cast
3030

31-
from . import invocation, signature, utils
31+
from . import invocation, sameish, signature, utils
3232
from . import verification as verificationModule
3333
from .mock_registry import mock_registry
3434
from .patching import Patch, patcher
@@ -407,12 +407,12 @@ def set_continuation(self, continuation: invocation.ConfiguredContinuation) -> N
407407
def _sameish_invocations(
408408
self, same: invocation.StubbedInvocation
409409
) -> list[invocation.StubbedInvocation]:
410-
"""Find prior stubs that are *mutually* signature-compatible.
410+
"""Find prior stubs that are signature-compatible.
411411
412412
This is used only for continuation bookkeeping (value-vs-chain mode),
413-
not for runtime call dispatch. We intentionally do a symmetric check
414-
(`a.matches(b)` and `b.matches(a)`) to approximate "same signature"
415-
despite one-way matchers like `any()`.
413+
not for runtime call dispatch. The comparison is structural and avoids
414+
executing matcher predicates, so `arg_that(...)` and other custom
415+
matchers cannot crash internal equivalence probing.
416416
417417
Why this exists: repeated selectors such as
418418
@@ -439,13 +439,7 @@ def _invocations_are_sameish(
439439
left: invocation.StubbedInvocation,
440440
right: invocation.StubbedInvocation,
441441
) -> bool:
442-
# Be conservative in internal equivalence probing: user predicates from
443-
# `arg_that` can throw when evaluated against matcher/sentinel objects.
444-
# In this phase, exceptions should mean "not equivalent", not failure.
445-
try:
446-
return left.matches(right) and right.matches(left)
447-
except Exception:
448-
return False
442+
return sameish.invocations_are_sameish(left, right)
449443

450444
def get_original_method(self, method_name: str) -> object | None:
451445
return self._original_methods.get(method_name, None)

mockito/sameish.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from . import matchers
6+
7+
if TYPE_CHECKING:
8+
from .invocation import StubbedInvocation
9+
10+
11+
def invocations_are_sameish(
12+
left: StubbedInvocation,
13+
right: StubbedInvocation,
14+
) -> bool:
15+
"""Structural signature-compatibility checks for continuation bookkeeping.
16+
17+
Intentionally avoids executing user-provided matcher predicates
18+
(e.g. `arg_that(...)) while comparing stub signatures.
19+
"""
20+
21+
return (
22+
_params_are_sameish(left.params, right.params)
23+
and _named_params_are_sameish(
24+
left.named_params,
25+
right.named_params,
26+
)
27+
)
28+
29+
30+
def _params_are_sameish(left: tuple, right: tuple) -> bool:
31+
if len(left) != len(right):
32+
return False
33+
34+
return all(
35+
_values_are_sameish(left_value, right_value)
36+
for left_value, right_value in zip(left, right)
37+
)
38+
39+
40+
def _named_params_are_sameish(left: dict, right: dict) -> bool:
41+
if set(left) != set(right):
42+
return False
43+
44+
return all(
45+
_values_are_sameish(left[key], right[key])
46+
for key in left
47+
)
48+
49+
50+
def _values_are_sameish(left: object, right: object) -> bool:
51+
if left is right:
52+
return True
53+
54+
if left is Ellipsis or right is Ellipsis:
55+
return left is right
56+
57+
if isinstance(left, matchers.Matcher) and isinstance(right, matchers.Matcher):
58+
return _matchers_are_sameish(left, right)
59+
60+
if isinstance(left, matchers.Matcher) or isinstance(right, matchers.Matcher):
61+
return False
62+
63+
return _equals_or_identity(left, right)
64+
65+
66+
def _matchers_are_sameish( # noqa: C901
67+
left: matchers.Matcher,
68+
right: matchers.Matcher,
69+
) -> bool:
70+
if left is right:
71+
return True
72+
73+
if type(left) is not type(right):
74+
return False
75+
76+
if isinstance(left, matchers.Any) and isinstance(right, matchers.Any):
77+
return _equals_or_identity(left.wanted_type, right.wanted_type)
78+
79+
if (
80+
isinstance(left, matchers.ValueMatcher)
81+
and isinstance(right, matchers.ValueMatcher)
82+
):
83+
return _values_are_sameish(left.value, right.value)
84+
85+
if (
86+
isinstance(left, (matchers.And, matchers.Or))
87+
and isinstance(right, (matchers.And, matchers.Or))
88+
):
89+
return _params_are_sameish(
90+
tuple(left.matchers),
91+
tuple(right.matchers),
92+
)
93+
94+
if isinstance(left, matchers.Not) and isinstance(right, matchers.Not):
95+
return _values_are_sameish(left.matcher, right.matcher)
96+
97+
if isinstance(left, matchers.ArgThat) and isinstance(right, matchers.ArgThat):
98+
return left.predicate is right.predicate
99+
100+
if isinstance(left, matchers.Contains) and isinstance(right, matchers.Contains):
101+
return _values_are_sameish(left.sub, right.sub)
102+
103+
if isinstance(left, matchers.Matches) and isinstance(right, matchers.Matches):
104+
return (
105+
left.regex.pattern == right.regex.pattern
106+
and left.flags == right.flags
107+
)
108+
109+
if (
110+
isinstance(left, matchers.ArgumentCaptor)
111+
and isinstance(right, matchers.ArgumentCaptor)
112+
):
113+
return _values_are_sameish(left.matcher, right.matcher)
114+
115+
if (
116+
isinstance(left, matchers.CallCaptor)
117+
and isinstance(right, matchers.CallCaptor)
118+
):
119+
return False
120+
121+
return _equals_or_identity(left, right)
122+
123+
124+
def _equals_or_identity(left: object, right: object) -> bool:
125+
try:
126+
return left == right
127+
except Exception:
128+
return left is right

tests/chaining_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,29 @@ def test_multiple_chain_branches_on_same_root_are_supported():
3737
assert cat_that_meowed.roll() == "playful"
3838

3939

40+
def test_multiple_chain_branches_with_equivalent_typed_any_matchers_share_root():
41+
cat = mock()
42+
43+
when(cat).meow(any_(int)).purr().thenReturn("friendly")
44+
when(cat).meow(any_(int)).roll().thenReturn("playful")
45+
46+
cat_that_meowed = cat.meow(1)
47+
assert cat_that_meowed.purr() == "friendly"
48+
assert cat_that_meowed.roll() == "playful"
49+
50+
51+
def test_multiple_chain_branches_with_same_arg_that_matcher_share_root():
52+
cat = mock()
53+
pred = arg_that(lambda value: value > 0)
54+
55+
when(cat).meow(pred).purr().thenReturn("friendly")
56+
when(cat).meow(pred).roll().thenReturn("playful")
57+
58+
cat_that_meowed = cat.meow(1)
59+
assert cat_that_meowed.purr() == "friendly"
60+
assert cat_that_meowed.roll() == "playful"
61+
62+
4063
def test_unstub_child_chain_then_reconfigure_does_not_leave_stale_root_stub():
4164
cat = mock()
4265

tests/sameish_test.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from dataclasses import dataclass, field
2+
3+
from mockito import and_, any as any_, arg_that, call_captor, eq, gt, neq, or_
4+
from mockito import sameish
5+
6+
7+
@dataclass
8+
class FakeInvocation:
9+
params: tuple = ()
10+
named_params: dict = field(default_factory=dict)
11+
12+
13+
def bar(*params, **named_params):
14+
return FakeInvocation(params=params, named_params=named_params)
15+
16+
17+
def test_concrete_values_must_match_exactly():
18+
assert sameish.invocations_are_sameish(
19+
bar(1, "x"),
20+
bar(1, "x"),
21+
)
22+
assert not sameish.invocations_are_sameish(
23+
bar(1, "x"),
24+
bar(2, "x"),
25+
)
26+
27+
28+
def test_keyword_names_must_match_independent_of_order():
29+
assert sameish.invocations_are_sameish(
30+
bar(a=1, b=2),
31+
bar(b=2, a=1),
32+
)
33+
assert not sameish.invocations_are_sameish(
34+
bar(a=1),
35+
bar(a=1, b=2),
36+
)
37+
38+
39+
def test_any_matchers_are_compared_structurally():
40+
assert sameish.invocations_are_sameish(
41+
bar(any_(int)),
42+
bar(any_(int)),
43+
)
44+
assert not sameish.invocations_are_sameish(
45+
bar(any_(int)),
46+
bar(any_()),
47+
)
48+
assert not sameish.invocations_are_sameish(
49+
bar(any_()),
50+
bar(1),
51+
)
52+
53+
54+
def test_composite_matchers_are_compared_recursively():
55+
assert sameish.invocations_are_sameish(
56+
bar(and_(any_(int), gt(1))),
57+
bar(and_(any_(int), gt(1))),
58+
)
59+
assert not sameish.invocations_are_sameish(
60+
bar(and_(any_(int), gt(1))),
61+
bar(and_(any_(int), gt(2))),
62+
)
63+
64+
65+
def test_distinct_matcher_types_are_not_sameish_even_with_equal_payload():
66+
assert not sameish.invocations_are_sameish(
67+
bar(eq(1)),
68+
bar(neq(1)),
69+
)
70+
assert not sameish.invocations_are_sameish(
71+
bar(and_(any_(int), gt(1))),
72+
bar(or_(any_(int), gt(1))),
73+
)
74+
75+
76+
def test_arg_that_uses_predicate_identity_and_does_not_execute_predicate():
77+
calls = []
78+
79+
def predicate(value):
80+
calls.append(value)
81+
raise RuntimeError("must not be executed")
82+
83+
assert sameish.invocations_are_sameish(
84+
bar(arg_that(predicate)),
85+
bar(arg_that(predicate)),
86+
)
87+
assert calls == []
88+
89+
90+
def test_arg_that_with_different_predicates_is_not_sameish():
91+
assert not sameish.invocations_are_sameish(
92+
bar(arg_that(lambda value: value > 0)),
93+
bar(arg_that(lambda value: value > 0)),
94+
)
95+
96+
97+
def test_arg_that_predicate_side_effects_are_not_triggered():
98+
seen = []
99+
100+
def predicate(value):
101+
seen.append(value)
102+
return True
103+
104+
assert sameish.invocations_are_sameish(
105+
bar(arg_that(predicate)),
106+
bar(arg_that(predicate)),
107+
)
108+
assert seen == []
109+
110+
111+
def test_call_captor_instances_are_not_interchangeable():
112+
left = call_captor()
113+
right = call_captor()
114+
115+
assert sameish.invocations_are_sameish(
116+
bar(left),
117+
bar(left),
118+
)
119+
assert not sameish.invocations_are_sameish(
120+
bar(left),
121+
bar(right),
122+
)
123+
124+
125+
def test_eq_failures_fallback_to_identity():
126+
class EqBoom:
127+
def __eq__(self, other):
128+
raise RuntimeError("boom")
129+
130+
first = EqBoom()
131+
second = EqBoom()
132+
133+
assert sameish.invocations_are_sameish(
134+
bar(first),
135+
bar(first),
136+
)
137+
assert not sameish.invocations_are_sameish(
138+
bar(first),
139+
bar(second),
140+
)

0 commit comments

Comments
 (0)