Skip to content

[ONNX] Use topk to export max(dim,keepdim) to onnx #76344

@dashesy

Description

@dashesy

🐛 Describe the bug

Model 1:

class MinTest(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        top_probs, top = x.max(1, keepdim=True)
        return top_probs, top

image

Model 2:

class MinTest(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        top_probs, top = x.topk(1, dim=1)
        return top_probs, top

image

Export:

m = MinTest().eval()

x = torch.rand(1,146,768)

torch.onnx.export(m, x, "/trt/output/min.onnx", verbose=True, input_names=['x'], output_names=["top_probs", "top"], opset_version=14)

There are two issues with exporting "Max(1,keep_dim=True)"

  1. It is not efficient (to call both max and argmax when they could be combined)
  2. onnxruntime CUDA EP does an extra copy to memory for argmax (another issue for onnxruntime here)

Versions

Collecting environment information...
PyTorch version: 1.10.2
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.21.0
Libc version: glibc-2.10

Python version: 3.7.12 | packaged by conda-forge | (default, Oct 26 2021, 06:08:21) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-1074-azure-x86_64-with-debian-bullseye-sid
Is CUDA available: True
CUDA runtime version: 11.4.48
GPU models and configuration: GPU 0: Tesla V100-PCIE-16GB
Nvidia driver version: 470.103.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.2.2
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.21.6
[pip3] torch==1.10.2
[pip3] torchlars==0.1.2
[pip3] torchvision==0.11.3
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.4.0 h8d4b97c_729 conda-forge
[conda] mkl-service 2.4.0 py37h402132d_0 conda-forge
[conda] mkl_fft 1.3.1 py37h3e078e5_1 conda-forge
[conda] mkl_random 1.2.2 py37h219a48f_0 conda-forge
[conda] numpy 1.21.6 pypi_0 pypi
[conda] pytorch 1.10.2 py3.7_cuda11.3_cudnn8.2.0_0 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torchlars 0.1.2 pypi_0 pypi
[conda] torchvision 0.11.3 py37_cu113 pytorch

cc @justinchuby

Metadata

Metadata

Labels

OSS contribution wantedPR from open source contributors welcome to solve this issue.module: onnxRelated to torch.onnxonnx-triagedtriaged by ONNX teamtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

Status

Inbox

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions