Commit 8176959
[mxfp8 torch._scaled_grouped_mm] fix meta registration for 3d tensor (pytorch#162765)
Meta registration checks for torch._scaled_grouped_mm has a bug for 3d "B" tensors. Namely, the scale shape for such a tensor should be 2d with shape (G, blocked_K * blocked_N), but it currently enforces an expected 3d shape of (G, blocked_K, blocked_N).
See Blas.cpp for correct validation logic [here](https://github.com/pytorch/pytorch/blob/8e217a9f6dc81e3d12697b04c3e611d82d9d866a/aten/src/ATen/native/cuda/Blas.cpp#L1622).
Pull Request resolved: pytorch#162765
Approved by: https://github.com/ngimel1 parent 5d833cc commit 8176959
1 file changed
Lines changed: 5 additions & 5 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
7547 | 7547 | | |
7548 | 7548 | | |
7549 | 7549 | | |
7550 | | - | |
7551 | | - | |
| 7550 | + | |
| 7551 | + | |
7552 | 7552 | | |
7553 | 7553 | | |
7554 | 7554 | | |
7555 | | - | |
| 7555 | + | |
7556 | 7556 | | |
7557 | 7557 | | |
7558 | 7558 | | |
7559 | 7559 | | |
7560 | | - | |
7561 | | - | |
| 7560 | + | |
| 7561 | + | |
7562 | 7562 | | |
7563 | 7563 | | |
7564 | 7564 | | |
| |||
0 commit comments