Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7657,11 +7657,8 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"), trace_only=True)
def aten_remainder(self: TTensor, other: TTensor) -> TTensor:
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""

if self.dtype.is_integer():
def _aten_remainder(self: TTensor, other: TTensor, integer: bool) -> TTensor:
if integer:
return op.Mod(self, other)

# TODO(justinchuby): Improve fp16 precision by following the logic in
Expand All @@ -7673,6 +7670,29 @@ def aten_remainder(self: TTensor, other: TTensor) -> TTensor:
return op.Sub(self, op.Mul(rounded_quotient, other))


@torch_op("aten::remainder.Tensor", trace_only=True)
def aten_remainder(self: TTensor, other: TTensor) -> TTensor:
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""

return _aten_remainder(self, other, integer=self.dtype.is_integer())


@torch_op("aten::remainder.Scalar", trace_only=True)
def aten_remainder_scalar(self: TTensor, other: float) -> TTensor:
"""remainder.Scalar(Tensor self, Scalar other) -> Tensor"""

other_tensor = ir.tensor(other, dtype=self.dtype)
return _aten_remainder(self, other_tensor, integer=self.dtype.is_integer())


@torch_op("aten::remainder.Scalar_Tensor", trace_only=True)
def aten_remainder_scalar_tensor(self: float, other: TTensor) -> TTensor:
"""remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""

self_tensor = ir.tensor(self, dtype=other.dtype)
return _aten_remainder(self_tensor, other, integer=other.dtype.is_integer())


@torch_op("_operator::mod", trace_only=True)
def operator_mod(self: TTensor, other: TTensor) -> TTensor:
# Modulus operator % on SymInt
Expand Down
Loading