Skip to content

Commit 510b757

Browse files
zoranzhaometa-codesync[bot]
authored andcommitted
Add a parttern matching for b2b mm (#1049)
Summary: Pull Request resolved: #1049 as title, I am just trying to playing around. Reviewed By: jijunyan Differential Revision: D94574598 fbshipit-source-id: 9ec423a0ef80cb8105c693117a907468c4c7f66b
1 parent 2c4136c commit 510b757

5 files changed

Lines changed: 575 additions & 6 deletions

File tree

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""
16+
AITemplate classic_b2b_bmm Pattern Matching Example
17+
====================================================
18+
Demonstrates the fuse_b2b_bmm compiler pass that automatically fuses a
19+
decomposed PyTorch-style attention pattern into classic_b2b_bmm.
20+
21+
Instead of directly using ops.classic_b2b_bmm, the AIT graph is built from
22+
individual ops (bmm_rcr, elementwise MUL/ADD/SIGMOID, bmm_rrr) that mirror
23+
the PyTorch implementation. The compiler's fuse_b2b_bmm pass then
24+
pattern-matches and replaces them with the fused classic_b2b_bmm kernel.
25+
26+
Pattern matched:
27+
score = bmm_rcr(Q, K) # Q @ K^T
28+
score = score * alpha0 # scale
29+
score = score + bias # add bias
30+
score = sigmoid(score) # activation
31+
score = score * alpha1 # scale (optional)
32+
output = bmm_rrr(score, V) # score @ V
33+
=>
34+
output = classic_b2b_bmm(Q, K, V, bias)
35+
36+
Requirements:
37+
- CUDA SM80+ (A100, H100, etc.)
38+
- N0, N1 <= 512 (sequence length limitation)
39+
40+
Run with:
41+
buck run fbcode//aitemplate/AITemplate/examples:classic_b2b_bmm_example
42+
"""
43+
44+
import logging
45+
46+
import torch
47+
from aitemplate.compiler import compile_model, ops
48+
from aitemplate.compiler.ops.common.epilogue import FuncEnum
49+
from aitemplate.frontend import Tensor
50+
from aitemplate.testing.detect_target import FBCUDA
51+
52+
53+
def _get_target(**kwargs):
54+
"""Create AIT CUDA target, auto-detecting GPU architecture."""
55+
cc_major, cc_minor = torch.cuda.get_device_capability(0)
56+
gpu_arch = str(cc_major * 10 + cc_minor)
57+
58+
if int(gpu_arch) < 80:
59+
raise RuntimeError(
60+
f"classic_b2b_bmm requires SM80+ (A100/H100). Current GPU: SM{gpu_arch}"
61+
)
62+
63+
print(f"Detected GPU architecture: SM{gpu_arch}")
64+
return FBCUDA(arch=gpu_arch, **kwargs)
65+
66+
67+
# =============================================================================
68+
# PyTorch Reference Model
69+
# =============================================================================
70+
71+
72+
class PTB2bBmm(torch.nn.Module):
73+
"""PyTorch reference for the b2b_bmm computation (no learnable params).
74+
75+
Computes: output = (alpha1) * sigmoid(alpha0 * (Q @ K^T) + bias) @ V
76+
"""
77+
78+
def __init__(self, head_dim: int):
79+
super().__init__()
80+
self.alpha0 = 1.0 / (head_dim**0.5)
81+
self.alpha1 = 1.0
82+
83+
def forward(self, q, k, v, bias):
84+
attn = self.alpha0 * (q @ k.transpose(-2, -1)) + bias
85+
attn = torch.sigmoid(attn)
86+
attn = self.alpha1 * attn
87+
return attn @ v
88+
89+
90+
# =============================================================================
91+
# AITemplate Graph Builder (decomposed ops, NOT using ops.classic_b2b_bmm)
92+
# =============================================================================
93+
94+
95+
def build_decomposed_b2b_bmm_graph(batch, seq_len, head_dim, dtype="float16"):
96+
"""Build AIT graph using decomposed ops that mirror the PyTorch implementation.
97+
98+
This does NOT use ops.classic_b2b_bmm directly. Instead, it builds the
99+
equivalent graph from primitive ops:
100+
bmm_rcr -> MUL(alpha0) -> ADD(bias) -> SIGMOID -> MUL(alpha1) -> bmm_rrr
101+
102+
The fuse_b2b_bmm compiler pass will pattern-match this and replace it
103+
with a fused classic_b2b_bmm op.
104+
"""
105+
alpha0 = 1.0 / (head_dim**0.5)
106+
alpha1 = 1.0
107+
108+
Q = Tensor(shape=[batch, seq_len, head_dim], dtype=dtype, name="Q", is_input=True)
109+
K = Tensor(shape=[batch, seq_len, head_dim], dtype=dtype, name="K", is_input=True)
110+
V = Tensor(shape=[batch, seq_len, head_dim], dtype=dtype, name="V", is_input=True)
111+
Bias = Tensor(
112+
shape=[batch, seq_len, seq_len], dtype=dtype, name="Bias", is_input=True
113+
)
114+
115+
# Step 1: score = Q @ K^T (bmm_rcr treats K as column-major => K^T)
116+
score = ops.bmm_rcr()(Q, K)
117+
118+
# Step 2: score = score * alpha0
119+
score = ops.elementwise(FuncEnum.MUL)(score, alpha0)
120+
121+
# Step 3: score = score + bias
122+
score = ops.elementwise(FuncEnum.ADD)(score, Bias)
123+
124+
# Step 4: score = sigmoid(score)
125+
score = ops.elementwise(FuncEnum.SIGMOID)(score)
126+
127+
# Step 5: score = score * alpha1
128+
score = ops.elementwise(FuncEnum.MUL)(score, alpha1)
129+
130+
# Step 6: output = score @ V (bmm_rrr: both row-major)
131+
Y = ops.bmm_rrr()(score, V)
132+
133+
Y._attrs["is_output"] = True
134+
Y._attrs["name"] = "Y"
135+
136+
return Y
137+
138+
139+
# =============================================================================
140+
# Test
141+
# =============================================================================
142+
143+
144+
def run_pattern_matching_example():
145+
"""Test: Decomposed ops auto-fused into classic_b2b_bmm by compiler pass.
146+
147+
Builds an AIT graph from primitive ops (bmm_rcr, elementwise MUL/ADD/SIGMOID,
148+
bmm_rrr) and verifies that the fuse_b2b_bmm pass fuses them into a single
149+
classic_b2b_bmm kernel, producing results matching PyTorch.
150+
"""
151+
print("\n" + "=" * 60)
152+
print("Pattern Matching Test: decomposed ops -> classic_b2b_bmm")
153+
print("=" * 60)
154+
155+
batch, seq_len, head_dim = 4, 128, 64
156+
dtype = "float16"
157+
158+
# Create and run PyTorch reference
159+
pt_model = PTB2bBmm(head_dim).cuda().half()
160+
pt_model.eval()
161+
162+
q_pt = torch.randn(batch, seq_len, head_dim, device="cuda", dtype=torch.float16)
163+
k_pt = torch.randn(batch, seq_len, head_dim, device="cuda", dtype=torch.float16)
164+
v_pt = torch.randn(batch, seq_len, head_dim, device="cuda", dtype=torch.float16)
165+
bias_pt = torch.randn(batch, seq_len, seq_len, device="cuda", dtype=torch.float16)
166+
y_pt = pt_model(q_pt, k_pt, v_pt, bias_pt)
167+
168+
# Build AIT graph from decomposed ops (NOT ops.classic_b2b_bmm)
169+
target = _get_target(use_fp16_acc=False)
170+
logging.getLogger("aitemplate").setLevel(logging.DEBUG)
171+
172+
with target:
173+
Y = build_decomposed_b2b_bmm_graph(batch, seq_len, head_dim, dtype)
174+
175+
# Compile - the fuse_b2b_bmm pass will fuse the decomposed graph
176+
print("\nCompiling... (fuse_b2b_bmm pass will pattern-match and fuse)")
177+
with compile_model(Y, target, "./tmp", "pattern_matched_b2b_bmm") as module:
178+
y_ait = torch.empty_like(y_pt)
179+
module.run_with_tensors(
180+
{"Q": q_pt, "K": k_pt, "V": v_pt, "Bias": bias_pt},
181+
{"Y": y_ait},
182+
)
183+
184+
# Verify correctness
185+
close = torch.allclose(y_ait, y_pt, atol=1e-2, rtol=1e-2)
186+
max_diff = (y_ait - y_pt).abs().max().item()
187+
assert close, f"Results mismatch! Max diff: {max_diff}"
188+
print(f"\nResults match PyTorch: {close} (max diff: {max_diff:.6f})")
189+
190+
191+
def main():
192+
print("=" * 60)
193+
print("AITemplate classic_b2b_bmm Pattern Matching Example")
194+
print("=" * 60)
195+
print("\nDemonstrates automatic fusion of decomposed attention ops")
196+
print("into classic_b2b_bmm via the fuse_b2b_bmm compiler pass.")
197+
198+
run_pattern_matching_example()
199+
200+
print("\n" + "=" * 60)
201+
print("All tests passed!")
202+
print("=" * 60)
203+
204+
205+
if __name__ == "__main__":
206+
main()

python/aitemplate/backend/cuda/b2b_bmm/classic_b2b_bmm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,7 @@ def classic_b2b_bmm_gen_function(func_attrs: Dict[str, Any]) -> str:
326326
),
327327
alpha0=str(func_attrs["alpha0"]),
328328
alpha1=str(func_attrs["alpha1"]),
329-
alpha1_divide_by_seq_len=(
330-
"true" if func_attrs["alpha1_divide_by_seq_len"] else "false"
331-
),
329+
alpha1_divide_by_seq_len=func_attrs["alpha1_divide_by_seq_len"],
332330
epilogue_math=epilogue_math,
333331
bias_stride_n=bias_stride_n,
334332
bias_stride_mn=bias_stride_mn,

python/aitemplate/backend/cuda/b2b_bmm/grouped_classic_b2b_bmm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,7 @@ def classic_b2b_bmm_gen_function(func_attrs: Dict[str, Any]) -> str:
326326
),
327327
alpha0=str(func_attrs["alpha0"]),
328328
alpha1=str(func_attrs["alpha1"]),
329-
alpha1_divide_by_seq_len=(
330-
"true" if func_attrs["alpha1_divide_by_seq_len"] else "false"
331-
),
329+
alpha1_divide_by_seq_len=func_attrs["alpha1_divide_by_seq_len"],
332330
epilogue_math=epilogue_math,
333331
bias_stride_n=bias_stride_n,
334332
bias_stride_mn=bias_stride_mn,

0 commit comments

Comments
 (0)