-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathgenerate_rcd.py
More file actions
531 lines (485 loc) · 20.4 KB
/
generate_rcd.py
File metadata and controls
531 lines (485 loc) · 20.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
# CUDA_VISIBLE_DEVICES=0 python generate.py --model_dir yuezhouhu/RCD-SDAR-4B-b64-Thinking --ref_model_dir yuezhouhu/SeqD-SDAR-1.7B-b64-Thinking --trust_remote_code --block_length 64 --denoising_steps 64 --temperature 0 --dtype bfloat16 --confidence_threshold 0.85
import argparse
import torch
from torch.nn import functional as F
from transformers.cache_utils import DynamicCache
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
def top_k_logits(logits, k):
if k <= 0:
return logits
else:
values, _ = torch.topk(logits, k)
min_values = values[..., -1, None]
return torch.where(
logits < min_values, torch.full_like(logits, float("-inf")), logits
)
def top_p_logits(logits, p):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_mask = cumulative_probs > p
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
sorted_mask[..., 0] = False
mask_indices = torch.scatter(
torch.full_like(logits, False, dtype=torch.bool),
-1,
sorted_indices,
sorted_mask,
)
logits = logits.masked_fill(mask_indices, float("-inf"))
return logits
def sample_with_temperature_topk_topp(logits, temperature=1.0, top_k=0, top_p=1.0):
orig_shape = logits.shape[:-1] # [batch, block]
vocab_size = logits.shape[-1]
logits = logits.reshape(-1, vocab_size) # [batch*block, vocab]
probs = F.softmax(logits, dim=-1)
if temperature == 0:
token = logits.argmax(-1, keepdim=True)
else:
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
logits = top_k_logits(logits, top_k)
if top_p < 1.0:
logits = top_p_logits(logits, top_p)
processed_probs = F.softmax(logits, dim=-1) # shape: [batch*block, vocab]
token = torch.multinomial(processed_probs, num_samples=1) # [batch*block, 1]
token_prob = torch.gather(probs, -1, token) # [batch*block, 1]
return token.view(*orig_shape), token_prob.view(*orig_shape)
def get_num_transfer_tokens(block_length, steps):
base = block_length // steps
remainder = block_length % steps
num_transfer_tokens = torch.zeros(steps, dtype=torch.int64) + base
num_transfer_tokens[:remainder] += 1
return num_transfer_tokens
@torch.no_grad()
def block_diffusion_generate_loop(
model,
draft_model,
latent_strategy,
alpha,
pnorm,
loop,
tokenizer,
prompt,
mask_id,
gen_length=128,
block_length=8,
denoising_steps=8,
temperature=1.0,
top_k=0,
top_p=1.0,
remasking_strategy="low_confidence_dynamic",
confidence_threshold=0.85,
eb_threshold=None,
stopping_criteria_idx=None,
device="cuda",
):
model.eval()
draft_model.eval()
input_ids = prompt["input_ids"]
prompt_length = input_ids.shape[1]
shortest_prompt_length = input_ids.ne(mask_id).sum(dim=1).min().item()
all_prompt_lengths = input_ids.ne(mask_id).sum(dim=1)
past_key_values = DynamicCache()
draft_past_key_values = DynamicCache()
num_blocks = (prompt_length + gen_length + block_length - 1) // block_length
total_length = num_blocks * block_length
block_mask = torch.tril(torch.ones(num_blocks, num_blocks, device=device))
block_diffusion_attention_mask = (
block_mask.repeat_interleave(block_length, dim=0)
.repeat_interleave(block_length, dim=1)
.unsqueeze(0)
)
position_ids = torch.arange(total_length, device=device).unsqueeze(0)
x = torch.full((input_ids.shape[0], total_length), mask_id, dtype=torch.long, device=device)
x[:, :prompt_length] = input_ids
prefill_blocks = shortest_prompt_length // block_length
prefill_length = prefill_blocks * block_length
# Prefill stage
if prefill_length > 0:
cur_x = x[:, :prefill_length]
cur_attn_mask = block_diffusion_attention_mask[
:, :prefill_length, :prefill_length
]
cur_position_ids = position_ids[:, :prefill_length]
model(
cur_x,
attention_mask=cur_attn_mask,
position_ids=cur_position_ids,
past_key_values=past_key_values,
use_cache=True,
store_kv=True,
)
draft_model(cur_x,
attention_mask=cur_attn_mask,
position_ids=cur_position_ids,
past_key_values=draft_past_key_values,
use_cache=True,
store_kv=True)
num_transfer_tokens = get_num_transfer_tokens(block_length, denoising_steps)
# Decode stage
total_steps = [0 for _ in range(input_ids.shape[0])]
stop = [False for _ in range(input_ids.shape[0])]
for num_block in range(prefill_blocks, num_blocks):
cur_x = x[:, num_block * block_length: (num_block + 1) * block_length].clone()
cur_attn_mask = block_diffusion_attention_mask[
:,
num_block * block_length: (num_block + 1) * block_length,
: (num_block + 1) * block_length,
]
cur_position_ids = position_ids[
:, num_block * block_length: (num_block + 1) * block_length
]
block_steps = [0 for _ in range(input_ids.shape[0])]
draft_logits = None
for step in range(denoising_steps + 1):
mask_index = cur_x == mask_id
if all(mask_index[j].sum() == 0 or stop[j] for j in range(x.shape[0])):
# Store kv cache
draft_logits = draft_model(cur_x,
attention_mask=cur_attn_mask,
position_ids=cur_position_ids,
past_key_values=draft_past_key_values,
use_cache=True,
store_kv=True).logits
flat_logits = draft_logits.view(-1, draft_logits.size(-1))
mask_flat = mask_index.view(-1)
indices = mask_flat.nonzero(as_tuple=True)[0]
compressed_logits = flat_logits[indices]
compressed_p = compressed_logits.float().softmax(dim=-1).to(compressed_logits.dtype)
model(
cur_x,
mask=mask_index,
latent_strategy=latent_strategy,
alpha=alpha * mask_index.sum() / block_length if pnorm else alpha,
p=(compressed_p, indices),
attention_mask=cur_attn_mask,
position_ids=cur_position_ids,
past_key_values=past_key_values,
use_cache=True,
store_kv=True,
)
for j in range(x.shape[0]):
if all_prompt_lengths[j] < (num_block + 1) * block_length and not stop[j]:
block_steps[j] += 1
break
# Denosing
if draft_logits is None:
draft_logits = draft_model(cur_x,
attention_mask=cur_attn_mask,
position_ids=cur_position_ids,
past_key_values=draft_past_key_values,
use_cache=True,
store_kv=False).logits
flat_logits = draft_logits.view(-1, draft_logits.size(-1))
mask_flat = mask_index.view(-1)
indices = mask_flat.nonzero(as_tuple=True)[0]
compressed_logits = flat_logits[indices]
compressed_p = compressed_logits.float().softmax(dim=-1).to(compressed_logits.dtype)
logits = model(
cur_x,
mask=mask_index,
latent_strategy=latent_strategy,
alpha=alpha * mask_index.sum() / block_length if pnorm else alpha,
p=(compressed_p, indices),
attention_mask=cur_attn_mask,
position_ids=cur_position_ids,
past_key_values=past_key_values,
use_cache=True,
store_kv=False,
).logits
draft_logits = logits.clone()
# Sampling
x0, x0_p = sample_with_temperature_topk_topp(
logits, temperature=temperature, top_k=top_k, top_p=top_p
)
# Sampling strategy
if remasking_strategy == "sequential":
transfer_index = torch.zeros_like(x0, dtype=torch.bool)
for j in range(cur_x.shape[0]):
if mask_index[j].any().logical_not() or stop[j]:
continue
block_steps[j] += 1
if mask_index[j].any():
first_mask_index = (
mask_index[j].nonzero(as_tuple=True)[0].min().item()
)
transfer_index[
j,
first_mask_index: first_mask_index
+ num_transfer_tokens[step],
] = True
else:
raise ValueError("No mask tokens found in the current block.")
elif remasking_strategy == "low_confidence_static":
confidence = torch.where(mask_index, x0_p, -torch.inf)
transfer_index = torch.zeros_like(x0, dtype=torch.bool)
for j in range(confidence.shape[0]):
if mask_index[j].any().logical_not() or stop[j]:
continue
block_steps[j] += 1
_, idx = torch.topk(confidence[j], num_transfer_tokens[step])
transfer_index[j, idx] = True
elif remasking_strategy == "low_confidence_dynamic":
confidence = torch.where(mask_index, x0_p, -torch.inf)
transfer_index = torch.zeros_like(x0, dtype=torch.bool)
for j in range(confidence.shape[0]):
if mask_index[j].any().logical_not() or stop[j]:
continue
block_steps[j] += 1
high_conf_mask = confidence[j] > confidence_threshold
num_high_confidence = high_conf_mask.sum()
if num_high_confidence >= num_transfer_tokens[step]:
transfer_index[j] = high_conf_mask
else:
_, idx = torch.topk(confidence[j], num_transfer_tokens[step])
transfer_index[j, idx] = True
elif remasking_strategy == "entropy_bounded":
eps = 1e-12
entropies = -(x0_p.clamp_min(eps) * (x0_p.clamp_min(eps)).log()).sum(
dim=-1
)
entropies = torch.where(mask_index, entropies, torch.inf)
ent_sorted, order = torch.sort(entropies, dim=1, descending=False)
cumsum = torch.cumsum(ent_sorted, dim=1)
for j in range(x0_p.shape[0]):
if mask_index[j].any().logical_not() or stop[j]:
continue
block_steps[j] += 1
k = torch.searchsorted(
cumsum[j],
torch.tensor(eb_threshold, device=x0_p.device),
right=False,
).item()
k = max(1, min(k, int(mask_index[j].sum().item())))
selected_token_indices = order[j, :k]
transfer_index[j, selected_token_indices] = True
else:
raise ValueError(f"Unknown remasking strategy: {remasking_strategy}")
cur_x[transfer_index] = x0[transfer_index]
for j, block_step in enumerate(block_steps):
total_steps[j] += block_step
x[:, num_block * block_length: (num_block + 1) * block_length] = cur_x
print(tokenizer.batch_decode(x[:, :(num_block + 1) * block_length].tolist()))
if stopping_criteria_idx is not None:
can_stop = True
for j in range(x.shape[0]):
if any(stop_idx in x[j, all_prompt_lengths[j]:] for stop_idx in stopping_criteria_idx):
stop[j] = True
elif (num_block + 1) * block_length >= all_prompt_lengths[j] + gen_length:
stop[j] = True
else:
can_stop = False
if can_stop:
break
generated_tokens = []
actual_gen_lens = []
avg_tokens_per_step = []
for j in range(x.shape[0]):
sample_generated_tokens = x[j:j+1, all_prompt_lengths[j]:]
if stopping_criteria_idx is not None:
for stop_idx in stopping_criteria_idx:
stop_positions = (sample_generated_tokens == stop_idx).nonzero()
if stop_positions.numel() > 0:
first_stop_pos = stop_positions[0, 1].item() + 1 # +1 for inclusive
sample_generated_tokens = sample_generated_tokens[:, :first_stop_pos]
break
generated_tokens.append(sample_generated_tokens)
actual_gen_lens.append(sample_generated_tokens.shape[1])
avg_tokens_per_step.append(sample_generated_tokens.shape[1] / total_steps[j] if total_steps[j] > 0 else 0.0)
return {
"generated_tokens": generated_tokens,
"total_steps": total_steps,
"actual_gen_lens": actual_gen_lens,
"avg_tokens_per_step": avg_tokens_per_step,
}
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_dir",
type=str,
required=True,
help="Path to the pretrained model directory",
)
parser.add_argument(
"--ref_model_dir",
type=str,
required=True,
help="Path to the pretrained model directory",
)
parser.add_argument("--trust_remote_code", action="store_true")
parser.add_argument(
"--mask_id", type=int, default=None, help="Mask token id for Diffusion"
)
parser.add_argument(
"--prompt_length",
type=int,
default=4096,
help="Maximum prompt length in tokens",
)
parser.add_argument(
"--gen_length",
type=int,
default=16384,
help="Maximum generation length in tokens",
)
parser.add_argument(
"--block_length",
type=int,
default=4,
help="Length of token block to replace each denoising step",
)
parser.add_argument(
"--denoising_steps",
type=int,
default=4,
help="Number of denoising steps (iterations)",
)
parser.add_argument(
"--temperature", type=float, default=1.0, help="Sampling temperature"
)
parser.add_argument(
"--top_k", type=int, default=0, help="Top-K sampling (0 to disable)"
)
parser.add_argument(
"--top_p", type=float, default=1.0, help="Top-P sampling probability threshold"
)
parser.add_argument(
"--remasking_strategy",
type=str,
default="low_confidence_dynamic",
choices=[
"low_confidence_dynamic",
"low_confidence_static",
"sequential",
"entropy_bounded",
],
help="Strategy for remasking tokens",
)
parser.add_argument(
"--confidence_threshold",
type=float,
default=0.85,
help="Confidence threshold for low-confidence remasking",
)
parser.add_argument(
"--eb_threshold",
type=float,
default=None,
help="entropy threshold for entropy bounded sampling",
)
parser.add_argument(
"--stopping_criteria_idx",
type=int,
nargs="+",
default=None,
help="List of token IDs that stop generation (e.g. eos_token_id)",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
)
parser.add_argument(
"--dtype",
type=str,
default="float16",
choices=["float16", "bfloat16"],
)
# if args.remasking_strategy == "low_confidence_dynamic" and args.confidence_threshold is None:
# parser.error(
# "--confidence_threshold is required when --remasking_strategy=low_confidence_dynamic"
# )
# if args.remasking_strategy == "entropy_bounded" and args.eb_threshold is None:
# parser.error(
# "--eb_threshold is required when --remasking_strategy=entropy_bounded"
# )
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
args.model_dir,
trust_remote_code=True,
torch_dtype=args.dtype,
device_map=args.device,
)
ref_model = AutoModelForCausalLM.from_pretrained(
args.ref_model_dir,
trust_remote_code=True,
torch_dtype=args.dtype,
device_map=args.device,
)
tokenizer = AutoTokenizer.from_pretrained(
args.model_dir,
trust_remote_code=True,
)
tokenizer.pad_token = tokenizer.mask_token
if args.mask_id is None:
args.mask_id = tokenizer(tokenizer.mask_token)["input_ids"][0]
print(f"Mask id: {args.mask_id}")
if args.stopping_criteria_idx is None:
gen_cfg = GenerationConfig.from_pretrained(
args.model_dir,
)
args.stopping_criteria_idx = gen_cfg.eos_token_id
print(f"Stopping criteria index: {args.stopping_criteria_idx}")
if isinstance(args.stopping_criteria_idx, int):
args.stopping_criteria_idx = [
args.stopping_criteria_idx,
]
args.stop_words = tokenizer.convert_ids_to_tokens(args.stopping_criteria_idx)
print(f"Your Arguments: {args}")
origin_prompt = [
# [dict(role="user",
# content=
# """
# Solve the following math problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering.
#
# Define
# \[p = \sum_{k = 1}^\infty \frac{1}{k^2} \quad \text{and} \quad q = \sum_{k = 1}^\infty \frac{1}{k^3}.\]Find a way to write
# \[\sum_{j = 1}^\infty \sum_{k = 1}^\infty \frac{1}{(j + k)^3}\]in terms of $p$ and $q.$
# """.strip()
# )],
[dict(role="user",
content=r"Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ minutes spent in the coffee shop. When she walks $s+2$ kilometers per hour, the walk takes her 2 hours and 24 minutes, including $t$ minutes spent in the coffee shop. Suppose Aya walks at $s+\frac{1}{2}$ kilometers per hour. Find the number of minutes the walk takes her, including the $t$ minutes spent in the coffee shop."
)],
]
messages = tokenizer.apply_chat_template(
origin_prompt, add_generation_prompt=True, tokenize=False
)
tokenize_kwargs = dict(
return_tensors="pt",
padding=True,
truncation=True,
add_special_tokens=False,
max_length=args.prompt_length,
)
tokens = tokenizer.batch_encode_plus(messages, **tokenize_kwargs)
tokens = {k: v.to(model.device) for k, v in tokens.items()}
# print(tokens)
output_ids = block_diffusion_generate_loop(
model,
ref_model,
"normalized-entropy-interpolation",
1,
False,
True,
tokenizer=tokenizer,
prompt=tokens,
mask_id=args.mask_id,
gen_length=args.gen_length,
block_length=args.block_length,
denoising_steps=args.denoising_steps,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
remasking_strategy=args.remasking_strategy,
confidence_threshold=args.confidence_threshold,
eb_threshold=args.eb_threshold,
stopping_criteria_idx=args.stopping_criteria_idx,
)
for i in range(len(output_ids["generated_tokens"])):
output_text = tokenizer.decode(output_ids["generated_tokens"][i][0], skip_special_tokens=False)
cleaned_text = output_text.replace("<|MASK|>", "")
print(cleaned_text)