Skip to content

Commit 4eb69d9

Browse files
zoranzhaometa-codesync[bot]
authored andcommitted
Standalone test for CuTeDSL bmm SM90 kernel (#1057)
Summary: Pull Request resolved: #1057 Reviewed By: sevenEng, jijunyan Differential Revision: D95895188 fbshipit-source-id: ec68adb704033ba9db369551efbda50511b910ce
1 parent b429fbe commit 4eb69d9

1 file changed

Lines changed: 355 additions & 0 deletions

File tree

examples/test_cutedsl_bmm_sm90.py

Lines changed: 355 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,355 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""
16+
Direct test harness for BmmSm90Kernel (cutedsl_bmm_sm90.py).
17+
18+
Directly invokes the CuTeDSL SM90 BMM kernel via cute.compile + execute,
19+
bypassing the full AITemplate compilation pipeline. Validates results
20+
against PyTorch reference across all layout/variant combinations.
21+
22+
The kernel expects batch-last tensor ordering: A(M,K,B), B(N,K,B), C(M,N,B).
23+
This test creates tensors in PyTorch batch-first format, computes the
24+
reference, then permutes to batch-last for the CuTe kernel.
25+
26+
Requires SM90+ (H100 / Hopper).
27+
28+
Run with:
29+
buck run fbcode//aitemplate/AITemplate/examples:test_cutedsl_bmm_sm90
30+
"""
31+
32+
import sys
33+
34+
import cuda.bindings.driver as cuda
35+
import cutlass.cute as cute
36+
import torch
37+
from aitemplate.backend.cuda.gemm_universal.cutedsl_bmm_sm90 import BmmSm90Kernel
38+
from cutlass.cute.runtime import from_dlpack
39+
40+
41+
# =============================================================================
42+
# Helpers
43+
# =============================================================================
44+
45+
46+
def make_cute_tensor(t):
47+
"""Convert a PyTorch CUDA tensor to a CuTe tensor with dynamic modes.
48+
49+
Marks the innermost (stride-1) mode as dynamic. For compact tensors,
50+
this single call makes all dependent strides dynamic as well.
51+
52+
Skipped when any dimension has size 1 (e.g. B=1 batch), because
53+
mark_compact_shape_dynamic cannot verify compact stride ordering
54+
for size-1 dimensions in permuted views.
55+
"""
56+
ct = from_dlpack(t, assumed_align=16)
57+
if all(s > 1 for s in t.shape):
58+
innermost_mode = t.dim_order()[0]
59+
ct = ct.mark_compact_shape_dynamic(
60+
mode=innermost_mode,
61+
stride_order=t.dim_order(),
62+
divisibility=1,
63+
)
64+
return ct
65+
66+
67+
def to_batch_last_a(t, a_row_major):
68+
"""Permute A from batch-first to batch-last (M, K, B).
69+
70+
Row-major A: (B, M, K) -> (M, K, B) via permute(1, 2, 0)
71+
Col-major A: (B, K, M) -> (M, K, B) via permute(2, 1, 0)
72+
"""
73+
return t.permute(1, 2, 0) if a_row_major else t.permute(2, 1, 0)
74+
75+
76+
def to_batch_last_b(t, b_row_major):
77+
"""Permute B from batch-first to batch-last (N, K, B).
78+
79+
Row-major B: (B, K, N) -> (N, K, B) via permute(2, 1, 0)
80+
Col-major B: (B, N, K) -> (N, K, B) via permute(1, 2, 0)
81+
"""
82+
return t.permute(2, 1, 0) if b_row_major else t.permute(1, 2, 0)
83+
84+
85+
def to_batch_last_c(t):
86+
"""Permute C/D from batch-first (B, M, N) to batch-last (M, N, B)."""
87+
return t.permute(1, 2, 0)
88+
89+
90+
def get_cu_stream():
91+
"""Get CUDA driver stream from current PyTorch stream."""
92+
return cuda.CUstream(torch.cuda.current_stream().cuda_stream)
93+
94+
95+
# =============================================================================
96+
# Layout configs: (name, a_row_major, b_row_major, A_shape_fn, B_shape_fn, ref_fn)
97+
# =============================================================================
98+
99+
100+
def _make_configs():
101+
"""Build layout test configs.
102+
103+
Each config: (name, a_row_major, b_row_major,
104+
A_shape(B,M,N,K), B_shape(B,M,N,K), ref_fn(a,b))
105+
"""
106+
return [
107+
(
108+
"rrr",
109+
True,
110+
True,
111+
lambda B, M, N, K: (B, M, K),
112+
lambda B, M, N, K: (B, K, N),
113+
lambda a, b: torch.bmm(a, b),
114+
),
115+
(
116+
"ccr",
117+
False,
118+
False,
119+
lambda B, M, N, K: (B, K, M),
120+
lambda B, M, N, K: (B, N, K),
121+
lambda a, b: torch.bmm(a.transpose(-2, -1), b.transpose(-2, -1)),
122+
),
123+
(
124+
"rcr",
125+
True,
126+
False,
127+
lambda B, M, N, K: (B, M, K),
128+
lambda B, M, N, K: (B, N, K),
129+
lambda a, b: torch.bmm(a, b.transpose(-2, -1)),
130+
),
131+
]
132+
133+
134+
# Shape configs: (name, B, M, N, K)
135+
_SHAPES = [
136+
("aligned", 2, 256, 512, 128),
137+
("medium", 4, 512, 256, 256),
138+
("large_batch", 16, 128, 128, 64),
139+
("small", 1, 128, 128, 64),
140+
]
141+
142+
143+
# =============================================================================
144+
# Core test runner
145+
# =============================================================================
146+
147+
148+
def run_test(
149+
name,
150+
a_row_major,
151+
b_row_major,
152+
has_d,
153+
B,
154+
M,
155+
N,
156+
K,
157+
a_shape,
158+
b_shape,
159+
ref_fn,
160+
atol=1e-2,
161+
rtol=1e-2,
162+
):
163+
"""Run a single BmmSm90Kernel test case."""
164+
add_str = "_add" if has_d else ""
165+
test_id = f"bmm_{name}{add_str} B={B} M={M} N={N} K={K}"
166+
167+
# Create kernel
168+
kernel = BmmSm90Kernel(
169+
tile_m=128,
170+
tile_n=128,
171+
a_row_major=a_row_major,
172+
b_row_major=b_row_major,
173+
has_d=has_d,
174+
)
175+
176+
# Create PyTorch tensors (batch-first, standard PyTorch convention)
177+
a_pt = torch.randn(*a_shape, device="cuda", dtype=torch.float16)
178+
b_pt = torch.randn(*b_shape, device="cuda", dtype=torch.float16)
179+
c_pt = torch.zeros(B, M, N, device="cuda", dtype=torch.float16)
180+
d_pt = (
181+
torch.randn(B, M, N, device="cuda", dtype=torch.float16)
182+
if has_d
183+
else torch.zeros(B, M, N, device="cuda", dtype=torch.float16)
184+
)
185+
186+
# PyTorch reference (batch-first)
187+
y_ref = ref_fn(a_pt, b_pt)
188+
if has_d:
189+
y_ref = y_ref + d_pt
190+
191+
# Permute to batch-last for the kernel: A(M,K,B), B(N,K,B), C/D(M,N,B).
192+
# These are views sharing memory with the batch-first tensors.
193+
a_bl = to_batch_last_a(a_pt, a_row_major)
194+
b_bl = to_batch_last_b(b_pt, b_row_major)
195+
c_bl = to_batch_last_c(c_pt)
196+
d_bl = to_batch_last_c(d_pt)
197+
198+
# Convert to CuTe tensors
199+
a_cute = make_cute_tensor(a_bl)
200+
b_cute = make_cute_tensor(b_bl)
201+
c_cute = make_cute_tensor(c_bl)
202+
d_cute = make_cute_tensor(d_bl)
203+
204+
cu_stream = get_cu_stream()
205+
206+
# JIT compile
207+
compiled = cute.compile(
208+
kernel,
209+
a_cute,
210+
b_cute,
211+
c_cute,
212+
d_cute,
213+
B,
214+
M,
215+
N,
216+
K,
217+
cu_stream,
218+
)
219+
220+
# Execute
221+
compiled(
222+
a_cute,
223+
b_cute,
224+
c_cute,
225+
d_cute,
226+
B,
227+
M,
228+
N,
229+
K,
230+
cu_stream,
231+
)
232+
torch.cuda.synchronize()
233+
234+
# Validate — c_pt (batch-first) shares memory with c_bl (batch-last),
235+
# so it already has the kernel output in batch-first layout.
236+
max_diff = (c_pt - y_ref).abs().max().item()
237+
passed = torch.allclose(c_pt, y_ref, atol=atol, rtol=rtol)
238+
239+
status = "PASS" if passed else "FAIL"
240+
print(f" [{status}] {test_id} (max_diff={max_diff:.6f})")
241+
return passed
242+
243+
244+
# =============================================================================
245+
# Main
246+
# =============================================================================
247+
248+
249+
def main():
250+
print("=" * 70)
251+
print("BmmSm90Kernel Direct Test Harness")
252+
print("=" * 70)
253+
254+
if not torch.cuda.is_available():
255+
print("ERROR: CUDA GPU required")
256+
sys.exit(1)
257+
258+
cc_major, cc_minor = torch.cuda.get_device_capability(0)
259+
gpu_arch = cc_major * 10 + cc_minor
260+
gpu_name = torch.cuda.get_device_name(0)
261+
print(f"GPU: {gpu_name} (SM{gpu_arch})")
262+
263+
if gpu_arch < 90:
264+
print(f"ERROR: SM90+ required for Hopper TMA/WGMMA, got SM{gpu_arch}")
265+
sys.exit(1)
266+
267+
configs = _make_configs()
268+
total = 0
269+
passed = 0
270+
failed_tests = []
271+
272+
# Test plain BMM (has_d=False) for all layouts and shapes
273+
for (
274+
layout_name,
275+
a_row_major,
276+
b_row_major,
277+
a_shape_fn,
278+
b_shape_fn,
279+
ref_fn,
280+
) in configs:
281+
print(f"\n--- bmm_{layout_name} (plain) ---")
282+
for shape_name, B, M, N, K in _SHAPES:
283+
a_shape = a_shape_fn(B, M, N, K)
284+
b_shape = b_shape_fn(B, M, N, K)
285+
total += 1
286+
ok = run_test(
287+
layout_name,
288+
a_row_major,
289+
b_row_major,
290+
has_d=False,
291+
B=B,
292+
M=M,
293+
N=N,
294+
K=K,
295+
a_shape=a_shape,
296+
b_shape=b_shape,
297+
ref_fn=ref_fn,
298+
)
299+
if ok:
300+
passed += 1
301+
else:
302+
failed_tests.append(
303+
f"bmm_{layout_name} {shape_name} B={B} M={M} N={N} K={K}"
304+
)
305+
306+
# Test BMM + residual add (has_d=True) for all layouts and shapes
307+
for (
308+
layout_name,
309+
a_row_major,
310+
b_row_major,
311+
a_shape_fn,
312+
b_shape_fn,
313+
ref_fn,
314+
) in configs:
315+
print(f"\n--- bmm_{layout_name}_add (residual) ---")
316+
for shape_name, B, M, N, K in _SHAPES:
317+
a_shape = a_shape_fn(B, M, N, K)
318+
b_shape = b_shape_fn(B, M, N, K)
319+
total += 1
320+
ok = run_test(
321+
layout_name,
322+
a_row_major,
323+
b_row_major,
324+
has_d=True,
325+
B=B,
326+
M=M,
327+
N=N,
328+
K=K,
329+
a_shape=a_shape,
330+
b_shape=b_shape,
331+
ref_fn=ref_fn,
332+
)
333+
if ok:
334+
passed += 1
335+
else:
336+
failed_tests.append(
337+
f"bmm_{layout_name}_add {shape_name} B={B} M={M} N={N} K={K}"
338+
)
339+
340+
# Summary
341+
print("\n" + "=" * 70)
342+
print(f"Results: {passed}/{total} passed")
343+
if failed_tests:
344+
print(f"\nFailed tests:")
345+
for t in failed_tests:
346+
print(f" - {t}")
347+
else:
348+
print("All tests passed!")
349+
print("=" * 70)
350+
351+
sys.exit(0 if passed == total else 1)
352+
353+
354+
if __name__ == "__main__":
355+
main()

0 commit comments

Comments
 (0)