@@ -30,51 +30,69 @@ def test_eagle_accuracy(self):
3030 ref_output = ref_engine .generate (prompt , sampling_params )["text" ]
3131 ref_engine .shutdown ()
3232
33- # Launch EAGLE engine
34- engine = sgl .Engine (
35- model_path = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
36- speculative_draft_model_path = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
37- speculative_algorithm = "EAGLE" ,
38- speculative_num_steps = 5 ,
39- speculative_eagle_topk = 8 ,
40- speculative_num_draft_tokens = 64 ,
41- mem_fraction_static = 0.7 ,
42- )
43-
44- # Case 1: Test the output of EAGLE engine is the same as normal engine
45- out1 = engine .generate (prompt , sampling_params )["text" ]
46- print (f"{ out1 = } , { ref_output = } " )
47- self .assertEqual (out1 , ref_output )
48-
49- # Case 2: Test the output of EAGLE engine does not contain unexpected EOS
50- prompt = "[INST] <<SYS>>\\ nYou are a helpful assistant.\\ n<</SYS>>\\ nToday is a sunny day and I like [/INST]"
51- sampling_params = {
52- "temperature" : 0 ,
53- "max_new_tokens" : 1024 ,
54- "skip_special_tokens" : False ,
55- }
56-
57- tokenizer = get_tokenizer (DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST )
58- out2 = engine .generate (prompt , sampling_params )["text" ]
59- print (f"{ out2 = } " )
60- tokens = tokenizer .encode (out2 , truncation = False )
61- assert tokenizer .eos_token_id not in tokens
62-
63- # Case 3: Batched prompts
64- prompts = [
65- "Hello, my name is" ,
66- "The president of the United States is" ,
67- "The capital of France is" ,
68- "The future of AI is" ,
33+ # Test cases with different configurations
34+ configs = [
35+ # Original config
36+ {
37+ "model_path" : DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
38+ "speculative_draft_model_path" : DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
39+ "speculative_algorithm" : "EAGLE" ,
40+ "speculative_num_steps" : 5 ,
41+ "speculative_eagle_topk" : 8 ,
42+ "speculative_num_draft_tokens" : 64 ,
43+ "mem_fraction_static" : 0.7 ,
44+ },
45+ # Config with CUDA graph disabled
46+ {
47+ "model_path" : DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST ,
48+ "speculative_draft_model_path" : DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST ,
49+ "speculative_algorithm" : "EAGLE" ,
50+ "speculative_num_steps" : 5 ,
51+ "speculative_eagle_topk" : 8 ,
52+ "speculative_num_draft_tokens" : 64 ,
53+ "mem_fraction_static" : 0.7 ,
54+ "disable_cuda_graph" : True ,
55+ },
6956 ]
70- sampling_params = {"temperature" : 0 , "max_new_tokens" : 30 }
71- outputs = engine .generate (prompts , sampling_params )
72- for prompt , output in zip (prompts , outputs ):
73- print ("===============================" )
74- print (f"Prompt: { prompt } \n Generated text: { output ['text' ]} " )
75-
76- # Shutdown the engine
77- engine .shutdown ()
57+
58+ for config in configs :
59+ # Launch EAGLE engine
60+ engine = sgl .Engine (** config )
61+
62+ # Case 1: Test the output of EAGLE engine is the same as normal engine
63+ out1 = engine .generate (prompt , sampling_params )["text" ]
64+ print (f"{ out1 = } , { ref_output = } " )
65+ self .assertEqual (out1 , ref_output )
66+
67+ # Case 2: Test the output of EAGLE engine does not contain unexpected EOS
68+ prompt = "[INST] <<SYS>>\\ nYou are a helpful assistant.\\ n<</SYS>>\\ nToday is a sunny day and I like [/INST]"
69+ sampling_params = {
70+ "temperature" : 0 ,
71+ "max_new_tokens" : 1024 ,
72+ "skip_special_tokens" : False ,
73+ }
74+
75+ tokenizer = get_tokenizer (DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST )
76+ out2 = engine .generate (prompt , sampling_params )["text" ]
77+ print (f"{ out2 = } " )
78+ tokens = tokenizer .encode (out2 , truncation = False )
79+ assert tokenizer .eos_token_id not in tokens
80+
81+ # Case 3: Batched prompts
82+ prompts = [
83+ "Hello, my name is" ,
84+ "The president of the United States is" ,
85+ "The capital of France is" ,
86+ "The future of AI is" ,
87+ ]
88+ sampling_params = {"temperature" : 0 , "max_new_tokens" : 30 }
89+ outputs = engine .generate (prompts , sampling_params )
90+ for prompt , output in zip (prompts , outputs ):
91+ print ("===============================" )
92+ print (f"Prompt: { prompt } \n Generated text: { output ['text' ]} " )
93+
94+ # Shutdown the engine
95+ engine .shutdown ()
7896
7997
8098prompts = [
0 commit comments