@@ -6770,7 +6770,7 @@ def narrow_type_by_identity_equality(
67706770 continue
67716771 expr = operands [j ]
67726772
6773- current_type_range = self .get_isinstance_type (expr )
6773+ current_type_range = self .get_isinstance_type (expr , flatten_tuples = False )
67746774 if current_type_range is not None :
67756775 target_type = get_proper_type (
67766776 make_simplified_union ([tr .item for tr in current_type_range ])
@@ -7865,22 +7865,27 @@ def is_writable_attribute(self, node: Node) -> bool:
78657865 return first_item .var .is_settable_property
78667866 return False
78677867
7868- def get_isinstance_type (self , expr : Expression ) -> list [TypeRange ] | None :
7868+ def get_isinstance_type (self , expr : Expression , flatten_tuples : bool = True ) -> list [TypeRange ] | None :
78697869 """Get the type(s) resulting from an isinstance check.
78707870
78717871 Returns an empty list for isinstance(x, ()).
78727872 """
78737873 if isinstance (expr , OpExpr ) and expr .op == "|" :
7874- left = self .get_isinstance_type (expr .left )
7874+ left = self .get_isinstance_type (expr .left , flatten_tuples = False )
78757875 if left is None and is_literal_none (expr .left ):
78767876 left = [TypeRange (NoneType (), is_upper_bound = False )]
7877- right = self .get_isinstance_type (expr .right )
7877+ right = self .get_isinstance_type (expr .right , flatten_tuples = False )
78787878 if right is None and is_literal_none (expr .right ):
78797879 right = [TypeRange (NoneType (), is_upper_bound = False )]
78807880 if left is None or right is None :
78817881 return None
78827882 return left + right
7883- all_types = get_proper_types (flatten_types (self .lookup_type (expr )))
7883+
7884+ if flatten_tuples :
7885+ all_types = get_proper_types (flatten_types_if_tuple (self .lookup_type (expr )))
7886+ else :
7887+ all_types = [get_proper_type (self .lookup_type (expr ))]
7888+
78847889 types : list [TypeRange ] = []
78857890 for typ in all_types :
78867891 if isinstance (typ , FunctionLike ) and typ .is_type_obj ():
@@ -8605,11 +8610,11 @@ def flatten(t: Expression) -> list[Expression]:
86058610 return [t ]
86068611
86078612
8608- def flatten_types (t : Type ) -> list [Type ]:
8613+ def flatten_types_if_tuple (t : Type ) -> list [Type ]:
86098614 """Flatten a nested sequence of tuples into one list of nodes."""
86108615 t = get_proper_type (t )
86118616 if isinstance (t , TupleType ):
8612- return [b for a in t .items for b in flatten_types (a )]
8617+ return [b for a in t .items for b in flatten_types_if_tuple (a )]
86138618 elif is_named_instance (t , "builtins.tuple" ):
86148619 return [t .args [0 ]]
86158620 else :
0 commit comments