Skip to content
Merged
4 changes: 4 additions & 0 deletions mypyc/analysis/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
MethodCall,
Op,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
RegisterOp,
Return,
Expand Down Expand Up @@ -234,6 +235,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill[T]:
def visit_call_c(self, op: CallC) -> GenAndKill[T]:
return self.visit_register_op(op)

def visit_primitive_op(self, op: PrimitiveOp) -> GenAndKill[T]:
return self.visit_register_op(op)

def visit_truncate(self, op: Truncate) -> GenAndKill[T]:
return self.visit_register_op(op)

Expand Down
4 changes: 4 additions & 0 deletions mypyc/analysis/ircheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
MethodCall,
Op,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
Register,
Return,
Expand Down Expand Up @@ -381,6 +382,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
def visit_call_c(self, op: CallC) -> None:
pass

def visit_primitive_op(self, op: PrimitiveOp) -> None:
pass

def visit_truncate(self, op: Truncate) -> None:
pass

Expand Down
4 changes: 4 additions & 0 deletions mypyc/analysis/selfleaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
LoadStatic,
MethodCall,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
Register,
RegisterOp,
Expand Down Expand Up @@ -149,6 +150,9 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill:
def visit_call_c(self, op: CallC) -> GenAndKill:
return self.check_register_op(op)

def visit_primitive_op(self, op: PrimitiveOp) -> GenAndKill:
return self.check_register_op(op)

def visit_truncate(self, op: Truncate) -> GenAndKill:
return CLEAN

Expand Down
6 changes: 6 additions & 0 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
MethodCall,
Op,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
Register,
Return,
Expand Down Expand Up @@ -629,6 +630,11 @@ def visit_call_c(self, op: CallC) -> None:
args = ", ".join(self.reg(arg) for arg in op.args)
self.emitter.emit_line(f"{dest}{op.function_name}({args});")

def visit_primitive_op(self, op: PrimitiveOp) -> None:
raise RuntimeError(
f"unexpected PrimitiveOp {op.desc.name}: they must be lowered before codegen"
)

def visit_truncate(self, op: Truncate) -> None:
dest = self.reg(op)
value = self.reg(op.src)
Expand Down
3 changes: 3 additions & 0 deletions mypyc/codegen/emitmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from mypyc.transform.copy_propagation import do_copy_propagation
from mypyc.transform.exceptions import insert_exception_handling
from mypyc.transform.flag_elimination import do_flag_elimination
from mypyc.transform.lower import lower_ir
from mypyc.transform.refcount import insert_ref_count_opcodes
from mypyc.transform.uninit import insert_uninit_checks

