|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# |
| 15 | +""" |
| 16 | +AITemplate Toy Example |
| 17 | +====================== |
| 18 | +Two patterns: |
| 19 | + 1. Raw operator graph: elementwise tanh(X + 3) |
| 20 | + 2. nn.Module style: Linear + GeLU + residual + LayerNorm |
| 21 | +
|
| 22 | +Run with: |
| 23 | + buck run fbcode//aitemplate/AITemplate/examples:toy_example |
| 24 | +""" |
| 25 | + |
| 26 | +import logging |
| 27 | + |
| 28 | +import torch |
| 29 | +from aitemplate.compiler import compile_model, ops |
| 30 | +from aitemplate.frontend import nn, Tensor |
| 31 | +from aitemplate.testing.detect_target import FBCUDA |
| 32 | + |
| 33 | + |
| 34 | +def _get_target(**kwargs): |
| 35 | + """Create AIT CUDA target using the actual GPU compute capability. |
| 36 | +
|
| 37 | + On virtual hosts /etc/fbwhoami reports MODEL_NAME=VIRTUAL which causes |
| 38 | + detect_target() to default to SM80 even when H100 (SM90) GPUs are present. |
| 39 | + We detect the real SM version via torch.cuda instead. |
| 40 | + """ |
| 41 | + cc_major, cc_minor = torch.cuda.get_device_capability(0) |
| 42 | + gpu_arch = str(cc_major * 10 + cc_minor) |
| 43 | + return FBCUDA(arch=gpu_arch, **kwargs) |
| 44 | + |
| 45 | + |
| 46 | +class PTSimpleModel(torch.nn.Module): |
| 47 | + """PyTorch reference model.""" |
| 48 | + |
| 49 | + def __init__(self, hidden, eps: float = 1e-5): |
| 50 | + super().__init__() |
| 51 | + self.dense1 = torch.nn.Linear(hidden, 4 * hidden) |
| 52 | + self.act1 = torch.nn.functional.gelu |
| 53 | + self.dense2 = torch.nn.Linear(4 * hidden, hidden) |
| 54 | + self.layernorm = torch.nn.LayerNorm(hidden, eps=eps) |
| 55 | + |
| 56 | + def forward(self, input): |
| 57 | + hidden_states = self.dense1(input) |
| 58 | + hidden_states = self.act1(hidden_states) |
| 59 | + hidden_states = self.dense2(hidden_states) |
| 60 | + hidden_states = hidden_states + input |
| 61 | + hidden_states = self.layernorm(hidden_states) |
| 62 | + return hidden_states |
| 63 | + |
| 64 | + |
| 65 | +class AITSimpleModel(nn.Module): |
| 66 | + """AITemplate equivalent — fuses GEMM + bias + GeLU into one kernel.""" |
| 67 | + |
| 68 | + def __init__(self, hidden, eps: float = 1e-5): |
| 69 | + super().__init__() |
| 70 | + self.dense1 = nn.Linear(hidden, 4 * hidden, specialization="fast_gelu") |
| 71 | + self.dense2 = nn.Linear(4 * hidden, hidden) |
| 72 | + self.layernorm = nn.LayerNorm(hidden, eps=eps) |
| 73 | + |
| 74 | + def forward(self, input): |
| 75 | + hidden_states = self.dense1(input) |
| 76 | + hidden_states = self.dense2(hidden_states) |
| 77 | + hidden_states = hidden_states + input |
| 78 | + hidden_states = self.layernorm(hidden_states) |
| 79 | + return hidden_states |
| 80 | + |
| 81 | + |
| 82 | +def map_pt_params(ait_model, pt_model): |
| 83 | + ait_model.name_parameter_tensor() |
| 84 | + pt_params = dict(pt_model.named_parameters()) |
| 85 | + mapped = {} |
| 86 | + for name, _ in ait_model.named_parameters(): |
| 87 | + ait_name = name.replace(".", "_") |
| 88 | + assert name in pt_params |
| 89 | + mapped[ait_name] = pt_params[name] |
| 90 | + return mapped |
| 91 | + |
| 92 | + |
| 93 | +def run_elementwise_example(): |
| 94 | + """Example 1: Raw elementwise ops — Y = tanh(X + 3)""" |
| 95 | + print("\n" + "=" * 60) |
| 96 | + print("Example 1: Elementwise ops (Y = tanh(X + 3))") |
| 97 | + print("=" * 60) |
| 98 | + |
| 99 | + # 1. Build graph |
| 100 | + X = Tensor(shape=[1024, 256], name="X", dtype="float16", is_input=True) |
| 101 | + Y = ops.tanh(X + 3.0) |
| 102 | + Y._attrs["is_output"] = True |
| 103 | + Y._attrs["name"] = "Y" |
| 104 | + |
| 105 | + # 2. Compile |
| 106 | + target = _get_target() |
| 107 | + logging.getLogger("aitemplate").setLevel(logging.DEBUG) |
| 108 | + module = compile_model(Y, target, "./tmp", "toy_tanh_add") |
| 109 | + |
| 110 | + # 3. Run inference |
| 111 | + x_pt = torch.randn(1024, 256).cuda().half() |
| 112 | + y_ait = torch.empty(1024, 256).cuda().half() |
| 113 | + module.run_with_tensors({"X": x_pt}, {"Y": y_ait}) |
| 114 | + |
| 115 | + # 4. Verify against PyTorch |
| 116 | + y_pt = torch.tanh(x_pt + 3.0) |
| 117 | + close = torch.allclose(y_ait, y_pt, atol=1e-2, rtol=1e-2) |
| 118 | + assert close, "Elementwise example: results do not match PyTorch!" |
| 119 | + print(f"Results match PyTorch: {close}") |
| 120 | + |
| 121 | + |
| 122 | +def run_nn_module_example(): |
| 123 | + """Example 2: nn.Module with weight mapping.""" |
| 124 | + print("\n" + "=" * 60) |
| 125 | + print("Example 2: nn.Module (Linear + GeLU + residual + LayerNorm)") |
| 126 | + print("=" * 60) |
| 127 | + |
| 128 | + batch_size, hidden = 1024, 512 |
| 129 | + |
| 130 | + # 1. Create and run PyTorch model |
| 131 | + pt_model = PTSimpleModel(hidden).cuda().half() |
| 132 | + pt_model.eval() |
| 133 | + x = torch.randn(batch_size, hidden).cuda().half() |
| 134 | + y_pt = pt_model(x) |
| 135 | + |
| 136 | + # 2. Build AIT graph |
| 137 | + ait_model = AITSimpleModel(hidden) |
| 138 | + X = Tensor( |
| 139 | + shape=[batch_size, hidden], |
| 140 | + name="X", |
| 141 | + dtype="float16", |
| 142 | + is_input=True, |
| 143 | + ) |
| 144 | + Y = ait_model(X) |
| 145 | + Y._attrs["is_output"] = True |
| 146 | + Y._attrs["name"] = "Y" |
| 147 | + |
| 148 | + # 3. Map weights and compile |
| 149 | + weights = map_pt_params(ait_model, pt_model) |
| 150 | + target = _get_target() |
| 151 | + logging.getLogger("aitemplate").setLevel(logging.DEBUG) |
| 152 | + with compile_model( |
| 153 | + Y, target, "./tmp", "toy_simple_model", constants=weights |
| 154 | + ) as module: |
| 155 | + # 4. Run inference |
| 156 | + y_ait = torch.empty(batch_size, hidden).cuda().half() |
| 157 | + module.run_with_tensors({"X": x}, {"Y": y_ait}) |
| 158 | + |
| 159 | + # 5. Verify |
| 160 | + close = torch.allclose(y_ait, y_pt, atol=1e-2, rtol=1e-2) |
| 161 | + assert close, "nn.Module example: results do not match PyTorch!" |
| 162 | + print(f"Results match PyTorch: {close}") |
| 163 | + |
| 164 | + |
| 165 | +def main(): |
| 166 | + run_elementwise_example() |
| 167 | + run_nn_module_example() |
| 168 | + print("\nAll examples passed!") |
| 169 | + |
| 170 | + |
| 171 | +if __name__ == "__main__": |
| 172 | + main() |
0 commit comments