|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# |
| 15 | +""" |
| 16 | +Direct test harness for BmmSm90Kernel (cutedsl_bmm_sm90.py). |
| 17 | +
|
| 18 | +Directly invokes the CuTeDSL SM90 BMM kernel via cute.compile + execute, |
| 19 | +bypassing the full AITemplate compilation pipeline. Validates results |
| 20 | +against PyTorch reference across all layout/variant combinations. |
| 21 | +
|
| 22 | +The kernel expects batch-last tensor ordering: A(M,K,B), B(N,K,B), C(M,N,B). |
| 23 | +This test creates tensors in PyTorch batch-first format, computes the |
| 24 | +reference, then permutes to batch-last for the CuTe kernel. |
| 25 | +
|
| 26 | +Requires SM90+ (H100 / Hopper). |
| 27 | +
|
| 28 | +Run with: |
| 29 | + buck run fbcode//aitemplate/AITemplate/examples:test_cutedsl_bmm_sm90 |
| 30 | +""" |
| 31 | + |
| 32 | +import sys |
| 33 | + |
| 34 | +import cuda.bindings.driver as cuda |
| 35 | +import cutlass.cute as cute |
| 36 | +import torch |
| 37 | +from aitemplate.backend.cuda.gemm_universal.cutedsl_bmm_sm90 import BmmSm90Kernel |
| 38 | +from cutlass.cute.runtime import from_dlpack |
| 39 | + |
| 40 | + |
| 41 | +# ============================================================================= |
| 42 | +# Helpers |
| 43 | +# ============================================================================= |
| 44 | + |
| 45 | + |
| 46 | +def make_cute_tensor(t): |
| 47 | + """Convert a PyTorch CUDA tensor to a CuTe tensor with dynamic modes. |
| 48 | +
|
| 49 | + Marks the innermost (stride-1) mode as dynamic. For compact tensors, |
| 50 | + this single call makes all dependent strides dynamic as well. |
| 51 | +
|
| 52 | + Skipped when any dimension has size 1 (e.g. B=1 batch), because |
| 53 | + mark_compact_shape_dynamic cannot verify compact stride ordering |
| 54 | + for size-1 dimensions in permuted views. |
| 55 | + """ |
| 56 | + ct = from_dlpack(t, assumed_align=16) |
| 57 | + if all(s > 1 for s in t.shape): |
| 58 | + innermost_mode = t.dim_order()[0] |
| 59 | + ct = ct.mark_compact_shape_dynamic( |
| 60 | + mode=innermost_mode, |
| 61 | + stride_order=t.dim_order(), |
| 62 | + divisibility=1, |
| 63 | + ) |
| 64 | + return ct |
| 65 | + |
| 66 | + |
| 67 | +def to_batch_last_a(t, a_row_major): |
| 68 | + """Permute A from batch-first to batch-last (M, K, B). |
| 69 | +
|
| 70 | + Row-major A: (B, M, K) -> (M, K, B) via permute(1, 2, 0) |
| 71 | + Col-major A: (B, K, M) -> (M, K, B) via permute(2, 1, 0) |
| 72 | + """ |
| 73 | + return t.permute(1, 2, 0) if a_row_major else t.permute(2, 1, 0) |
| 74 | + |
| 75 | + |
| 76 | +def to_batch_last_b(t, b_row_major): |
| 77 | + """Permute B from batch-first to batch-last (N, K, B). |
| 78 | +
|
| 79 | + Row-major B: (B, K, N) -> (N, K, B) via permute(2, 1, 0) |
| 80 | + Col-major B: (B, N, K) -> (N, K, B) via permute(1, 2, 0) |
| 81 | + """ |
| 82 | + return t.permute(2, 1, 0) if b_row_major else t.permute(1, 2, 0) |
| 83 | + |
| 84 | + |
| 85 | +def to_batch_last_c(t): |
| 86 | + """Permute C/D from batch-first (B, M, N) to batch-last (M, N, B).""" |
| 87 | + return t.permute(1, 2, 0) |
| 88 | + |
| 89 | + |
| 90 | +def get_cu_stream(): |
| 91 | + """Get CUDA driver stream from current PyTorch stream.""" |
| 92 | + return cuda.CUstream(torch.cuda.current_stream().cuda_stream) |
| 93 | + |
| 94 | + |
| 95 | +# ============================================================================= |
| 96 | +# Layout configs: (name, a_row_major, b_row_major, A_shape_fn, B_shape_fn, ref_fn) |
| 97 | +# ============================================================================= |
| 98 | + |
| 99 | + |
| 100 | +def _make_configs(): |
| 101 | + """Build layout test configs. |
| 102 | +
|
| 103 | + Each config: (name, a_row_major, b_row_major, |
| 104 | + A_shape(B,M,N,K), B_shape(B,M,N,K), ref_fn(a,b)) |
| 105 | + """ |
| 106 | + return [ |
| 107 | + ( |
| 108 | + "rrr", |
| 109 | + True, |
| 110 | + True, |
| 111 | + lambda B, M, N, K: (B, M, K), |
| 112 | + lambda B, M, N, K: (B, K, N), |
| 113 | + lambda a, b: torch.bmm(a, b), |
| 114 | + ), |
| 115 | + ( |
| 116 | + "ccr", |
| 117 | + False, |
| 118 | + False, |
| 119 | + lambda B, M, N, K: (B, K, M), |
| 120 | + lambda B, M, N, K: (B, N, K), |
| 121 | + lambda a, b: torch.bmm(a.transpose(-2, -1), b.transpose(-2, -1)), |
| 122 | + ), |
| 123 | + ( |
| 124 | + "rcr", |
| 125 | + True, |
| 126 | + False, |
| 127 | + lambda B, M, N, K: (B, M, K), |
| 128 | + lambda B, M, N, K: (B, N, K), |
| 129 | + lambda a, b: torch.bmm(a, b.transpose(-2, -1)), |
| 130 | + ), |
| 131 | + ] |
| 132 | + |
| 133 | + |
| 134 | +# Shape configs: (name, B, M, N, K) |
| 135 | +_SHAPES = [ |
| 136 | + ("aligned", 2, 256, 512, 128), |
| 137 | + ("medium", 4, 512, 256, 256), |
| 138 | + ("large_batch", 16, 128, 128, 64), |
| 139 | + ("small", 1, 128, 128, 64), |
| 140 | +] |
| 141 | + |
| 142 | + |
| 143 | +# ============================================================================= |
| 144 | +# Core test runner |
| 145 | +# ============================================================================= |
| 146 | + |
| 147 | + |
| 148 | +def run_test( |
| 149 | + name, |
| 150 | + a_row_major, |
| 151 | + b_row_major, |
| 152 | + has_d, |
| 153 | + B, |
| 154 | + M, |
| 155 | + N, |
| 156 | + K, |
| 157 | + a_shape, |
| 158 | + b_shape, |
| 159 | + ref_fn, |
| 160 | + atol=1e-2, |
| 161 | + rtol=1e-2, |
| 162 | +): |
| 163 | + """Run a single BmmSm90Kernel test case.""" |
| 164 | + add_str = "_add" if has_d else "" |
| 165 | + test_id = f"bmm_{name}{add_str} B={B} M={M} N={N} K={K}" |
| 166 | + |
| 167 | + # Create kernel |
| 168 | + kernel = BmmSm90Kernel( |
| 169 | + tile_m=128, |
| 170 | + tile_n=128, |
| 171 | + a_row_major=a_row_major, |
| 172 | + b_row_major=b_row_major, |
| 173 | + has_d=has_d, |
| 174 | + ) |
| 175 | + |
| 176 | + # Create PyTorch tensors (batch-first, standard PyTorch convention) |
| 177 | + a_pt = torch.randn(*a_shape, device="cuda", dtype=torch.float16) |
| 178 | + b_pt = torch.randn(*b_shape, device="cuda", dtype=torch.float16) |
| 179 | + c_pt = torch.zeros(B, M, N, device="cuda", dtype=torch.float16) |
| 180 | + d_pt = ( |
| 181 | + torch.randn(B, M, N, device="cuda", dtype=torch.float16) |
| 182 | + if has_d |
| 183 | + else torch.zeros(B, M, N, device="cuda", dtype=torch.float16) |
| 184 | + ) |
| 185 | + |
| 186 | + # PyTorch reference (batch-first) |
| 187 | + y_ref = ref_fn(a_pt, b_pt) |
| 188 | + if has_d: |
| 189 | + y_ref = y_ref + d_pt |
| 190 | + |
| 191 | + # Permute to batch-last for the kernel: A(M,K,B), B(N,K,B), C/D(M,N,B). |
| 192 | + # These are views sharing memory with the batch-first tensors. |
| 193 | + a_bl = to_batch_last_a(a_pt, a_row_major) |
| 194 | + b_bl = to_batch_last_b(b_pt, b_row_major) |
| 195 | + c_bl = to_batch_last_c(c_pt) |
| 196 | + d_bl = to_batch_last_c(d_pt) |
| 197 | + |
| 198 | + # Convert to CuTe tensors |
| 199 | + a_cute = make_cute_tensor(a_bl) |
| 200 | + b_cute = make_cute_tensor(b_bl) |
| 201 | + c_cute = make_cute_tensor(c_bl) |
| 202 | + d_cute = make_cute_tensor(d_bl) |
| 203 | + |
| 204 | + cu_stream = get_cu_stream() |
| 205 | + |
| 206 | + # JIT compile |
| 207 | + compiled = cute.compile( |
| 208 | + kernel, |
| 209 | + a_cute, |
| 210 | + b_cute, |
| 211 | + c_cute, |
| 212 | + d_cute, |
| 213 | + B, |
| 214 | + M, |
| 215 | + N, |
| 216 | + K, |
| 217 | + cu_stream, |
| 218 | + ) |
| 219 | + |
| 220 | + # Execute |
| 221 | + compiled( |
| 222 | + a_cute, |
| 223 | + b_cute, |
| 224 | + c_cute, |
| 225 | + d_cute, |
| 226 | + B, |
| 227 | + M, |
| 228 | + N, |
| 229 | + K, |
| 230 | + cu_stream, |
| 231 | + ) |
| 232 | + torch.cuda.synchronize() |
| 233 | + |
| 234 | + # Validate — c_pt (batch-first) shares memory with c_bl (batch-last), |
| 235 | + # so it already has the kernel output in batch-first layout. |
| 236 | + max_diff = (c_pt - y_ref).abs().max().item() |
| 237 | + passed = torch.allclose(c_pt, y_ref, atol=atol, rtol=rtol) |
| 238 | + |
| 239 | + status = "PASS" if passed else "FAIL" |
| 240 | + print(f" [{status}] {test_id} (max_diff={max_diff:.6f})") |
| 241 | + return passed |
| 242 | + |
| 243 | + |
| 244 | +# ============================================================================= |
| 245 | +# Main |
| 246 | +# ============================================================================= |
| 247 | + |
| 248 | + |
| 249 | +def main(): |
| 250 | + print("=" * 70) |
| 251 | + print("BmmSm90Kernel Direct Test Harness") |
| 252 | + print("=" * 70) |
| 253 | + |
| 254 | + if not torch.cuda.is_available(): |
| 255 | + print("ERROR: CUDA GPU required") |
| 256 | + sys.exit(1) |
| 257 | + |
| 258 | + cc_major, cc_minor = torch.cuda.get_device_capability(0) |
| 259 | + gpu_arch = cc_major * 10 + cc_minor |
| 260 | + gpu_name = torch.cuda.get_device_name(0) |
| 261 | + print(f"GPU: {gpu_name} (SM{gpu_arch})") |
| 262 | + |
| 263 | + if gpu_arch < 90: |
| 264 | + print(f"ERROR: SM90+ required for Hopper TMA/WGMMA, got SM{gpu_arch}") |
| 265 | + sys.exit(1) |
| 266 | + |
| 267 | + configs = _make_configs() |
| 268 | + total = 0 |
| 269 | + passed = 0 |
| 270 | + failed_tests = [] |
| 271 | + |
| 272 | + # Test plain BMM (has_d=False) for all layouts and shapes |
| 273 | + for ( |
| 274 | + layout_name, |
| 275 | + a_row_major, |
| 276 | + b_row_major, |
| 277 | + a_shape_fn, |
| 278 | + b_shape_fn, |
| 279 | + ref_fn, |
| 280 | + ) in configs: |
| 281 | + print(f"\n--- bmm_{layout_name} (plain) ---") |
| 282 | + for shape_name, B, M, N, K in _SHAPES: |
| 283 | + a_shape = a_shape_fn(B, M, N, K) |
| 284 | + b_shape = b_shape_fn(B, M, N, K) |
| 285 | + total += 1 |
| 286 | + ok = run_test( |
| 287 | + layout_name, |
| 288 | + a_row_major, |
| 289 | + b_row_major, |
| 290 | + has_d=False, |
| 291 | + B=B, |
| 292 | + M=M, |
| 293 | + N=N, |
| 294 | + K=K, |
| 295 | + a_shape=a_shape, |
| 296 | + b_shape=b_shape, |
| 297 | + ref_fn=ref_fn, |
| 298 | + ) |
| 299 | + if ok: |
| 300 | + passed += 1 |
| 301 | + else: |
| 302 | + failed_tests.append( |
| 303 | + f"bmm_{layout_name} {shape_name} B={B} M={M} N={N} K={K}" |
| 304 | + ) |
| 305 | + |
| 306 | + # Test BMM + residual add (has_d=True) for all layouts and shapes |
| 307 | + for ( |
| 308 | + layout_name, |
| 309 | + a_row_major, |
| 310 | + b_row_major, |
| 311 | + a_shape_fn, |
| 312 | + b_shape_fn, |
| 313 | + ref_fn, |
| 314 | + ) in configs: |
| 315 | + print(f"\n--- bmm_{layout_name}_add (residual) ---") |
| 316 | + for shape_name, B, M, N, K in _SHAPES: |
| 317 | + a_shape = a_shape_fn(B, M, N, K) |
| 318 | + b_shape = b_shape_fn(B, M, N, K) |
| 319 | + total += 1 |
| 320 | + ok = run_test( |
| 321 | + layout_name, |
| 322 | + a_row_major, |
| 323 | + b_row_major, |
| 324 | + has_d=True, |
| 325 | + B=B, |
| 326 | + M=M, |
| 327 | + N=N, |
| 328 | + K=K, |
| 329 | + a_shape=a_shape, |
| 330 | + b_shape=b_shape, |
| 331 | + ref_fn=ref_fn, |
| 332 | + ) |
| 333 | + if ok: |
| 334 | + passed += 1 |
| 335 | + else: |
| 336 | + failed_tests.append( |
| 337 | + f"bmm_{layout_name}_add {shape_name} B={B} M={M} N={N} K={K}" |
| 338 | + ) |
| 339 | + |
| 340 | + # Summary |
| 341 | + print("\n" + "=" * 70) |
| 342 | + print(f"Results: {passed}/{total} passed") |
| 343 | + if failed_tests: |
| 344 | + print(f"\nFailed tests:") |
| 345 | + for t in failed_tests: |
| 346 | + print(f" - {t}") |
| 347 | + else: |
| 348 | + print("All tests passed!") |
| 349 | + print("=" * 70) |
| 350 | + |
| 351 | + sys.exit(0 if passed == total else 1) |
| 352 | + |
| 353 | + |
| 354 | +if __name__ == "__main__": |
| 355 | + main() |
0 commit comments