diff --git a/tests/parser/functions/test_interfaces.py b/tests/parser/functions/test_interfaces.py index e209582e87..ba270076ff 100644 --- a/tests/parser/functions/test_interfaces.py +++ b/tests/parser/functions/test_interfaces.py @@ -145,6 +145,57 @@ def foo() -> uint256: ) +def test_missing_event(assert_compile_failed): + interface_code = """ +event Foo: + a: uint256 + """ + + interface_codes = {"FooBarInterface": {"type": "vyper", "code": interface_code}} + + not_implemented_code = """ +import a as FooBarInterface + +implements: FooBarInterface + +@external +def bar() -> uint256: + return 1 + """ + + assert_compile_failed( + lambda: compile_code(not_implemented_code, interface_codes=interface_codes), + InterfaceViolation, + ) + + +def test_malformed_event(assert_compile_failed): + interface_code = """ +event Foo: + a: uint256 + """ + + interface_codes = {"FooBarInterface": {"type": "vyper", "code": interface_code}} + + not_implemented_code = """ +import a as FooBarInterface + +implements: FooBarInterface + +event Foo: + a: int128 + +@external +def bar() -> uint256: + return 1 + """ + + assert_compile_failed( + lambda: compile_code(not_implemented_code, interface_codes=interface_codes), + InterfaceViolation, + ) + + VALID_IMPORT_CODE = [ # import statement, import path without suffix ("import a as Foo", "a"), diff --git a/vyper/context/types/meta/event.py b/vyper/context/types/meta/event.py index b463f1c0bb..241afd1d22 100644 --- a/vyper/context/types/meta/event.py +++ b/vyper/context/types/meta/event.py @@ -1,10 +1,13 @@ from collections import OrderedDict -from typing import List +from typing import Dict, List from vyper import ast as vy_ast from vyper.ast.validation import validate_call_args from vyper.context.types.bases import DataLocation -from vyper.context.types.utils import get_type_from_annotation +from vyper.context.types.utils import ( + get_type_from_abi, + get_type_from_annotation, +) from vyper.context.validation.utils import validate_expected_type from vyper.exceptions import ( EventDeclarationException, @@ -39,6 +42,26 @@ def __init__(self, name: str, arguments: OrderedDict, indexed: List) -> None: signature = f"{name}({','.join(v.canonical_type for v in arguments.values())})" self.event_id = int(keccak256(signature.encode()).hex(), 16) + @classmethod + def from_abi(cls, abi: Dict) -> "Event": + """ + Generate an `Event` object from an ABI interface. + + Arguments + --------- + abi : dict + An object from a JSON ABI interface, representing an event. + + Returns + ------- + Event object. + """ + members: OrderedDict = OrderedDict() + indexed: List = [i["indexed"] for i in abi["inputs"]] + for item in abi["inputs"]: + members[item["name"]] = get_type_from_abi(item) + return Event(abi["name"], members, indexed) + @classmethod def from_EventDef(cls, base_node: vy_ast.EventDef) -> "Event": """ diff --git a/vyper/context/types/meta/interface.py b/vyper/context/types/meta/interface.py index b298e21c7e..ee52a7b70d 100644 --- a/vyper/context/types/meta/interface.py +++ b/vyper/context/types/meta/interface.py @@ -1,11 +1,12 @@ from collections import OrderedDict -from typing import Union +from typing import Dict, Tuple, Union from vyper import ast as vy_ast from vyper.ast.validation import validate_call_args from vyper.context.namespace import get_namespace from vyper.context.types.bases import DataLocation, MemberTypeDefinition from vyper.context.types.function import ContractFunction +from vyper.context.types.meta.event import Event from vyper.context.types.value.address import AddressDefinition from vyper.context.validation.utils import validate_expected_type from vyper.exceptions import ( @@ -38,9 +39,10 @@ class InterfacePrimitive: _is_callable = True _as_array = True - def __init__(self, _id, members): + def __init__(self, _id, members, events): self._id = _id self.members = members + self.events = events def __repr__(self): return f"{self._id} declaration object" @@ -66,6 +68,7 @@ def fetch_call_return(self, node: vy_ast.Call) -> InterfaceDefinition: def validate_implements(self, node: vy_ast.AnnAssign) -> None: namespace = get_namespace() + # check for missing functions unimplemented = [ name for name, type_ in self.members.items() @@ -73,9 +76,18 @@ def validate_implements(self, node: vy_ast.AnnAssign) -> None: or not hasattr(namespace["self"].members[name], "compare_signature") or not namespace["self"].members[name].compare_signature(type_) ] + # check for missing events + unimplemented += [ + name + for name, event in self.events.items() + if name not in namespace + or not isinstance(namespace[name], Event) + or namespace[name].event_id != event.event_id + ] if unimplemented: + missing_str = ", ".join(sorted(unimplemented)) raise InterfaceViolation( - f"Contract does not implement all interface functions: {', '.join(unimplemented)}", + f"Contract does not implement all interface functions or events: {missing_str}", node, ) @@ -97,16 +109,22 @@ def build_primitive_from_abi(name: str, abi: dict) -> InterfacePrimitive: primitive interface type """ members: OrderedDict = OrderedDict() + events: Dict = {} + + names = [i["name"] for i in abi if i.get("type") in ("event", "function")] + collisions = set(i for i in names if names.count(i) > 1) + if collisions: + collision_list = ", ".join(sorted(collisions)) + raise NamespaceCollision( + f"ABI '{name}' has multiple functions or events with the same name: {collision_list}" + ) + for item in [i for i in abi if i.get("type") == "function"]: - func = ContractFunction.from_abi(item) - if func.name in members: - # TODO overloaded functions - raise NamespaceCollision( - f"ABI '{name}' contains multiple functions named '{func.name}'" - ) - members[func.name] = func + members[item["name"]] = ContractFunction.from_abi(item) + for item in [i for i in abi if i.get("type") == "event"]: + events[item["name"]] = Event.from_abi(item) - return InterfacePrimitive(name, members) + return InterfacePrimitive(name, members, events) def build_primitive_from_node( @@ -125,22 +143,24 @@ def build_primitive_from_node( primitive interface type """ if isinstance(node, vy_ast.Module): - members = _get_module_functions(node) + members, events = _get_module_definitions(node) elif isinstance(node, vy_ast.InterfaceDef): members = _get_class_functions(node) + events = {} else: raise StructureException("Invalid syntax for interface definition", node) namespace = get_namespace() - for func in members.values(): - if func.name in namespace: - raise NamespaceCollision(func.name, func.node) + for item in list(members.values()) + list(events.values()): + if item.name in namespace: + raise NamespaceCollision(item.name, item.node) - return InterfacePrimitive(node.name, members) + return InterfacePrimitive(node.name, members, events) -def _get_module_functions(base_node: vy_ast.Module) -> OrderedDict: +def _get_module_definitions(base_node: vy_ast.Module) -> Tuple[OrderedDict, Dict]: functions: OrderedDict = OrderedDict() + events: Dict = {} for node in base_node.get_children(vy_ast.FunctionDef): if "external" in [i.id for i in node.decorator_list]: func = ContractFunction.from_FunctionDef(node) @@ -167,7 +187,15 @@ def _get_module_functions(base_node: vy_ast.Module) -> OrderedDict: f"Interface contains multiple functions named '{name}'", base_node ) functions[name] = ContractFunction.from_AnnAssign(node) - return functions + for node in base_node.get_children(vy_ast.EventDef): + name = node.name + if name in functions or name in events: + raise NamespaceCollision( + f"Interface contains multiple objects named '{name}'", base_node + ) + events[name] = Event.from_EventDef(node) + + return functions, events def _get_class_functions(base_node: vy_ast.InterfaceDef) -> OrderedDict: