Skip to content

Commit da967e3

Browse files
authored
[torchlib] Fix and implement overloads for aten::remainder (microsoft#2727)
Previously the Scalar_Tensor overload will fail because the first arg will be a scalar which does not have dtype. Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent a3883a6 commit da967e3

1 file changed

Lines changed: 25 additions & 5 deletions

File tree

  • onnxscript/function_libs/torch_lib/ops

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7657,11 +7657,8 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:
76577657
raise NotImplementedError()
76587658

76597659

7660-
@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"), trace_only=True)
7661-
def aten_remainder(self: TTensor, other: TTensor) -> TTensor:
7662-
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""
7663-
7664-
if self.dtype.is_integer():
7660+
def _aten_remainder(self: TTensor, other: TTensor, integer: bool) -> TTensor:
7661+
if integer:
76657662
return op.Mod(self, other)
76667663

76677664
# TODO(justinchuby): Improve fp16 precision by following the logic in
@@ -7673,6 +7670,29 @@ def aten_remainder(self: TTensor, other: TTensor) -> TTensor:
76737670
return op.Sub(self, op.Mul(rounded_quotient, other))
76747671

76757672

7673+
@torch_op("aten::remainder.Tensor", trace_only=True)
7674+
def aten_remainder(self: TTensor, other: TTensor) -> TTensor:
7675+
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""
7676+
7677+
return _aten_remainder(self, other, integer=self.dtype.is_integer())
7678+
7679+
7680+
@torch_op("aten::remainder.Scalar", trace_only=True)
7681+
def aten_remainder_scalar(self: TTensor, other: float) -> TTensor:
7682+
"""remainder.Scalar(Tensor self, Scalar other) -> Tensor"""
7683+
7684+
other_tensor = ir.tensor(other, dtype=self.dtype)
7685+
return _aten_remainder(self, other_tensor, integer=self.dtype.is_integer())
7686+
7687+
7688+
@torch_op("aten::remainder.Scalar_Tensor", trace_only=True)
7689+
def aten_remainder_scalar_tensor(self: float, other: TTensor) -> TTensor:
7690+
"""remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor"""
7691+
7692+
self_tensor = ir.tensor(self, dtype=other.dtype)
7693+
return _aten_remainder(self_tensor, other, integer=other.dtype.is_integer())
7694+
7695+
76767696
@torch_op("_operator::mod", trace_only=True)
76777697
def operator_mod(self: TTensor, other: TTensor) -> TTensor:
76787698
# Modulus operator % on SymInt

0 commit comments

Comments
 (0)