🐛 Describe the bug
Model 1:
class MinTest(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
top_probs, top = x.max(1, keepdim=True)
return top_probs, top

Model 2:
class MinTest(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
top_probs, top = x.topk(1, dim=1)
return top_probs, top

Export:
m = MinTest().eval()
x = torch.rand(1,146,768)
torch.onnx.export(m, x, "/trt/output/min.onnx", verbose=True, input_names=['x'], output_names=["top_probs", "top"], opset_version=14)
There are two issues with exporting "Max(1,keep_dim=True)"
- It is not efficient (to call both max and argmax when they could be combined)
- onnxruntime CUDA EP does an extra copy to memory for argmax (another issue for onnxruntime here)
Versions
Collecting environment information...
PyTorch version: 1.10.2
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.21.0
Libc version: glibc-2.10
Python version: 3.7.12 | packaged by conda-forge | (default, Oct 26 2021, 06:08:21) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-1074-azure-x86_64-with-debian-bullseye-sid
Is CUDA available: True
CUDA runtime version: 11.4.48
GPU models and configuration: GPU 0: Tesla V100-PCIE-16GB
Nvidia driver version: 470.103.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.2.2
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.21.6
[pip3] torch==1.10.2
[pip3] torchlars==0.1.2
[pip3] torchvision==0.11.3
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.4.0 h8d4b97c_729 conda-forge
[conda] mkl-service 2.4.0 py37h402132d_0 conda-forge
[conda] mkl_fft 1.3.1 py37h3e078e5_1 conda-forge
[conda] mkl_random 1.2.2 py37h219a48f_0 conda-forge
[conda] numpy 1.21.6 pypi_0 pypi
[conda] pytorch 1.10.2 py3.7_cuda11.3_cudnn8.2.0_0 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torchlars 0.1.2 pypi_0 pypi
[conda] torchvision 0.11.3 py37_cu113 pytorch
cc @justinchuby
🐛 Describe the bug
Model 1:
Model 2:
Export:
There are two issues with exporting "Max(1,keep_dim=True)"
Versions
Collecting environment information...
PyTorch version: 1.10.2
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.21.0
Libc version: glibc-2.10
Python version: 3.7.12 | packaged by conda-forge | (default, Oct 26 2021, 06:08:21) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-1074-azure-x86_64-with-debian-bullseye-sid
Is CUDA available: True
CUDA runtime version: 11.4.48
GPU models and configuration: GPU 0: Tesla V100-PCIE-16GB
Nvidia driver version: 470.103.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.2.2
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.21.6
[pip3] torch==1.10.2
[pip3] torchlars==0.1.2
[pip3] torchvision==0.11.3
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.4.0 h8d4b97c_729 conda-forge
[conda] mkl-service 2.4.0 py37h402132d_0 conda-forge
[conda] mkl_fft 1.3.1 py37h3e078e5_1 conda-forge
[conda] mkl_random 1.2.2 py37h219a48f_0 conda-forge
[conda] numpy 1.21.6 pypi_0 pypi
[conda] pytorch 1.10.2 py3.7_cuda11.3_cudnn8.2.0_0 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torchlars 0.1.2 pypi_0 pypi
[conda] torchvision 0.11.3 py37_cu113 pytorch
cc @justinchuby