Skip to content

Commit ac52da0

Browse files
authored
add disable cuda graph unit test for eagle 2 (sgl-project#3412)
1 parent eeb7cdb commit ac52da0

1 file changed

Lines changed: 62 additions & 44 deletions

File tree

test/srt/test_eagle_infer.py

Lines changed: 62 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -30,51 +30,69 @@ def test_eagle_accuracy(self):
3030
ref_output = ref_engine.generate(prompt, sampling_params)["text"]
3131
ref_engine.shutdown()
3232

33-
# Launch EAGLE engine
34-
engine = sgl.Engine(
35-
model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
36-
speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
37-
speculative_algorithm="EAGLE",
38-
speculative_num_steps=5,
39-
speculative_eagle_topk=8,
40-
speculative_num_draft_tokens=64,
41-
mem_fraction_static=0.7,
42-
)
43-
44-
# Case 1: Test the output of EAGLE engine is the same as normal engine
45-
out1 = engine.generate(prompt, sampling_params)["text"]
46-
print(f"{out1=}, {ref_output=}")
47-
self.assertEqual(out1, ref_output)
48-
49-
# Case 2: Test the output of EAGLE engine does not contain unexpected EOS
50-
prompt = "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like [/INST]"
51-
sampling_params = {
52-
"temperature": 0,
53-
"max_new_tokens": 1024,
54-
"skip_special_tokens": False,
55-
}
56-
57-
tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
58-
out2 = engine.generate(prompt, sampling_params)["text"]
59-
print(f"{out2=}")
60-
tokens = tokenizer.encode(out2, truncation=False)
61-
assert tokenizer.eos_token_id not in tokens
62-
63-
# Case 3: Batched prompts
64-
prompts = [
65-
"Hello, my name is",
66-
"The president of the United States is",
67-
"The capital of France is",
68-
"The future of AI is",
33+
# Test cases with different configurations
34+
configs = [
35+
# Original config
36+
{
37+
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
38+
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
39+
"speculative_algorithm": "EAGLE",
40+
"speculative_num_steps": 5,
41+
"speculative_eagle_topk": 8,
42+
"speculative_num_draft_tokens": 64,
43+
"mem_fraction_static": 0.7,
44+
},
45+
# Config with CUDA graph disabled
46+
{
47+
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
48+
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
49+
"speculative_algorithm": "EAGLE",
50+
"speculative_num_steps": 5,
51+
"speculative_eagle_topk": 8,
52+
"speculative_num_draft_tokens": 64,
53+
"mem_fraction_static": 0.7,
54+
"disable_cuda_graph": True,
55+
},
6956
]
70-
sampling_params = {"temperature": 0, "max_new_tokens": 30}
71-
outputs = engine.generate(prompts, sampling_params)
72-
for prompt, output in zip(prompts, outputs):
73-
print("===============================")
74-
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
75-
76-
# Shutdown the engine
77-
engine.shutdown()
57+
58+
for config in configs:
59+
# Launch EAGLE engine
60+
engine = sgl.Engine(**config)
61+
62+
# Case 1: Test the output of EAGLE engine is the same as normal engine
63+
out1 = engine.generate(prompt, sampling_params)["text"]
64+
print(f"{out1=}, {ref_output=}")
65+
self.assertEqual(out1, ref_output)
66+
67+
# Case 2: Test the output of EAGLE engine does not contain unexpected EOS
68+
prompt = "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like [/INST]"
69+
sampling_params = {
70+
"temperature": 0,
71+
"max_new_tokens": 1024,
72+
"skip_special_tokens": False,
73+
}
74+
75+
tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
76+
out2 = engine.generate(prompt, sampling_params)["text"]
77+
print(f"{out2=}")
78+
tokens = tokenizer.encode(out2, truncation=False)
79+
assert tokenizer.eos_token_id not in tokens
80+
81+
# Case 3: Batched prompts
82+
prompts = [
83+
"Hello, my name is",
84+
"The president of the United States is",
85+
"The capital of France is",
86+
"The future of AI is",
87+
]
88+
sampling_params = {"temperature": 0, "max_new_tokens": 30}
89+
outputs = engine.generate(prompts, sampling_params)
90+
for prompt, output in zip(prompts, outputs):
91+
print("===============================")
92+
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
93+
94+
# Shutdown the engine
95+
engine.shutdown()
7896

7997

8098
prompts = [

0 commit comments

Comments
 (0)