Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4491,6 +4491,26 @@ def analyze_iterable_item_type(self, expr: Expression) -> tuple[Type, Type]:
# Non-tuple iterable.
return iterator, echk.check_method_call_by_name("__next__", iterator, [], [], expr)[0]

def analyze_iterable_item_type_without_expression(
self, type: Type, context: Context
) -> tuple[Type, Type]:
"""Analyse iterable type and return iterator and iterator item types."""
echk = self.expr_checker
iterable = get_proper_type(type)
iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], context)[0]

if isinstance(iterable, TupleType):
joined: Type = UninhabitedType()
for item in iterable.items:
joined = join_types(joined, item)
return iterator, joined
else:
# Non-tuple iterable.
return (
iterator,
echk.check_method_call_by_name("__next__", iterator, [], [], context)[0],
)

def analyze_range_native_int_type(self, expr: Expression) -> Type | None:
"""Try to infer native int item type from arguments to range(...).

Expand Down
140 changes: 89 additions & 51 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2887,75 +2887,115 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
That is, 'a < b > c == d' is check as 'a < b and b > c and c == d'
"""
result: Type | None = None
sub_result: Type | None = None
sub_result: Type

# Check each consecutive operand pair and their operator
for left, right, operator in zip(e.operands, e.operands[1:], e.operators):
left_type = self.accept(left)

method_type: mypy.types.Type | None = None

if operator == "in" or operator == "not in":
"""
This case covers both iterables and containers, which have different meanings.
For a container, the in operator calls the __contains__ method.
For an iterable, the in operator iterates over the iterable, and compares each item one-by-one.
We allow `in` for a union of containers and iterables as long as at least one of them matches the
type of the left operand, as the operation will simply return False if the union's container/iterator
type doesn't match the left operand.
"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style nit: We use # comments also for block comments.


# If the right operand has partial type, look it up without triggering
# a "Need type annotation ..." message, as it would be noise.
right_type = self.find_partial_type_ref_fast_path(right)
if right_type is None:
right_type = self.accept(right) # Validate the right operand

# Keep track of whether we get type check errors (these won't be reported, they
# are just to verify whether something is valid typing wise).
with self.msg.filter_errors(save_filtered_errors=True) as local_errors:
_, method_type = self.check_method_call_by_name(
method="__contains__",
base_type=right_type,
args=[left],
arg_kinds=[ARG_POS],
context=e,
)
right_type = get_proper_type(right_type)
item_types: Sequence[Type] = [right_type]
if isinstance(right_type, UnionType):
item_types = list(right_type.items)

sub_result = self.bool_type()
# Container item type for strict type overlap checks. Note: we need to only
# check for nominal type, because a usual "Unsupported operands for in"
# will be reported for types incompatible with __contains__().
# See testCustomContainsCheckStrictEquality for an example.
cont_type = self.chk.analyze_container_item_type(right_type)
if isinstance(right_type, PartialType):
# We don't really know if this is an error or not, so just shut up.
pass
elif (
local_errors.has_new_errors()
and
# is_valid_var_arg is True for any Iterable
self.is_valid_var_arg(right_type)
):
_, itertype = self.chk.analyze_iterable_item_type(right)
method_type = CallableType(
[left_type],
[nodes.ARG_POS],
[None],
self.bool_type(),
self.named_type("builtins.function"),
)
if not is_subtype(left_type, itertype):
self.msg.unsupported_operand_types("in", left_type, right_type, e)
# Only show dangerous overlap if there are no other errors.
elif (
not local_errors.has_new_errors()
and cont_type
and self.dangerous_comparison(
left_type, cont_type, original_container=right_type
)
):
self.msg.dangerous_comparison(left_type, cont_type, "container", e)
else:
self.msg.add_errors(local_errors.filtered_errors())

container_types: list[Type] = []
iterable_types: list[Type] = []
failed_out = False
encountered_partial_type = False

for item_type in item_types:
# Keep track of whether we get type check errors (these won't be reported, they
# are just to verify whether something is valid typing wise).
with self.msg.filter_errors(save_filtered_errors=True) as container_errors:
_, method_type = self.check_method_call_by_name(
method="__contains__",
base_type=item_type,
args=[left],
arg_kinds=[ARG_POS],
context=e,
original_type=right_type,
)
# Container item type for strict type overlap checks. Note: we need to only
# check for nominal type, because a usual "Unsupported operands for in"
# will be reported for types incompatible with __contains__().
# See testCustomContainsCheckStrictEquality for an example.
cont_type = self.chk.analyze_container_item_type(item_type)

if isinstance(item_type, PartialType):
# We don't really know if this is an error or not, so just shut up.
encountered_partial_type = True
pass
elif (
container_errors.has_new_errors()
and
# is_valid_var_arg is True for any Iterable
self.is_valid_var_arg(item_type)
):
# it's not a container, but it is an iterable
with self.msg.filter_errors(save_filtered_errors=True) as iterable_errors:
_, itertype = self.chk.analyze_iterable_item_type_without_expression(
item_type, e
)
if iterable_errors.has_new_errors():
self.msg.add_errors(iterable_errors.filtered_errors())
failed_out = True
else:
method_type = CallableType(
[left_type],
[nodes.ARG_POS],
[None],
self.bool_type(),
self.named_type("builtins.function"),
)
e.method_types.append(method_type)
iterable_types.append(itertype)
elif not container_errors.has_new_errors() and cont_type:
container_types.append(cont_type)
e.method_types.append(method_type)
else:
self.msg.add_errors(container_errors.filtered_errors())
failed_out = True

if not encountered_partial_type and not failed_out:
iterable_type = UnionType.make_union(iterable_types)
if not is_subtype(left_type, iterable_type):
if len(container_types) == 0:
self.msg.unsupported_operand_types("in", left_type, right_type, e)
else:
container_type = UnionType.make_union(container_types)
if self.dangerous_comparison(
left_type, container_type, original_container=right_type
):
self.msg.dangerous_comparison(
left_type, container_type, "container", e
)

elif operator in operators.op_methods:
method = operators.op_methods[operator]

with ErrorWatcher(self.msg.errors) as w:
sub_result, method_type = self.check_op(
method, left_type, right, e, allow_reverse=True
)
e.method_types.append(method_type)

# Only show dangerous overlap if there are no other errors. See
# testCustomEqCheckStrictEquality for an example.
Expand Down Expand Up @@ -2983,12 +3023,10 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
right_type = try_getting_literal(right_type)
if self.dangerous_comparison(left_type, right_type):
self.msg.dangerous_comparison(left_type, right_type, "identity", e)
method_type = None
e.method_types.append(None)
else:
raise RuntimeError(f"Unknown comparison operator {operator}")

e.method_types.append(method_type)

# Determine type of boolean-and of result and sub_result
if result is None:
result = sub_result
Expand Down
17 changes: 17 additions & 0 deletions test-data/unit/check-unions.test
Original file line number Diff line number Diff line change
Expand Up @@ -1183,3 +1183,20 @@ def foo(
yield i
foo([1])
[builtins fixtures/list.pyi]

[case testUnionIterableContainer]
from typing import Iterable, Container, Union

i: Iterable[str]
c: Container[str]
u: Union[Iterable[str], Container[str]]
ni: Union[Iterable[str], int]
nc: Union[Container[str], int]

'x' in i
'x' in c
'x' in u
'x' in ni # E: Unsupported right operand type for in ("Union[Iterable[str], int]")
'x' in nc # E: Unsupported right operand type for in ("Union[Container[str], int]")
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]