diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 099b786d74..254378bf09 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -7657,11 +7657,8 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType: raise NotImplementedError() -@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"), trace_only=True) -def aten_remainder(self: TTensor, other: TTensor) -> TTensor: - """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" - - if self.dtype.is_integer(): +def _aten_remainder(self: TTensor, other: TTensor, integer: bool) -> TTensor: + if integer: return op.Mod(self, other) # TODO(justinchuby): Improve fp16 precision by following the logic in @@ -7673,6 +7670,29 @@ def aten_remainder(self: TTensor, other: TTensor) -> TTensor: return op.Sub(self, op.Mul(rounded_quotient, other)) +@torch_op("aten::remainder.Tensor", trace_only=True) +def aten_remainder(self: TTensor, other: TTensor) -> TTensor: + """remainder.Tensor(Tensor self, Tensor other) -> Tensor""" + + return _aten_remainder(self, other, integer=self.dtype.is_integer()) + + +@torch_op("aten::remainder.Scalar", trace_only=True) +def aten_remainder_scalar(self: TTensor, other: float) -> TTensor: + """remainder.Scalar(Tensor self, Scalar other) -> Tensor""" + + other_tensor = ir.tensor(other, dtype=self.dtype) + return _aten_remainder(self, other_tensor, integer=self.dtype.is_integer()) + + +@torch_op("aten::remainder.Scalar_Tensor", trace_only=True) +def aten_remainder_scalar_tensor(self: float, other: TTensor) -> TTensor: + """remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor""" + + self_tensor = ir.tensor(self, dtype=other.dtype) + return _aten_remainder(self_tensor, other, integer=other.dtype.is_integer()) + + @torch_op("_operator::mod", trace_only=True) def operator_mod(self: TTensor, other: TTensor) -> TTensor: # Modulus operator % on SymInt