Skip to content

Commit 55f76d8

Browse files
authored
Implement equality constraint (#2927)
1 parent 6054748 commit 55f76d8

3 files changed

Lines changed: 220 additions & 17 deletions

File tree

ChangeLog

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ What's New in astroid 4.1.0?
77
============================
88
Release date: TBA
99

10+
* Add support for equality constraints (``==``, ``!=``) in inference.
11+
Closes pylint-dev/pylint#3632
12+
Closes pylint-dev/pylint#3633
13+
1014
* Ensure ``ast.JoinedStr`` nodes are ``Uninferable`` when the ``ast.FormattedValue`` is
1115
``Uninferable``. This prevents ``unexpected-keyword-arg`` messages in Pylint
1216
where the ``Uninferable`` string appeared in function arguments that were

astroid/constraint.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,62 @@ def satisfied_by(self, inferred: InferenceResult) -> bool:
174174
return True
175175

176176

177+
class EqualityConstraint(Constraint):
178+
"""Represents a "==" or "!=" constraint."""
179+
180+
def __init__(self, node: nodes.NodeNG, operand: nodes.NodeNG, negate: bool) -> None:
181+
super().__init__(node=node, negate=negate)
182+
self.operand = operand
183+
184+
@classmethod
185+
def match(
186+
cls, node: _NameNodes, expr: nodes.NodeNG, negate: bool = False
187+
) -> Self | None:
188+
"""Return a new constraint for node if expr matches one of these patterns:
189+
190+
- "node == operand" or "operand == node": use given negate value
191+
- "node != operand" or "operand != node": flip negate value
192+
193+
Return None if no pattern matches.
194+
"""
195+
if isinstance(expr, nodes.Compare) and len(expr.ops) == 1:
196+
left = expr.left
197+
op, right = expr.ops[0]
198+
matches_left = _matches(left, node)
199+
200+
if op in {"==", "!="} and (matches_left or _matches(right, node)):
201+
operand = right if matches_left else left
202+
negate = (op == "==" and negate) or (op == "!=" and not negate)
203+
return cls(node=node, operand=operand, negate=negate)
204+
205+
return None
206+
207+
def satisfied_by(self, inferred: InferenceResult) -> bool:
208+
"""Return True for uninferable/ambiguous results, or depending on negate flag:
209+
210+
- negate=False: satisfied when both operands are equal.
211+
- negate=True: satisfied when both operands are not equal.
212+
213+
Only comparisons between constants and callables are supported.
214+
"""
215+
if inferred is util.Uninferable:
216+
return True
217+
218+
operand_inferred = util.safe_infer(self.operand)
219+
if operand_inferred is util.Uninferable or operand_inferred is None:
220+
return True
221+
222+
if isinstance(inferred, nodes.Const) and isinstance(
223+
operand_inferred, nodes.Const
224+
):
225+
return self.negate ^ (inferred.value == operand_inferred.value)
226+
227+
if inferred.callable() and operand_inferred.callable():
228+
return self.negate ^ (inferred is operand_inferred)
229+
230+
return True
231+
232+
177233
def get_constraints(
178234
expr: _NameNodes, frame: nodes.LocalsDictNodeNG
179235
) -> dict[nodes.If | nodes.IfExp, set[Constraint]]:
@@ -209,6 +265,7 @@ def get_constraints(
209265
NoneConstraint,
210266
BooleanConstraint,
211267
TypeConstraint,
268+
EqualityConstraint,
212269
)
213270
)
214271
"""All supported constraint types."""

tests/test_constraint.py

Lines changed: 159 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
from astroid.util import Uninferable
1616

1717

