@@ -6812,7 +6812,7 @@ def narrow_type_by_identity_equality(
68126812 continue
68136813 expr = operands [j ]
68146814
6815- current_type_range = self .get_isinstance_type (expr )
6815+ current_type_range = self .get_isinstance_type (expr , flatten_tuples = False )
68166816 if current_type_range is not None :
68176817 target_type = get_proper_type (
68186818 make_simplified_union ([tr .item for tr in current_type_range ])
@@ -7907,22 +7907,27 @@ def is_writable_attribute(self, node: Node) -> bool:
79077907 return first_item .var .is_settable_property
79087908 return False
79097909
7910- def get_isinstance_type (self , expr : Expression ) -> list [TypeRange ] | None :
7910+ def get_isinstance_type (self , expr : Expression , flatten_tuples : bool = True ) -> list [TypeRange ] | None :
79117911 """Get the type(s) resulting from an isinstance check.
79127912
79137913 Returns an empty list for isinstance(x, ()).
79147914 """
79157915 if isinstance (expr , OpExpr ) and expr .op == "|" :
7916- left = self .get_isinstance_type (expr .left )
7916+ left = self .get_isinstance_type (expr .left , flatten_tuples = False )
79177917 if left is None and is_literal_none (expr .left ):
79187918 left = [TypeRange (NoneType (), is_upper_bound = False )]
7919- right = self .get_isinstance_type (expr .right )
7919+ right = self .get_isinstance_type (expr .right , flatten_tuples = False )
79207920 if right is None and is_literal_none (expr .right ):
79217921 right = [TypeRange (NoneType (), is_upper_bound = False )]
79227922 if left is None or right is None :
79237923 return None
79247924 return left + right
7925- all_types = get_proper_types (flatten_types (self .lookup_type (expr )))
7925+
7926+ if flatten_tuples :
7927+ all_types = get_proper_types (flatten_types_if_tuple (self .lookup_type (expr )))
7928+ else :
7929+ all_types = [get_proper_type (self .lookup_type (expr ))]
7930+
79267931 types : list [TypeRange ] = []
79277932 for typ in all_types :
79287933 if isinstance (typ , FunctionLike ) and typ .is_type_obj ():
@@ -8654,11 +8659,11 @@ def flatten(t: Expression) -> list[Expression]:
86548659 return [t ]
86558660
86568661
8657- def flatten_types (t : Type ) -> list [Type ]:
8662+ def flatten_types_if_tuple (t : Type ) -> list [Type ]:
86588663 """Flatten a nested sequence of tuples into one list of nodes."""
86598664 t = get_proper_type (t )
86608665 if isinstance (t , TupleType ):
8661- return [b for a in t .items for b in flatten_types (a )]
8666+ return [b for a in t .items for b in flatten_types_if_tuple (a )]
86628667 elif is_named_instance (t , "builtins.tuple" ):
86638668 return [t .args [0 ]]
86648669 else :
0 commit comments