@@ -6781,7 +6781,7 @@ def narrow_type_by_identity_equality(
67816781 continue
67826782 expr = operands [j ]
67836783
6784- current_type_range = self .get_isinstance_type (expr )
6784+ current_type_range = self .get_isinstance_type (expr , flatten_tuples = False )
67856785 if current_type_range is not None :
67866786 target_type = get_proper_type (
67876787 make_simplified_union ([tr .item for tr in current_type_range ])
@@ -7876,22 +7876,27 @@ def is_writable_attribute(self, node: Node) -> bool:
78767876 return first_item .var .is_settable_property
78777877 return False
78787878
7879- def get_isinstance_type (self , expr : Expression ) -> list [TypeRange ] | None :
7879+ def get_isinstance_type (self , expr : Expression , flatten_tuples : bool = True ) -> list [TypeRange ] | None :
78807880 """Get the type(s) resulting from an isinstance check.
78817881
78827882 Returns an empty list for isinstance(x, ()).
78837883 """
78847884 if isinstance (expr , OpExpr ) and expr .op == "|" :
7885- left = self .get_isinstance_type (expr .left )
7885+ left = self .get_isinstance_type (expr .left , flatten_tuples = False )
78867886 if left is None and is_literal_none (expr .left ):
78877887 left = [TypeRange (NoneType (), is_upper_bound = False )]
7888- right = self .get_isinstance_type (expr .right )
7888+ right = self .get_isinstance_type (expr .right , flatten_tuples = False )
78897889 if right is None and is_literal_none (expr .right ):
78907890 right = [TypeRange (NoneType (), is_upper_bound = False )]
78917891 if left is None or right is None :
78927892 return None
78937893 return left + right
7894- all_types = get_proper_types (flatten_types (self .lookup_type (expr )))
7894+
7895+ if flatten_tuples :
7896+ all_types = get_proper_types (flatten_types_if_tuple (self .lookup_type (expr )))
7897+ else :
7898+ all_types = [get_proper_type (self .lookup_type (expr ))]
7899+
78957900 types : list [TypeRange ] = []
78967901 for typ in all_types :
78977902 if isinstance (typ , FunctionLike ) and typ .is_type_obj ():
@@ -8619,11 +8624,11 @@ def flatten(t: Expression) -> list[Expression]:
86198624 return [t ]
86208625
86218626
8622- def flatten_types (t : Type ) -> list [Type ]:
8627+ def flatten_types_if_tuple (t : Type ) -> list [Type ]:
86238628 """Flatten a nested sequence of tuples into one list of nodes."""
86248629 t = get_proper_type (t )
86258630 if isinstance (t , TupleType ):
8626- return [b for a in t .items for b in flatten_types (a )]
8631+ return [b for a in t .items for b in flatten_types_if_tuple (a )]
86278632 elif is_named_instance (t , "builtins.tuple" ):
86288633 return [t .args [0 ]]
86298634 else :
0 commit comments