18+
def node_info(node: nodes.NodeNG) -> str:
19+
return f"Inference of {node.as_string()!r} at line {node.lineno}"
20+
21+
1822
def common_params(node: str) -> pytest.MarkDecorator:
1923
return pytest.mark.parametrize(
2024
("condition", "satisfy_val", "fail_val"),
@@ -25,6 +29,10 @@ def common_params(node: str) -> pytest.MarkDecorator:
2529
(f"not {node}", None, 3),
2630
(f"isinstance({node}, int)", 3, None),
2731
(f"isinstance({node}, (int, str))", 3, None),
32+
(f"{node} == 3", 3, None),
33+
(f"{node} != 3", None, 3),
34+
(f"3 == {node}", 3, None),
35+
(f"3 != {node}", None, 3),
2836
),
2937
)
3038

@@ -267,17 +275,19 @@ def f2(y, x = {satisfy_val}):
267275
)
268276
""")
269277
for node in (node1, node2):
278+
msg = node_info(node)
270279
inferred = node.inferred()
271-
assert len(inferred) == 2
272-
assert isinstance(inferred[0], nodes.Const)
273-
assert inferred[0].value == fail_val
280+
assert len(inferred) == 2, msg
281+
assert isinstance(inferred[0], nodes.Const), msg
282+
assert inferred[0].value == fail_val, msg
274283

275-
assert inferred[1] is Uninferable
284+
assert inferred[1] is Uninferable, msg
276285

277286
for node in (node3, node4):
287+
msg = node_info(node)
278288
inferred = node.inferred()
279-
assert len(inferred) == 1
280-
assert inferred[0] is Uninferable
289+
assert len(inferred) == 1, msg
290+
assert inferred[0] is Uninferable, msg
281291

282292

283293
@common_params(node="x")
@@ -839,11 +849,12 @@ class C(A, B):
839849
""")
840850

841851
for node in (n1, n2, n3):
852+
msg = node_info(node)
842853
inferred = node.inferred()
843-
assert len(inferred) == 1
844-
assert isinstance(inferred[0], Instance)
845-
assert isinstance(inferred[0]._proxied, nodes.ClassDef)
846-
assert inferred[0].name == "C"
854+
assert len(inferred) == 1, msg
855+
assert isinstance(inferred[0], Instance), msg
856+
assert isinstance(inferred[0]._proxied, nodes.ClassDef), msg
857+
assert inferred[0].name == "C", msg
847858

848859

849860
def test_isinstance_diamond_inheritance():
@@ -879,11 +890,12 @@ class D(B, C):
879890
""")
880891

881892
for node in (n1, n2, n3, n4):
893+
msg = node_info(node)
882894
inferred = node.inferred()
883-
assert len(inferred) == 1
884-
assert isinstance(inferred[0], Instance)
885-
assert isinstance(inferred[0]._proxied, nodes.ClassDef)
886-
assert inferred[0].name == "D"
895+
assert len(inferred) == 1, msg
896+
assert isinstance(inferred[0], Instance), msg
897+
assert isinstance(inferred[0]._proxied, nodes.ClassDef), msg
898+
assert inferred[0].name == "D", msg
887899

888900

889901
def test_isinstance_keyword_arguments():
@@ -901,10 +913,11 @@ def test_isinstance_keyword_arguments():
901913
""")
902914

903915
for node in (n1, n2):
916+
msg = node_info(node)
904917
inferred = node.inferred()
905-
assert len(inferred) == 1
906-
assert isinstance(inferred[0], nodes.Const)
907-
assert inferred[0].value == 3
918+
assert len(inferred) == 1, msg
919+
assert isinstance(inferred[0], nodes.Const), msg
920+
assert inferred[0].value == 3, msg
908921

909922

