Skip to content

Commit a2543dc

Browse files
Add compile debug example script
1 parent 779bd75 commit a2543dc

1 file changed

Lines changed: 61 additions & 0 deletions

File tree

examples/compile_debug.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import logging
2+
3+
import torch
4+
import torch._dynamo
5+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
6+
7+
torch._logging.set_logs(
8+
dynamo=logging.INFO,
9+
graph_breaks=True,
10+
recompiles=True,
11+
recompiles_verbose=True,
12+
compiled_autograd_verbose=True,
13+
)
14+
15+
torch._dynamo.config.suppress_errors = False
16+
17+
18+
torch.set_float32_matmul_precision("high")
19+
20+
quantization_config = BitsAndBytesConfig(
21+
load_in_4bit=True,
22+
bnb_4bit_compute_dtype=torch.bfloat16,
23+
bnb_4bit_quant_type="nf4",
24+
bnb_4bit_use_double_quant=True,
25+
)
26+
27+
# torch._dynamo.config.capture_dynamic_output_shape_ops = True
28+
29+
# model_id = "google/gemma-2-2b-it"
30+
model_id = "Qwen/Qwen2.5-7B"
31+
32+
tokenizer = AutoTokenizer.from_pretrained(model_id)
33+
model = AutoModelForCausalLM.from_pretrained(
34+
model_id,
35+
quantization_config=quantization_config,
36+
device_map="auto",
37+
torch_dtype=torch.bfloat16,
38+
)
39+
40+
input_text = "Write me a poem about Machine Learning."
41+
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
42+
43+
compile_options = {
44+
# "epilogue_fusion": True,
45+
# "shape_padding": True,
46+
# "trace.enabled" : True,
47+
# "triton.cudagraphs" : False,
48+
}
49+
50+
# warmup
51+
outputs = model.generate(**input_ids, max_new_tokens=32)
52+
print(tokenizer.decode(outputs[0]))
53+
54+
# compile
55+
56+
model.forward = torch.compile(model.forward, dynamic=True, fullgraph=True, options=compile_options)
57+
58+
# model = torch.compile(model, dynamic=True, fullgraph=True, options=compile_options)
59+
60+
outputs = model.generate(**input_ids, max_new_tokens=32)
61+
print(tokenizer.decode(outputs[0]))

0 commit comments

Comments
 (0)