Skip to content

Commit 2c4136c

Browse files
zoranzhaometa-codesync[bot]
authored andcommitted
Make toy example runnable on local H100/A100 dev server (#1048)
Summary: Pull Request resolved: #1048 as title, you never know whether the diff is written by me, or by claude code ;-( Reviewed By: jijunyan Differential Revision: D94554047 fbshipit-source-id: 1935dab76b3faf3a12fc9914ad43aaebea265006
1 parent 2676ddf commit 2c4136c

2 files changed

Lines changed: 188 additions & 0 deletions

File tree

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""
16+
AITemplate Toy Example
17+
======================
18+
Two patterns:
19+
1. Raw operator graph: elementwise tanh(X + 3)
20+
2. nn.Module style: Linear + GeLU + residual + LayerNorm
21+
22+
Run with:
23+
buck run fbcode//aitemplate/AITemplate/examples:toy_example
24+
"""
25+
26+
import logging
27+
28+
import torch
29+
from aitemplate.compiler import compile_model, ops
30+
from aitemplate.frontend import nn, Tensor
31+
from aitemplate.testing.detect_target import FBCUDA
32+
33+
34+
def _get_target(**kwargs):
35+
"""Create AIT CUDA target using the actual GPU compute capability.
36+
37+
On virtual hosts /etc/fbwhoami reports MODEL_NAME=VIRTUAL which causes
38+
detect_target() to default to SM80 even when H100 (SM90) GPUs are present.
39+
We detect the real SM version via torch.cuda instead.
40+
"""
41+
cc_major, cc_minor = torch.cuda.get_device_capability(0)
42+
gpu_arch = str(cc_major * 10 + cc_minor)
43+
return FBCUDA(arch=gpu_arch, **kwargs)
44+
45+
46+
class PTSimpleModel(torch.nn.Module):
47+
"""PyTorch reference model."""
48+
49+
def __init__(self, hidden, eps: float = 1e-5):
50+
super().__init__()
51+
self.dense1 = torch.nn.Linear(hidden, 4 * hidden)
52+
self.act1 = torch.nn.functional.gelu
53+
self.dense2 = torch.nn.Linear(4 * hidden, hidden)
54+
self.layernorm = torch.nn.LayerNorm(hidden, eps=eps)
55+
56+
def forward(self, input):
57+
hidden_states = self.dense1(input)
58+
hidden_states = self.act1(hidden_states)
59+
hidden_states = self.dense2(hidden_states)
60+
hidden_states = hidden_states + input
61+
hidden_states = self.layernorm(hidden_states)
62+
return hidden_states
63+
64+
65+
class AITSimpleModel(nn.Module):
66+
"""AITemplate equivalent — fuses GEMM + bias + GeLU into one kernel."""
67+
68+
def __init__(self, hidden, eps: float = 1e-5):
69+
super().__init__()
70+
self.dense1 = nn.Linear(hidden, 4 * hidden, specialization="fast_gelu")
71+
self.dense2 = nn.Linear(4 * hidden, hidden)
72+
self.layernorm = nn.LayerNorm(hidden, eps=eps)
73+
74+
def forward(self, input):
75+
hidden_states = self.dense1(input)
76+
hidden_states = self.dense2(hidden_states)
77+
hidden_states = hidden_states + input
78+
hidden_states = self.layernorm(hidden_states)
79+
return hidden_states
80+
81+
82+
def map_pt_params(ait_model, pt_model):
83+
ait_model.name_parameter_tensor()
84+
pt_params = dict(pt_model.named_parameters())
85+
mapped = {}
86+
for name, _ in ait_model.named_parameters():
87+
ait_name = name.replace(".", "_")
88+
assert name in pt_params
89+
mapped[ait_name] = pt_params[name]
90+
return mapped
91+
92+
93+
def run_elementwise_example():
94+
"""Example 1: Raw elementwise ops — Y = tanh(X + 3)"""
95+
print("\n" + "=" * 60)
96+
print("Example 1: Elementwise ops (Y = tanh(X + 3))")
97+
print("=" * 60)
98+
99+
# 1. Build graph
100+
X = Tensor(shape=[1024, 256], name="X", dtype="float16", is_input=True)
101+
Y = ops.tanh(X + 3.0)
102+
Y._attrs["is_output"] = True
103+
Y._attrs["name"] = "Y"
104+
105+
# 2. Compile
106+
target = _get_target()
107+
logging.getLogger("aitemplate").setLevel(logging.DEBUG)
108+
module = compile_model(Y, target, "./tmp", "toy_tanh_add")
109+
110+
# 3. Run inference
111+
x_pt = torch.randn(1024, 256).cuda().half()
112+
y_ait = torch.empty(1024, 256).cuda().half()
113+
module.run_with_tensors({"X": x_pt}, {"Y": y_ait})
114+
115+
# 4. Verify against PyTorch
116+
y_pt = torch.tanh(x_pt + 3.0)
117+
close = torch.allclose(y_ait, y_pt, atol=1e-2, rtol=1e-2)
118+
assert close, "Elementwise example: results do not match PyTorch!"
119+
print(f"Results match PyTorch: {close}")
120+
121+
122+
def run_nn_module_example():
123+
"""Example 2: nn.Module with weight mapping."""
124+
print("\n" + "=" * 60)
125+
print("Example 2: nn.Module (Linear + GeLU + residual + LayerNorm)")
126+
print("=" * 60)
127+
128+
batch_size, hidden = 1024, 512
129+
130+
# 1. Create and run PyTorch model
131+
pt_model = PTSimpleModel(hidden).cuda().half()
132+
pt_model.eval()
133+
x = torch.randn(batch_size, hidden).cuda().half()
134+
y_pt = pt_model(x)
135+
136+
# 2. Build AIT graph
137+
ait_model = AITSimpleModel(hidden)
138+
X = Tensor(
139+
shape=[batch_size, hidden],
140+
name="X",
141+
dtype="float16",
142+
is_input=True,
143+
)
144+
Y = ait_model(X)
145+
Y._attrs["is_output"] = True
146+
Y._attrs["name"] = "Y"
147+
148+
# 3. Map weights and compile
149+
weights = map_pt_params(ait_model, pt_model)
150+
target = _get_target()
151+
logging.getLogger("aitemplate").setLevel(logging.DEBUG)
152+
with compile_model(
153+
Y, target, "./tmp", "toy_simple_model", constants=weights
154+
) as module:
155+
# 4. Run inference
156+
y_ait = torch.empty(batch_size, hidden).cuda().half()
157+
module.run_with_tensors({"X": x}, {"Y": y_ait})
158+
159+
# 5. Verify
160+
close = torch.allclose(y_ait, y_pt, atol=1e-2, rtol=1e-2)
161+
assert close, "nn.Module example: results do not match PyTorch!"
162+
print(f"Results match PyTorch: {close}")
163+
164+
165+
def main():
166+
run_elementwise_example()
167+
run_nn_module_example()
168+
print("\nAll examples passed!")
169+
170+
171+
if __name__ == "__main__":
172+
main()

python/aitemplate/backend/cuda/target_def.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,22 @@ def __init__(self, arch="80", remote_cache_bytes=None, **kwargs):
325325
cutlass_lib_path = parutil.get_dir_path(
326326
"aitemplate/AITemplate/python/aitemplate/utils/mk_cutlass_lib"
327327
)
328+
# Ensure the cutlass_lib resource directory has __init__.py.
329+
# Buck may strip it when packaging genrule output as a resource.
330+
cutlass_lib_init = os.path.join(cutlass_lib_path, "cutlass_lib", "__init__.py")
331+
if not os.path.exists(cutlass_lib_init) and os.path.isdir(
332+
os.path.join(cutlass_lib_path, "cutlass_lib")
333+
):
334+
with open(cutlass_lib_init, "w") as f:
335+
f.write(
336+
"from . import library\n"
337+
"from . import generator\n"
338+
"from . import manifest\n"
339+
"from . import conv3d_operation\n"
340+
"from . import gemm_operation\n"
341+
"from . import conv2d_operation\n"
342+
"from . import extra_operation\n"
343+
)
328344
sys.path.append(cutlass_lib_path)
329345

330346
if not FBCUDA.nvcc_option_json:

0 commit comments

Comments
 (0)