910923
def test_isinstance_extra_argument():
@@ -999,3 +1012,132 @@ def test_isinstance_uninferable():
9991012
assert len(inferred) == 1
10001013
assert isinstance(inferred[0], nodes.Const)
10011014
assert inferred[0].value == 3
1015+
1016+
1017+
def test_equality_callable():
1018+
"""Test constraint for equality of callables."""
1019+
node1, node2, node3, node4, node5, node6 = builder.extract_node("""
1020+
class Foo:
1021+
pass
1022+
1023+
def bar():
1024+
pass
1025+
1026+
baz = lambda i : i
1027+
1028+
x, y, z = Foo, bar, baz
1029+
1030+
if x == Foo:
1031+
x #@
1032+
if x != Foo:
1033+
x #@
1034+
1035+
if y == bar:
1036+
y #@
1037+
if y != bar:
1038+
y #@
1039+
1040+
if z == baz:
1041+
z #@
1042+
if z != baz:
1043+
z #@
1044+
""")
1045+
1046+
inferred = node1.inferred()
1047+
assert len(inferred) == 1
1048+
assert isinstance(inferred[0], nodes.ClassDef)
1049+
assert inferred[0].name == "Foo"
1050+
1051+
inferred = node3.inferred()
1052+
assert len(inferred) == 1
1053+
assert isinstance(inferred[0], nodes.FunctionDef)
1054+
assert inferred[0].name == "bar"
1055+
1056+
inferred = node5.inferred()
1057+
assert len(inferred) == 1
1058+
assert isinstance(inferred[0], nodes.Lambda)
1059+
1060+
for node in (node2, node4, node6):
1061+
msg = node_info(node)
1062+
inferred = node.inferred()
1063+
assert len(inferred) == 1, msg
1064+
assert inferred[0] is Uninferable, msg
1065+
1066+
1067+
def test_equality_uninferable_operand():
1068+
"""Test that equality constraint is satisfied when either operand is uninferable."""
1069+
node1, node2, node3, node4 = builder.extract_node("""
1070+
def f1(x):
1071+
if x == 3:
1072+
x #@
1073+
1074+
if x != 3:
1075+
x #@
1076+
1077+
def f2(y):
1078+
x = 3
1079+
if x == y:
1080+
x #@
1081+
1082+
if x != y:
1083+
x #@
1084+
""")
1085+
1086+
for node in (node1, node2):
1087+
msg = node_info(node)
1088+
inferred = node.inferred()
1089+
assert len(inferred) == 1, msg
1090+
assert inferred[0] is Uninferable, msg
1091+
1092+
for node in (node3, node4):
1093+
msg = node_info(node)
1094+
inferred = node.inferred()
1095+
assert len(inferred) == 1, msg
1096+
assert isinstance(inferred[0], nodes.Const), msg
1097+
assert inferred[0].value == 3, msg
1098+
1099+
1100+
def test_equality_ambiguous_operand():
1101+
"""Test that equality constraint is satisfied when the compared operand has multiple inferred values."""
1102+
node1, node2 = builder.extract_node("""
1103+
def f(y = 1):
1104+
x = 3
1105+
if x == y:
1106+
x #@
1107+
1108+
if x != y:
1109+
x #@
1110+
""")
1111+
1112+
for node in (node1, node2):
1113+
msg = node_info(node)
1114+
inferred = node.inferred()
1115+
assert len(inferred) == 1, msg
1116+
assert isinstance(inferred[0], nodes.Const), msg
1117+
assert inferred[0].value == 3, msg
1118+
1119+
1120+
def test_equality_fractions():
1121+
"""Test that equality constraint is satisfied when both operands are fractions."""
1122+
node1, node2, node3, node4 = builder.extract_node("""
1123+
from fractions import Fraction
1124+
1125+
x = Fraction(1, 3)
1126+
y = Fraction(1, 3)
1127+
1128+
if x == y:
1129+
x #@
1130+
y #@
1131+
1132+
if x != y:
1133+
x #@
1134+
y #@
1135+
""")
1136+
1137+
for node in (node1, node2, node3, node4):
1138+
msg = node_info(node)
1139+
inferred = node.inferred()
1140+
assert len(inferred) == 1, msg
1141+
assert isinstance(inferred[0], Instance), msg
1142+
assert isinstance(inferred[0]._proxied, nodes.ClassDef), msg
1143+
assert inferred[0]._proxied.name == "Fraction", msg

0 commit comments

Comments
 (0)