Skip to content

Commit 7ec5d25

Browse files
authored
[torchlib] prims::sum (#2778)
Fix pytorch/pytorch#173074 --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 08ef68a commit 7ec5d25

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

onnxscript/function_libs/torch_lib/ops/prims.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,12 +772,18 @@ def prims_sub(self: TTensor, other: TTensor) -> TTensor:
772772
return op.Sub(self, other)
773773

774774

775+
@torch_op("prims::sum", trace_only=True)
775776
def prims_sum(
776777
inp: TensorType, dims: Optional[Sequence[int]], output_dtype: Optional[int] = None
777778
) -> TensorType:
778779
"""sum(Tensor inp, int[]? dims, *, ScalarType? output_dtype=None) -> Tensor"""
779780

780-
raise NotImplementedError()
781+
result = op.ReduceSum(inp, dims, keepdims=False)
782+
783+
if output_dtype is not None and output_dtype != -1:
784+
result = op.Cast(result, to=output_dtype)
785+
786+
return result
781787

782788

783789
def prims_svd(A: TensorType, full_matrices: bool) -> tuple[TensorType, TensorType, TensorType]:

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,18 @@ def forward(self, x, index, update):
902902
)
903903
_testing.assert_onnx_program(onnx_program)
904904

905+
def test_std_mean(self):
906+
"""Test torch.std_mean which will be decomposed into prims.sum."""
907+
908+
class Model(torch.nn.Module):
909+
def forward(self, x):
910+
return torch.std_mean(x)
911+
912+
onnx_program = torch.onnx.export(
913+
Model(), (torch.rand(10, 10, 10),), dynamo=True, verbose=False
914+
)
915+
_testing.assert_onnx_program(onnx_program)
916+
905917

906918
if __name__ == "__main__":
907919
unittest.main()

0 commit comments

Comments
 (0)