Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions tests/parser/functions/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,57 @@ def foo() -> uint256:
)


def test_missing_event(assert_compile_failed):
interface_code = """
event Foo:
a: uint256
"""

interface_codes = {"FooBarInterface": {"type": "vyper", "code": interface_code}}

not_implemented_code = """
import a as FooBarInterface

implements: FooBarInterface

@external
def bar() -> uint256:
return 1
"""

assert_compile_failed(
lambda: compile_code(not_implemented_code, interface_codes=interface_codes),
InterfaceViolation,
)


def test_malformed_event(assert_compile_failed):
interface_code = """
event Foo:
a: uint256
"""

interface_codes = {"FooBarInterface": {"type": "vyper", "code": interface_code}}

not_implemented_code = """
import a as FooBarInterface

implements: FooBarInterface

event Foo:
a: int128

@external
def bar() -> uint256:
return 1
"""

assert_compile_failed(
lambda: compile_code(not_implemented_code, interface_codes=interface_codes),
InterfaceViolation,
)


VALID_IMPORT_CODE = [
# import statement, import path without suffix
("import a as Foo", "a"),
Expand Down
27 changes: 25 additions & 2 deletions vyper/context/types/meta/event.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from collections import OrderedDict
from typing import List
from typing import Dict, List

from vyper import ast as vy_ast
from vyper.ast.validation import validate_call_args
from vyper.context.types.bases import DataLocation
from vyper.context.types.utils import get_type_from_annotation
from vyper.context.types.utils import (
get_type_from_abi,
get_type_from_annotation,
)
from vyper.context.validation.utils import validate_expected_type
from vyper.exceptions import (
EventDeclarationException,
Expand Down Expand Up @@ -39,6 +42,26 @@ def __init__(self, name: str, arguments: OrderedDict, indexed: List) -> None:
signature = f"{name}({','.join(v.canonical_type for v in arguments.values())})"
self.event_id = int(keccak256(signature.encode()).hex(), 16)

@classmethod
def from_abi(cls, abi: Dict) -> "Event":
"""
Generate an `Event` object from an ABI interface.

Arguments
---------
abi : dict
An object from a JSON ABI interface, representing an event.

Returns
-------
Event object.
"""
members: OrderedDict = OrderedDict()
indexed: List = [i["indexed"] for i in abi["inputs"]]
for item in abi["inputs"]:
members[item["name"]] = get_type_from_abi(item)
return Event(abi["name"], members, indexed)

@classmethod
def from_EventDef(cls, base_node: vy_ast.EventDef) -> "Event":
"""
Expand Down
64 changes: 46 additions & 18 deletions vyper/context/types/meta/interface.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from collections import OrderedDict
from typing import Union
from typing import Dict, Tuple, Union

from vyper import ast as vy_ast
from vyper.ast.validation import validate_call_args
from vyper.context.namespace import get_namespace
from vyper.context.types.bases import DataLocation, MemberTypeDefinition
from vyper.context.types.function import ContractFunction
from vyper.context.types.meta.event import Event
from vyper.context.types.value.address import AddressDefinition
from vyper.context.validation.utils import validate_expected_type
from vyper.exceptions import (
Expand Down Expand Up @@ -38,9 +39,10 @@ class InterfacePrimitive:
_is_callable = True
_as_array = True

def __init__(self, _id, members):
def __init__(self, _id, members, events):
self._id = _id
self.members = members
self.events = events

def __repr__(self):
return f"{self._id} declaration object"
Expand All @@ -66,16 +68,26 @@ def fetch_call_return(self, node: vy_ast.Call) -> InterfaceDefinition:

def validate_implements(self, node: vy_ast.AnnAssign) -> None:
namespace = get_namespace()
# check for missing functions
unimplemented = [
name
for name, type_ in self.members.items()
if name not in namespace["self"].members
or not hasattr(namespace["self"].members[name], "compare_signature")
or not namespace["self"].members[name].compare_signature(type_)
]
# check for missing events
unimplemented += [
name
for name, event in self.events.items()
if name not in namespace
or not isinstance(namespace[name], Event)
or namespace[name].event_id != event.event_id
]
if unimplemented:
missing_str = ", ".join(sorted(unimplemented))
raise InterfaceViolation(
f"Contract does not implement all interface functions: {', '.join(unimplemented)}",
f"Contract does not implement all interface functions or events: {missing_str}",
node,
)

Expand All @@ -97,16 +109,22 @@ def build_primitive_from_abi(name: str, abi: dict) -> InterfacePrimitive:
primitive interface type
"""
members: OrderedDict = OrderedDict()
events: Dict = {}

names = [i["name"] for i in abi if i.get("type") in ("event", "function")]
collisions = set(i for i in names if names.count(i) > 1)
if collisions:
collision_list = ", ".join(sorted(collisions))
raise NamespaceCollision(
f"ABI '{name}' has multiple functions or events with the same name: {collision_list}"
)

for item in [i for i in abi if i.get("type") == "function"]:
func = ContractFunction.from_abi(item)
if func.name in members:
# TODO overloaded functions
raise NamespaceCollision(
f"ABI '{name}' contains multiple functions named '{func.name}'"
)
members[func.name] = func
members[item["name"]] = ContractFunction.from_abi(item)
for item in [i for i in abi if i.get("type") == "event"]:
events[item["name"]] = Event.from_abi(item)

return InterfacePrimitive(name, members)
return InterfacePrimitive(name, members, events)


def build_primitive_from_node(
Expand All @@ -125,22 +143,24 @@ def build_primitive_from_node(
primitive interface type
"""
if isinstance(node, vy_ast.Module):
members = _get_module_functions(node)
members, events = _get_module_definitions(node)
elif isinstance(node, vy_ast.InterfaceDef):
members = _get_class_functions(node)
events = {}
else:
raise StructureException("Invalid syntax for interface definition", node)

namespace = get_namespace()
for func in members.values():
if func.name in namespace:
raise NamespaceCollision(func.name, func.node)
for item in list(members.values()) + list(events.values()):
if item.name in namespace:
raise NamespaceCollision(item.name, item.node)

return InterfacePrimitive(node.name, members)
return InterfacePrimitive(node.name, members, events)


def _get_module_functions(base_node: vy_ast.Module) -> OrderedDict:
def _get_module_definitions(base_node: vy_ast.Module) -> Tuple[OrderedDict, Dict]:
functions: OrderedDict = OrderedDict()
events: Dict = {}
for node in base_node.get_children(vy_ast.FunctionDef):
if "external" in [i.id for i in node.decorator_list]:
func = ContractFunction.from_FunctionDef(node)
Expand All @@ -167,7 +187,15 @@ def _get_module_functions(base_node: vy_ast.Module) -> OrderedDict:
f"Interface contains multiple functions named '{name}'", base_node
)
functions[name] = ContractFunction.from_AnnAssign(node)
return functions
for node in base_node.get_children(vy_ast.EventDef):
name = node.name
if name in functions or name in events:
raise NamespaceCollision(
f"Interface contains multiple objects named '{name}'", base_node
)
events[name] = Event.from_EventDef(node)

return functions, events


def _get_class_functions(base_node: vy_ast.InterfaceDef) -> OrderedDict:
Expand Down