Expand Down Expand Up @@ -235,6 +236,8 @@ def compile_scc_to_ir(
insert_exception_handling(fn)
# Insert refcount handling.
insert_ref_count_opcodes(fn)
# Switch to lower abstraction level IR.
lower_ir(fn, compiler_options)
# Perform optimizations.
do_copy_propagation(fn, compiler_options)
do_flag_elimination(fn, compiler_options)
Expand Down
79 changes: 78 additions & 1 deletion mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,78 @@ def accept(self, visitor: OpVisitor[T]) -> T:
return visitor.visit_method_call(self)


class PrimitiveDescription:
"""Description of a primitive op.

Primitives get lowered into lower-level ops before code generation.

If c_function_name is provided, a primitive will be lowered into a CallC op.
Otherwise custom logic will need to be implemented to transform the
primitive into lower-level ops.
"""

def __init__(
self,
name: str,
arg_types: list[RType],
return_type: RType, # TODO: What about generic?
var_arg_type: RType | None,
truncated_type: RType | None,
c_function_name: str | None,
error_kind: int,
steals: StealsDescription,
is_borrowed: bool,
ordering: list[int] | None,
extra_int_constants: list[tuple[int, RType]],
priority: int,
) -> None:
# Each primitive much have a distinct name, but otherwise they are arbitrary.
self.name: Final = name
self.arg_types: Final = arg_types
self.return_type: Final = return_type
self.var_arg_type: Final = var_arg_type
self.truncated_type: Final = truncated_type
# If non-None, this will map to a call of a C helper function; if None,
# there must be a custom handler function that gets invoked during the lowering
# pass to generate low-level IR for the primitive (in the mypyc.lower package)
self.c_function_name: Final = c_function_name
self.error_kind: Final = error_kind
self.steals: Final = steals
self.is_borrowed: Final = is_borrowed
self.ordering: Final = ordering
self.extra_int_constants: Final = extra_int_constants
self.priority: Final = priority

def __repr__(self) -> str:
return f"<PrimitiveDescription {self.name}>"


class PrimitiveOp(RegisterOp):
"""A higher-level primitive operation.

Some of these have special compiler support. These will be lowered
(transformed) into lower-level IR ops before code generation, and after
reference counting op insertion. Others will be transformed into CallC
ops.

Tagged integer equality is a typical primitive op with non-trivial
lowering. It gets transformed into a tag check, followed by different
code paths for short and long representations.
"""

def __init__(self, args: list[Value], desc: PrimitiveDescription, line: int = -1) -> None:
self.args = args
self.type = desc.return_type
self.error_kind = desc.error_kind
self.desc = desc

def sources(self) -> list[Value]:
return self.args

def accept(self, visitor: OpVisitor[T]) -> T:
return visitor.visit_primitive_op(self)


class LoadErrorValue(RegisterOp):
"""Load an error value.

Expand Down Expand Up @@ -1446,7 +1518,8 @@ class Unborrow(RegisterOp):

error_kind = ERR_NEVER

def __init__(self, src: Value) -> None:
def __init__(self, src: Value, line: int = -1) -> None:
super().__init__(line)
assert src.is_borrowed
self.src = src
self.type = src.type
Expand Down Expand Up @@ -1555,6 +1628,10 @@ def visit_raise_standard_error(self, op: RaiseStandardError) -> T:
def visit_call_c(self, op: CallC) -> T:
raise NotImplementedError

@abstractmethod
def visit_primitive_op(self, op: PrimitiveOp) -> T:
raise NotImplementedError

@abstractmethod
def visit_truncate(self, op: Truncate) -> T:
raise NotImplementedError
Expand Down
17 changes: 17 additions & 0 deletions mypyc/ir/pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
MethodCall,
Op,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
Register,
Return,
Expand Down Expand Up @@ -217,6 +218,22 @@ def visit_call_c(self, op: CallC) -> str:
else:
return self.format("%r = %s(%s)", op, op.function_name, args_str)

def visit_primitive_op(self, op: PrimitiveOp) -> str:
args = []
arg_index = 0
type_arg_index = 0
for arg_type in zip(op.desc.arg_types):
if arg_type:
args.append(self.format("%r", op.args[arg_index]))
arg_index += 1
else:
assert op.type_args
args.append(self.format("%r", op.type_args[type_arg_index]))
type_arg_index += 1

args_str = ", ".join(args)
return self.format("%r = %s %s ", op, op.desc.name, args_str)

def visit_truncate(self, op: Truncate) -> str:
return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type)

Expand Down
6 changes: 5 additions & 1 deletion mypyc/irbuild/ast_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ def maybe_process_conditional_comparison(
self.add_bool_branch(reg, true, false)
else:
# "left op right" for two tagged integers
self.builder.compare_tagged_condition(left, right, op, true, false, e.line)
if op in ("==", "!="):
reg = self.builder.binary_op(left, right, op, e.line)
self.add_bool_branch(reg, true, false)
else:
self.builder.compare_tagged_condition(left, right, op, true, false, e.line)
return True


Expand Down
4 changes: 2 additions & 2 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
set_literal = precompute_set_literal(builder, e.operands[1])
if set_literal is not None:
lhs = e.operands[0]
result = builder.builder.call_c(
result = builder.builder.primitive_op(
set_in_op, [builder.accept(lhs), set_literal], e.line, bool_rprimitive
)
if first_op == "not in":
Expand All @@ -778,7 +778,7 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
borrow_left = is_borrow_friendly_expr(builder, right_expr)
left = builder.accept(left_expr, can_borrow=borrow_left)
right = builder.accept(right_expr, can_borrow=True)
return builder.compare_tagged(left, right, first_op, e.line)
return builder.binary_op(left, right, first_op, e.line)

# TODO: Don't produce an expression when used in conditional context
# All of the trickiness here is due to support for chained conditionals
Expand Down
Loading