Skip to content

Commit cddf350

Browse files
expand interface output comparison logic
1 parent bb89c40 commit cddf350

1 file changed

Lines changed: 26 additions & 8 deletions

File tree

vyper/signatures/interface.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@
3434
InterfaceImports,
3535
SourceCode,
3636
)
37+
from vyper.types.types import (
38+
ByteArrayLike,
39+
TupleLike,
40+
)
3741

3842

3943
# Populate built-in interfaces.
@@ -64,12 +68,12 @@ def abi_type_to_ast(atype):
6468
elif atype == 'bytes':
6569
return ast.Subscript(
6670
value=ast.Name(id='bytes'),
67-
slice=ast.Index(256)
71+
slice=ast.Index(value=ast.Num(n=256))
6872
)
6973
elif atype == 'string':
7074
return ast.Subscript(
7175
value=ast.Name(id='string'),
72-
slice=ast.Index(256)
76+
slice=ast.Index(value=ast.Num(n=256))
7377
)
7478
else:
7579
raise ParserException(f'Type {atype} not supported by vyper.')
@@ -294,16 +298,14 @@ def check_valid_contract_interface(global_ctx, contract_sigs):
294298

295299
for sig, func_sig in contract_sigs.items():
296300
if isinstance(func_sig, FunctionSignature):
297-
# Remove units, as inteface signatures should not enforce units.
301+
if sig not in funcs_left or func_sig.private:
302+
continue
303+
# Remove units, as interface signatures should not enforce units.
298304
clean_sig_output_type = func_sig.output_type
299305
if func_sig.output_type:
300306
clean_sig_output_type = copy.deepcopy(func_sig.output_type)
301307
clean_sig_output_type.unit = {}
302-
if (
303-
sig in funcs_left and # noqa: W504
304-
not func_sig.private and # noqa: W504
305-
funcs_left[sig].output_type == clean_sig_output_type
306-
):
308+
if _compare_outputs(funcs_left[sig].output_type, clean_sig_output_type):
307309
del funcs_left[sig]
308310
if isinstance(func_sig, EventSignature) and func_sig.sig in funcs_left:
309311
del funcs_left[func_sig.sig]
@@ -329,3 +331,19 @@ def check_valid_contract_interface(global_ctx, contract_sigs):
329331
err_join = "\n\t".join(missing_events)
330332
error_message += f'Missing interface events:\n\t{err_join}'
331333
raise StructureException(error_message)
334+
335+
336+
def _compare_outputs(a, b):
337+
if isinstance(a, TupleLike):
338+
# for tuples and structs, compare the length and individual members
339+
if type(a) != type(b):
340+
return False
341+
if len(a.tuple_members()) != len(b.tuple_members()):
342+
return False
343+
compare = zip(a.tuple_members(), b.tuple_members())
344+
return next((False for i in compare if not _compare_outputs(*i)), True)
345+
if isinstance(a, ByteArrayLike):
346+
# for string and bytes, only the type matters (not the length)
347+
return type(a) == type(b)
348+
# for all other types, check strict equality
349+
return a == b

0 commit comments

Comments
 (0)