Skip to content

Commit c8aad37

Browse files
zoranzhaometa-codesync[bot]
authored andcommitted
add gemm_rcr CuTeDSL backend (#1060)
Summary: Pull Request resolved: #1060 Reviewed By: jijunyan Differential Revision: D96823319 fbshipit-source-id: c309eab7bf42db00128c1d313ca627ab684ac3e2
1 parent 4eb69d9 commit c8aad37

6 files changed

Lines changed: 1471 additions & 0 deletions

File tree

examples/test_cutedsl_gemm_rcr.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""
16+
CuTeDSL gemm_rcr (no bias) via AIT compile_model test
17+
======================================================
18+
19+
Validates the CuTeDSL backend for gemm_rcr by compiling an AIT graph
20+
and checking numerical correctness against PyTorch.
21+
22+
Operation: Y[M, N] = X[M, K] @ W[N, K]^T
23+
Equivalent: torch.nn.functional.linear(X, W, bias=None)
24+
25+
Run with:
26+
buck run fbcode//aitemplate/AITemplate/examples:test_cutedsl_gemm_rcr
27+
buck run fbcode//aitemplate/AITemplate/examples:test_cutedsl_gemm_rcr -- --both
28+
"""
29+
30+
import argparse
31+
import logging
32+
33+
import torch
34+
from aitemplate.compiler import compile_model, ops
35+
from aitemplate.frontend import Tensor
36+
from aitemplate.testing.detect_target import FBCUDA
37+
38+
39+
def _get_target(**kwargs):
40+
"""Create AIT CUDA target, auto-detecting GPU architecture."""
41+
cc_major, cc_minor = torch.cuda.get_device_capability(0)
42+
gpu_arch = str(cc_major * 10 + cc_minor)
43+
44+
if int(gpu_arch) < 80:
45+
raise RuntimeError(
46+
f"gemm_rcr CuTeDSL requires SM80+ (A100/H100). Current GPU: SM{gpu_arch}"
47+
)
48+
49+
print(f" Detected GPU architecture: SM{gpu_arch}")
50+
return FBCUDA(arch=gpu_arch, **kwargs)
51+
52+
53+
def build_gemm_rcr_graph(M, N, K, dtype="float16"):
54+
"""Build AIT graph for gemm_rcr: Y[M,N] = X[M,K] @ W[N,K]^T."""
55+
X = Tensor(shape=[M, K], dtype=dtype, name="X", is_input=True)
56+
W = Tensor(shape=[N, K], dtype=dtype, name="W", is_input=True)
57+
58+
Y = ops.gemm_rcr()(X, W)
59+
60+
Y._attrs["is_output"] = True
61+
Y._attrs["name"] = "Y"
62+
63+
return Y
64+
65+
66+
def run_test(M, N, K, use_cutedsl=False):
67+
"""Compile and run gemm_rcr through AIT compile_model."""
68+
backend_name = "CuTeDSL" if use_cutedsl else "CUTLASS C++"
69+
print(f"\n --- gemm_rcr ({backend_name}) M={M}, N={N}, K={K} ---")
70+
71+
# PyTorch reference
72+
x_pt = torch.randn(M, K, device="cuda", dtype=torch.float16)
73+
w_pt = torch.randn(N, K, device="cuda", dtype=torch.float16)
74+
y_pt = torch.nn.functional.linear(x_pt, w_pt, bias=None)
75+
76+
# Build AIT graph
77+
target = _get_target(use_fp16_acc=False, use_cutedsl_gemm=use_cutedsl)
78+
logging.getLogger("aitemplate").setLevel(logging.DEBUG)
79+
80+
with target:
81+
Y = build_gemm_rcr_graph(M, N, K)
82+
83+
# Compile and run
84+
workdir_suffix = "cutedsl" if use_cutedsl else "cutlass"
85+
print(f" Compiling with {backend_name} backend...")
86+
with compile_model(
87+
Y, target, "./tmp", f"gemm_rcr_{workdir_suffix}_{M}_{N}_{K}"
88+
) as module:
89+
y_ait = torch.empty_like(y_pt)
90+
module.run_with_tensors(
91+
{"X": x_pt, "W": w_pt},
92+
{"Y": y_ait},
93+
)
94+
95+
# Validate
96+
close = torch.allclose(y_ait, y_pt, atol=1e-2, rtol=1e-2)
97+
max_diff = (y_ait - y_pt).abs().max().item()
98+
assert close, f"Results mismatch! Max diff: {max_diff}"
99+
print(f" Results match PyTorch: max diff = {max_diff:.6f}")
100+
101+
return True
102+
103+
104+
def main():
105+
parser = argparse.ArgumentParser(
106+
description="CuTeDSL gemm_rcr (no bias) via AIT compile_model test"
107+
)
108+
parser.add_argument(
109+
"--use-cutedsl",
110+
action="store_true",
111+
default=False,
112+
help="Use CuTeDSL backend instead of CUTLASS C++ templates",
113+
)
114+
parser.add_argument(
115+
"--both",
116+
action="store_true",
117+
default=False,
118+
help="Run with both CUTLASS C++ and CuTeDSL backends",
119+
)
120+
args = parser.parse_args()
121+
122+
print("=" * 60)
123+
print("CuTeDSL gemm_rcr (no bias) Test")
124+
print("=" * 60)
125+
print("Operation: Y[M,N] = X[M,K] @ W[N,K]^T")
126+
127+
test_shapes = [
128+
(256, 512, 128),
129+
(128, 256, 64),
130+
(1, 1024, 512),
131+
]
132+
133+
for M, N, K in test_shapes:
134+
if args.both:
135+
run_test(M, N, K, use_cutedsl=False)
136+
run_test(M, N, K, use_cutedsl=True)
137+
else:
138+
run_test(M, N, K, use_cutedsl=args.use_cutedsl or True)
139+
140+
print("\n" + "=" * 60)
141+
print("All tests passed!")
142+
print("=" * 60)
143+
144+
145+
if __name__ == "__main__":
146+
main()

python/aitemplate/backend/cuda/gemm_universal/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
gemm_rcr_bias_sigmoid,
3232
gemm_rcr_bias_swish,
3333
gemm_rcr_bias_tanh,
34+
gemm_rcr_cutedsl,
3435
gemm_rcr_fast_gelu,
3536
gemm_rcr_permute,
3637
gemm_rcr_permute_elup1,

0 commit comments

Comments
 (0)