Skip to content

Commit 8176959

Browse files
danielvegamyhredsashidh
authored andcommitted
[mxfp8 torch._scaled_grouped_mm] fix meta registration for 3d tensor (pytorch#162765)
Meta registration checks for torch._scaled_grouped_mm has a bug for 3d "B" tensors. Namely, the scale shape for such a tensor should be 2d with shape (G, blocked_K * blocked_N), but it currently enforces an expected 3d shape of (G, blocked_K, blocked_N). See Blas.cpp for correct validation logic [here](https://github.com/pytorch/pytorch/blob/8e217a9f6dc81e3d12697b04c3e611d82d9d866a/aten/src/ATen/native/cuda/Blas.cpp#L1622). Pull Request resolved: pytorch#162765 Approved by: https://github.com/ngimel
1 parent 5d833cc commit 8176959

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

torch/_meta_registrations.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7547,18 +7547,18 @@ def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1):
75477547
# scale sizes at compile time.
75487548
if is_mxfp8:
75497549
torch._check(
7550-
mat.ndim == scale.ndim,
7551-
lambda: f"For MXFP8, scale should have same number of dimensions as target tensor, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950
7550+
scale.ndim == mat.ndim - 1,
7551+
lambda: f"For MXFP8, 3d tensor should have 2d scales, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950
75527552
)
75537553
# TODO: This logic only holds for RHS tensor in 2d-3d case.
75547554
# We'll need to update it to handle LHS 3d tensor in 3d-2d and 3d-3d cases.
7555-
G, K, N = scale.shape
7555+
G, K, N = mat.shape
75567556
block_size = 32
75577557
blocked_K = round_up(K / block_size, 4)
75587558
blocked_N = round_up(N, 128)
75597559
torch._check(
7560-
mat.shape[-2] == blocked_K and mat.shape[-1] == blocked_N,
7561-
lambda: f"For MXFP8, expected mat.shape={mat.shape} to have scale shape of ({G},{blocked_K},{blocked_N}), but got {scale.shape}", # noqa: B950
7560+
scale.shape[0] == G and scale.shape[1] == blocked_K * blocked_N,
7561+
lambda: f"For MXFP8, expected mat.shape={mat.shape} to have scale shape of ({G},{blocked_K * blocked_N}), but got {scale.shape}", # noqa: B950
75627562
)
75637563
else:
75647564
torch._check(

0 commit comments

Comments
 (0)