Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5284,10 +5284,19 @@ def find_isinstance_check_helper(self, node: Expression) -> tuple[TypeMap, TypeM
return self.hasattr_type_maps(expr, self.lookup_type(expr), attr[0])
elif isinstance(node.callee, RefExpr):
if node.callee.type_guard is not None:
# TODO: Follow keyword args or *args, **kwargs
# TODO: Follow *args, **kwargs
if node.arg_kinds[0] != nodes.ARG_POS:
self.fail(message_registry.TYPE_GUARD_POS_ARG_REQUIRED, node)
return {}, {}
# the first argument might be used as a kwarg
called_type = get_proper_type(self.lookup_type(node.callee))
assert isinstance(called_type, CallableType)
Comment thread
A5rocks marked this conversation as resolved.
Outdated
name = called_type.arg_names[0]
if name in node.arg_names:
idx = node.arg_names.index(name)
# we want the idx-th variable to be narrowed
expr = collapse_walrus(node.args[idx])
Comment thread
A5rocks marked this conversation as resolved.
else:
self.fail(message_registry.TYPE_GUARD_POS_ARG_REQUIRED, node)
return {}, {}
if literal(expr) == LITERAL_TYPE:
# Note: we wrap the target type, so that we can special case later.
# Namely, for isinstance() we use a normal meet, while TypeGuard is
Expand Down
14 changes: 14 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,20 @@ def analyze_func_def(self, defn: FuncDef) -> None:
return
assert isinstance(result, ProperType)
if isinstance(result, CallableType):
# type guards need to have a positional argument, to spec
Comment thread
A5rocks marked this conversation as resolved.
if (
result.type_guard
and ARG_POS not in result.arg_kinds[self.is_class_scope() :]
and not defn.is_static
):
self.fail(
"TypeGuard functions must have a positional argument",
result,
code=codes.VALID_TYPE,
)
# in this case, we just kind of just ... remove the type guard.
result = result.copy_modified(type_guard=None)

result = self.remove_unpack_kwargs(defn, result)
if has_self_type and self.type is not None:
info = self.type
Expand Down
57 changes: 50 additions & 7 deletions test-data/unit/check-typeguard.test
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ reveal_type(foo) # N: Revealed type is "def (a: builtins.object) -> TypeGuard[b
[case testTypeGuardCallArgsNone]
from typing_extensions import TypeGuard
class Point: pass
# TODO: error on the 'def' line (insufficient args for type guard)
def is_point() -> TypeGuard[Point]: pass

def is_point() -> TypeGuard[Point]: pass # E: TypeGuard functions must have a positional argument
def main(a: object) -> None:
if is_point():
reveal_type(a) # N: Revealed type is "builtins.object"
Expand Down Expand Up @@ -227,13 +227,13 @@ def main(a: object) -> None:
from typing_extensions import TypeGuard
def is_float(a: object, b: object = 0) -> TypeGuard[float]: pass
def main1(a: object) -> None:
# This is debatable -- should we support these cases?
if is_float(a=a, b=1):
reveal_type(a) # N: Revealed type is "builtins.float"

if is_float(a=a, b=1): # E: Type guard requires positional argument
reveal_type(a) # N: Revealed type is "builtins.object"
if is_float(b=1, a=a):
reveal_type(a) # N: Revealed type is "builtins.float"

if is_float(b=1, a=a): # E: Type guard requires positional argument
reveal_type(a) # N: Revealed type is "builtins.object"
# This is debatable -- should we support these cases?

ta = (a,)
if is_float(*ta): # E: Type guard requires positional argument
Expand Down Expand Up @@ -597,3 +597,46 @@ def func(names: Tuple[str, ...]):
if is_two_element_tuple(names):
reveal_type(names) # N: Revealed type is "Tuple[builtins.str, builtins.str]"
[builtins fixtures/tuple.pyi]

[case testTypeGuardErroneousDefinitionFails]
from typing_extensions import TypeGuard

class Z:
def typeguard(self, *, x: object) -> TypeGuard[int]: # E: TypeGuard functions must have a positional argument
...

def bad_typeguard(*, x: object) -> TypeGuard[int]: # E: TypeGuard functions must have a positional argument
...

# make sure not to break other things

class Y:
@staticmethod
def typeguard(h: object) -> TypeGuard[int]:
...

x: object
if Y().typeguard(x):
reveal_type(x) # N: Revealed type is "builtins.int"
if Y.typeguard(x):
reveal_type(x) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]
[builtins fixtures/classmethod.pyi]

[case testTypeGuardWithKeywordArg]
from typing_extensions import TypeGuard

class Z:
def typeguard(self, x: object) -> TypeGuard[int]:
...

def typeguard(x: object) -> TypeGuard[int]:
...

n: object
if typeguard(x=n):
reveal_type(n) # N: Revealed type is "builtins.int"

if Z().typeguard(x=n):
reveal_type(n) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]