|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# |
| 15 | +""" |
| 16 | +AITemplate classic_b2b_bmm Pattern Matching Example |
| 17 | +==================================================== |
| 18 | +Demonstrates the fuse_b2b_bmm compiler pass that automatically fuses a |
| 19 | +decomposed PyTorch-style attention pattern into classic_b2b_bmm. |
| 20 | +
|
| 21 | +Instead of directly using ops.classic_b2b_bmm, the AIT graph is built from |
| 22 | +individual ops (bmm_rcr, elementwise MUL/ADD/SIGMOID, bmm_rrr) that mirror |
| 23 | +the PyTorch implementation. The compiler's fuse_b2b_bmm pass then |
| 24 | +pattern-matches and replaces them with the fused classic_b2b_bmm kernel. |
| 25 | +
|
| 26 | +Pattern matched: |
| 27 | + score = bmm_rcr(Q, K) # Q @ K^T |
| 28 | + score = score * alpha0 # scale |
| 29 | + score = score + bias # add bias |
| 30 | + score = sigmoid(score) # activation |
| 31 | + score = score * alpha1 # scale (optional) |
| 32 | + output = bmm_rrr(score, V) # score @ V |
| 33 | + => |
| 34 | + output = classic_b2b_bmm(Q, K, V, bias) |
| 35 | +
|
| 36 | +Requirements: |
| 37 | + - CUDA SM80+ (A100, H100, etc.) |
| 38 | + - N0, N1 <= 512 (sequence length limitation) |
| 39 | +
|
| 40 | +Run with: |
| 41 | + buck run fbcode//aitemplate/AITemplate/examples:classic_b2b_bmm_example |
| 42 | +""" |
| 43 | + |
| 44 | +import logging |
| 45 | + |
| 46 | +import torch |
| 47 | +from aitemplate.compiler import compile_model, ops |
| 48 | +from aitemplate.compiler.ops.common.epilogue import FuncEnum |
| 49 | +from aitemplate.frontend import Tensor |
| 50 | +from aitemplate.testing.detect_target import FBCUDA |
| 51 | + |
| 52 | + |
| 53 | +def _get_target(**kwargs): |
| 54 | + """Create AIT CUDA target, auto-detecting GPU architecture.""" |
| 55 | + cc_major, cc_minor = torch.cuda.get_device_capability(0) |
| 56 | + gpu_arch = str(cc_major * 10 + cc_minor) |
| 57 | + |
| 58 | + if int(gpu_arch) < 80: |
| 59 | + raise RuntimeError( |
| 60 | + f"classic_b2b_bmm requires SM80+ (A100/H100). Current GPU: SM{gpu_arch}" |
| 61 | + ) |
| 62 | + |
| 63 | + print(f"Detected GPU architecture: SM{gpu_arch}") |
| 64 | + return FBCUDA(arch=gpu_arch, **kwargs) |
| 65 | + |
| 66 | + |
| 67 | +# ============================================================================= |
| 68 | +# PyTorch Reference Model |
| 69 | +# ============================================================================= |
| 70 | + |
| 71 | + |
| 72 | +class PTB2bBmm(torch.nn.Module): |
| 73 | + """PyTorch reference for the b2b_bmm computation (no learnable params). |
| 74 | +
|
| 75 | + Computes: output = (alpha1) * sigmoid(alpha0 * (Q @ K^T) + bias) @ V |
| 76 | + """ |
| 77 | + |
| 78 | + def __init__(self, head_dim: int): |
| 79 | + super().__init__() |
| 80 | + self.alpha0 = 1.0 / (head_dim**0.5) |
| 81 | + self.alpha1 = 1.0 |
| 82 | + |
| 83 | + def forward(self, q, k, v, bias): |
| 84 | + attn = self.alpha0 * (q @ k.transpose(-2, -1)) + bias |
| 85 | + attn = torch.sigmoid(attn) |
| 86 | + attn = self.alpha1 * attn |
| 87 | + return attn @ v |
| 88 | + |
| 89 | + |
| 90 | +# ============================================================================= |
| 91 | +# AITemplate Graph Builder (decomposed ops, NOT using ops.classic_b2b_bmm) |
| 92 | +# ============================================================================= |
| 93 | + |
| 94 | + |
| 95 | +def build_decomposed_b2b_bmm_graph(batch, seq_len, head_dim, dtype="float16"): |
| 96 | + """Build AIT graph using decomposed ops that mirror the PyTorch implementation. |
| 97 | +
|
| 98 | + This does NOT use ops.classic_b2b_bmm directly. Instead, it builds the |
| 99 | + equivalent graph from primitive ops: |
| 100 | + bmm_rcr -> MUL(alpha0) -> ADD(bias) -> SIGMOID -> MUL(alpha1) -> bmm_rrr |
| 101 | +
|
| 102 | + The fuse_b2b_bmm compiler pass will pattern-match this and replace it |
| 103 | + with a fused classic_b2b_bmm op. |
| 104 | + """ |
| 105 | + alpha0 = 1.0 / (head_dim**0.5) |
| 106 | + alpha1 = 1.0 |
| 107 | + |
| 108 | + Q = Tensor(shape=[batch, seq_len, head_dim], dtype=dtype, name="Q", is_input=True) |
| 109 | + K = Tensor(shape=[batch, seq_len, head_dim], dtype=dtype, name="K", is_input=True) |
| 110 | + V = Tensor(shape=[batch, seq_len, head_dim], dtype=dtype, name="V", is_input=True) |
| 111 | + Bias = Tensor( |
| 112 | + shape=[batch, seq_len, seq_len], dtype=dtype, name="Bias", is_input=True |
| 113 | + ) |
| 114 | + |
| 115 | + # Step 1: score = Q @ K^T (bmm_rcr treats K as column-major => K^T) |
| 116 | + score = ops.bmm_rcr()(Q, K) |
| 117 | + |
| 118 | + # Step 2: score = score * alpha0 |
| 119 | + score = ops.elementwise(FuncEnum.MUL)(score, alpha0) |
| 120 | + |
| 121 | + # Step 3: score = score + bias |
| 122 | + score = ops.elementwise(FuncEnum.ADD)(score, Bias) |
| 123 | + |
| 124 | + # Step 4: score = sigmoid(score) |
| 125 | + score = ops.elementwise(FuncEnum.SIGMOID)(score) |
| 126 | + |
| 127 | + # Step 5: score = score * alpha1 |
| 128 | + score = ops.elementwise(FuncEnum.MUL)(score, alpha1) |
| 129 | + |
| 130 | + # Step 6: output = score @ V (bmm_rrr: both row-major) |
| 131 | + Y = ops.bmm_rrr()(score, V) |
| 132 | + |
| 133 | + Y._attrs["is_output"] = True |
| 134 | + Y._attrs["name"] = "Y" |
| 135 | + |
| 136 | + return Y |
| 137 | + |
| 138 | + |
| 139 | +# ============================================================================= |
| 140 | +# Test |
| 141 | +# ============================================================================= |
| 142 | + |
| 143 | + |
| 144 | +def run_pattern_matching_example(): |
| 145 | + """Test: Decomposed ops auto-fused into classic_b2b_bmm by compiler pass. |
| 146 | +
|
| 147 | + Builds an AIT graph from primitive ops (bmm_rcr, elementwise MUL/ADD/SIGMOID, |
| 148 | + bmm_rrr) and verifies that the fuse_b2b_bmm pass fuses them into a single |
| 149 | + classic_b2b_bmm kernel, producing results matching PyTorch. |
| 150 | + """ |
| 151 | + print("\n" + "=" * 60) |
| 152 | + print("Pattern Matching Test: decomposed ops -> classic_b2b_bmm") |
| 153 | + print("=" * 60) |
| 154 | + |
| 155 | + batch, seq_len, head_dim = 4, 128, 64 |
| 156 | + dtype = "float16" |
| 157 | + |
| 158 | + # Create and run PyTorch reference |
| 159 | + pt_model = PTB2bBmm(head_dim).cuda().half() |
| 160 | + pt_model.eval() |
| 161 | + |
| 162 | + q_pt = torch.randn(batch, seq_len, head_dim, device="cuda", dtype=torch.float16) |
| 163 | + k_pt = torch.randn(batch, seq_len, head_dim, device="cuda", dtype=torch.float16) |
| 164 | + v_pt = torch.randn(batch, seq_len, head_dim, device="cuda", dtype=torch.float16) |
| 165 | + bias_pt = torch.randn(batch, seq_len, seq_len, device="cuda", dtype=torch.float16) |
| 166 | + y_pt = pt_model(q_pt, k_pt, v_pt, bias_pt) |
| 167 | + |
| 168 | + # Build AIT graph from decomposed ops (NOT ops.classic_b2b_bmm) |
| 169 | + target = _get_target(use_fp16_acc=False) |
| 170 | + logging.getLogger("aitemplate").setLevel(logging.DEBUG) |
| 171 | + |
| 172 | + with target: |
| 173 | + Y = build_decomposed_b2b_bmm_graph(batch, seq_len, head_dim, dtype) |
| 174 | + |
| 175 | + # Compile - the fuse_b2b_bmm pass will fuse the decomposed graph |
| 176 | + print("\nCompiling... (fuse_b2b_bmm pass will pattern-match and fuse)") |
| 177 | + with compile_model(Y, target, "./tmp", "pattern_matched_b2b_bmm") as module: |
| 178 | + y_ait = torch.empty_like(y_pt) |
| 179 | + module.run_with_tensors( |
| 180 | + {"Q": q_pt, "K": k_pt, "V": v_pt, "Bias": bias_pt}, |
| 181 | + {"Y": y_ait}, |
| 182 | + ) |
| 183 | + |
| 184 | + # Verify correctness |
| 185 | + close = torch.allclose(y_ait, y_pt, atol=1e-2, rtol=1e-2) |
| 186 | + max_diff = (y_ait - y_pt).abs().max().item() |
| 187 | + assert close, f"Results mismatch! Max diff: {max_diff}" |
| 188 | + print(f"\nResults match PyTorch: {close} (max diff: {max_diff:.6f})") |
| 189 | + |
| 190 | + |
| 191 | +def main(): |
| 192 | + print("=" * 60) |
| 193 | + print("AITemplate classic_b2b_bmm Pattern Matching Example") |
| 194 | + print("=" * 60) |
| 195 | + print("\nDemonstrates automatic fusion of decomposed attention ops") |
| 196 | + print("into classic_b2b_bmm via the fuse_b2b_bmm compiler pass.") |
| 197 | + |
| 198 | + run_pattern_matching_example() |
| 199 | + |
| 200 | + print("\n" + "=" * 60) |
| 201 | + print("All tests passed!") |
| 202 | + print("=" * 60) |
| 203 | + |
| 204 | + |
| 205 | +if __name__ == "__main__": |
| 206 | + main() |
0 commit comments