-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathmodel_eval.py
More file actions
4042 lines (3481 loc) · 200 KB
/
model_eval.py
File metadata and controls
4042 lines (3481 loc) · 200 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import ast
from copy import deepcopy
import json
import math
import os
import pickle
import random
import shutil
import string
import time
from datetime import datetime
from pathlib import Path
from typing import Optional
import pandas as pd
from constants import UNIDISC_DIR
from data_defs import InterleavedBatch
import einops
import numpy as np
from unidisc.utils.simple_llm import get_llm
from unidisc.utils.viz_utils import augment_image_with_random_object_coco, create_text_image
import torch
import torch.utils.checkpoint
from accelerate.utils import gather, gather_object
from image_utils import Im
from jaxtyping import Bool, Float, Integer
from PIL import Image
from tensordict import TensorDict, tensorclass
from torch import Tensor
from tqdm import tqdm
from collections import defaultdict
import torch.nn.functional as F
import utils
import wandb
from decoupled_utils import (barrier, dprint, get_num_gpus, get_rank, get_world_size,
gprint, is_main_process, print_memory_summary,
rprint, save_memory_profile, show_memory_usage, try_except, sanitize_filename)
from unidisc.tokenizers.chameleon_tokenizers import (decode_ids_batched,
get_chameleon_images)
from unidisc.tokenizers.image_tokenizers import decode_latents, get_image_batch
from unidisc.utils.throughput_monitor import get_available_flops
from model_utils import (_sample_categorical, empty_device_cache, get_chameleon_txt_indices, get_interleaved_block_mask, log,
remap_image_torch, replace_nan_dict,
wrapped_batch_decode)
from torch import nn
from model_utils import get_block_mask, MauveScore, Entropy
def get_anole_data(self, model, processor, prompt, image, dtype, device):
inputs = processor(text=prompt, images=[image], padding=True, return_tensors="pt").to(device=device, dtype=dtype)
image_tokens = model.model.get_image_tokens(inputs["pixel_values"])
special_image_mask = inputs["input_ids"] == model.model.vocabulary_mapping.image_token_id
image_tokens = image_tokens.to(inputs["input_ids"].device, inputs["input_ids"].dtype)
inputs["input_ids"] = inputs["input_ids"].masked_scatter(special_image_mask, image_tokens)
inputs.pop("pixel_values")
return inputs
def calculate_chameleon_perplexity(self, model, processor, prompts, images, dtype=torch.bfloat16, return_all=False, standalone=False):
"""
Calculate perplexities for multiple prompts and images using the Chameleon model.
Args:
model (ChameleonForConditionalGeneration): The Chameleon model.
processor (ChameleonProcessor): The Chameleon processor.
prompts (List[str]): List of prompt strings.
images (List[Image.Image]): List of PIL Image objects.
device (str): The device to use for computation (default: "cuda:0").
dtype (torch.dtype): The data type to use (default: torch.bfloat16).
Returns:
List[float]: List of perplexities for each prompt-image pair.
"""
device = self.device
if model is None or processor is None:
model = getattr(self, "chameleon_model", None)
processor = getattr(self, "chameleon_processor", None)
if model is None:
from image_utils import Im
from transformers import (ChameleonForConditionalGeneration, ChameleonProcessor)
self.chameleon_model = ChameleonForConditionalGeneration.from_pretrained("leloy/Anole-7b-v0.1-hf", torch_dtype=torch.bfloat16).to("cuda")
self.chameleon_processor = ChameleonProcessor.from_pretrained("leloy/Anole-7b-v0.1-hf")
model = self.chameleon_model
processor = self.chameleon_processor
assert len(prompts) == len(images), "Number of prompts and images must match"
perplexities = []
for prompt, image in zip(prompts, images):
if not standalone:
txt_first_prompt = f"{prompt} <image>"
img_first_prompt = f"<image> {prompt}"
else:
txt_first_prompt = prompt
img_first_prompt = "<image>"
tot_ppl = 0.0
tot_loss = 0.0
img_loss = 0.0
txt_loss = 0.0
for i, _prompt in enumerate([txt_first_prompt, img_first_prompt]):
inputs = self.get_anole_data(model, processor, _prompt, image, dtype, device)
img_start_tok_id = self.chameleon_processor.tokenizer(self.chameleon_processor.image_start_token)['input_ids'][1]
img_end_tok_id = self.chameleon_processor.tokenizer(self.chameleon_processor.image_end_token)['input_ids'][1]
if i == 0:
# text first
mod_mask = torch.cumsum(inputs['input_ids'] == img_start_tok_id, dim=1).bool()
else:
# img first
mod_mask = torch.cumsum(inputs['input_ids'] == img_end_tok_id, dim=1).bool()
mod_mask = mod_mask.cumsum(dim=1) > 1
output = model(
input_ids=inputs['input_ids'].to(device),
attention_mask=inputs['attention_mask'].to(device),
labels=inputs['input_ids'].to(device)
)
loss = output.loss
perplexity = torch.exp(loss).item()
tot_ppl += perplexity
logits = output.logits
logits = logits.transpose(-1, -2)
sample_chunk = inputs["input_ids"]
nlls = F.cross_entropy(logits[..., :-1].to(self.device), sample_chunk[..., 1:].to(self.device), reduction="none")
mod_mask = mod_mask[:, 1:]
# img nll is where mod_mask == 1
zeros = torch.zeros_like(nlls)
img_nll = torch.where(mod_mask, nlls, zeros).mean().item()
txt_nll = torch.where(~mod_mask, nlls, zeros).mean().item()
tot_loss += loss.item()
if not standalone:
txt_loss += txt_nll
img_loss += img_nll
else:
if i == 0:
txt_loss += loss.item()
else:
img_loss += loss.item()
if not standalone:
tot_ppl /= 2
tot_loss /= 2
img_loss /= 2
txt_loss /= 2
if return_all:
perplexities.append((tot_ppl, tot_loss, img_loss, txt_loss))
else:
perplexities.append(tot_ppl)
print(f"Total PPL: {tot_ppl} | Total Loss: {tot_loss} | Img Loss: {img_loss} | Txt Loss: {txt_loss}")
return perplexities
def get_every_n_evals(self, n):
return (
self.config.mode == "eval"
or ((self.num_evals > 0 or getattr(self.config.eval, "log_on_start", False)) and n > 0 and self.num_evals % n == 0)
) and n != -1
@try_except(write_error_to_file=True)
def on_validation_epoch_start(self):
rprint("on_validation_epoch_start")
# EMA (Exponential Moving Average) is a technique used to maintain a moving average of model parameters
# It can help stabilize training and potentially improve model performance
if self.ema is not None and not self.config.trainer.use_custom_ema:
# Store the current model parameters in the EMA object
rprint(" [WARNING] USING EMA IN on_validation_epoch_start - THIS MIGHT RESET LOADED WEIGHTS ".center(100, "!"))
self.ema.store(self.get_params())
# Copy the EMA parameters to the current model
self.ema.copy_to(self.get_params())
self.backbone.eval()
self.reset_validation_metrics()
if getattr(self.config.trainer, "disable_torchmetrics", False) is False:
assert self.valid_metrics.nll.mean_value == 0
assert self.valid_metrics.nll.weight == 0
if self.non_embedding_params < 1e9:
self.print_hashes()
if (
self.image_model
and getattr(self.config.model, "image_model_fid_eval", False)
and self.get_every_n_evals(getattr(self.config.eval, "log_every_n_fid", 10))
):
self.fid_eval = True
if self.config.eval.fid_mode == "inline":
from vqgan.inception_metrics import MultiInceptionMetrics
self.inception_metrics = MultiInceptionMetrics(
reset_real_features=False,
compute_unconditional_metrics=True,
compute_conditional_metrics=False,
compute_conditional_metrics_per_class=False,
num_classes=1000,
num_inception_chunks=10,
manifold_k=3,
)
if self.config.mode == "eval":
self.computed_tokens = []
else:
if getattr(self.config.eval, "force_fid_output_dir", None) is None:
shm_path = Path("/dev/shm") / os.getenv("USER")
fid_save_path = shm_path / Path(self.config.output_dir).parent.stem / Path(self.config.output_dir).stem / f"{self.num_evals}_{self.global_step}" / "fid_gen"
else:
fid_save_path = Path(getattr(self.config.eval, "force_fid_output_dir", None)) / "fid_gen"
fid_save_path.mkdir(parents=True, exist_ok=True)
fid_gt_path = fid_save_path.parent / (fid_save_path.name.replace("gen", "gt"))
fid_gt_path.mkdir(parents=True, exist_ok=True)
self.fid_gen_dir = fid_save_path
self.fid_gt_dir = fid_gt_path
rprint(f"FID eval output dir: {self.fid_gen_dir}, FID GT dir: {self.fid_gt_dir}")
rprint(f"Setting FID eval for epoch {self.num_evals}")
else:
self.fid_eval = False
if self.image_model and getattr(self.config.model, "image_model_fid_eval", False):
rprint(f"Not setting FID eval: num_evals: {self.num_evals} % {getattr(self.config.eval, 'log_every_n_fid', 10)}")
if self.config.eval.compute_img_to_txt_mauve_clip:
shm_path = Path("/dev/shm") / os.getenv("USER")
img_to_txt_mauve_save_path = shm_path / Path(self.config.output_dir).parent.stem / Path(self.config.output_dir).stem / f"{self.num_evals}_{self.global_step}" / "img_to_txt_mauve_gen"
img_to_txt_mauve_save_path.mkdir(parents=True, exist_ok=True)
img_to_txt_mauve_gt_path = img_to_txt_mauve_save_path.parent / (img_to_txt_mauve_save_path.name.replace("gen", "gt"))
img_to_txt_mauve_gt_path.mkdir(parents=True, exist_ok=True)
self.img_to_txt_mauve_gen_dir = img_to_txt_mauve_save_path
self.img_to_txt_mauve_gt_dir = img_to_txt_mauve_gt_path
rprint(f"Img to txt mauve eval gen dir: {self.img_to_txt_mauve_gen_dir}, gt dir: {self.img_to_txt_mauve_gt_dir}")
self.saved_tokens = defaultdict(list)
self.validation_start_time = time.time()
if getattr(self.config.trainer, "attach_oom_observer_eval", False):
from torchtnt.utils.oom import attach_oom_observer
attach_oom_observer(output_dir=str(self.config.output_dir), trace_max_entries=1000000)
rprint(f"Attached OOM observer to {self.config.output_dir}")
self.gpu_memory_reserved = torch.cuda.memory_reserved()
def sample(self, return_input_ids=False, **kwargs):
continuous_mode = self.config.trainer.image_mode == "continuous"
text_only = kwargs.get("text_only", False)
kwargs.pop("text_only", None)
assert not continuous_mode
txt_tokens, img_tokens = self._sample(text_only=text_only, **kwargs)
if img_tokens is not None:
img_pred = decode_latents(self.config, self.get_vae(), img_tokens)
else:
img_pred = None
if txt_tokens is not None:
txt_pred = wrapped_batch_decode(self.tokenizer, txt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True)
else:
txt_pred = None
if return_input_ids:
return txt_pred, img_pred, txt_tokens, img_tokens
else:
return txt_pred, img_pred
@torch.no_grad()
def predict_step(self, batch, batch_idx, dataloader_idx=0):
batch = self.update_batch(batch)
assert (batch["input_ids"][~batch["x0_unmask"]] == self.mask_index).all()
txt_pred, img_pred, txt_tokens, img_tokens = self.sample(x0=batch["input_ids"], x0_unmask=batch["x0_unmask"], return_input_ids=True)
batch.update(dict(txt_pred=txt_pred, img_pred=img_pred, txt_tokens=txt_tokens, img_tokens=(img_tokens + self.text_vocab_size)))
return batch
@torch.no_grad()
def zero_shot_eval_step(self, batch, batch_idx):
batch = self.zero_shot_update_batch(batch)
dataset_name = self.config.data.train
def get_similarity(x0, batch, num_timesteps=None, txt_cond=True, return_unweighed=False, do_unconditional=False):
# NOTE - this function assume [txt, img] order with self.config.model.txt_length + self.config.model.img_length
# given a batch of img+text, get the similarity score
return_unweighed = return_unweighed or getattr(self.config.eval, "return_unweighed_sim", False)
class_log_probs = []
unweighed_class_log_probs = []
num_timesteps = num_timesteps or self.config.sampling.steps
effective_batch_size = batch['modality'].shape[0]
empty_device_cache()
times = torch.linspace(0, 1, steps=num_timesteps + 2)[1:-1].to(self.device).to(torch.float32)
if getattr(self.config.eval, "use_random_timesteps_same_batch", False):
times = torch.rand(num_timesteps, device=x0.device)
times = torch.sort(times)[0]
if getattr(self.config.eval, "use_random_timesteps_diff_batch", False):
# get a (B, num_timesteps) random timesteps
times = torch.rand(effective_batch_size, num_timesteps, device=x0.device)
times = torch.sort(times)[0]
print(f'Times: {times}')
do_unconditional = do_unconditional or getattr(self.config.eval, "do_unconditional", False)
# unweighed/weighed, randomized but different over batch, randomized but same over batch,
cond_mask = torch.full_like(x0, False, device=x0.device).bool()
if txt_cond:
cond_mask[:, :self.config.model.txt_length] = True
else:
# img conditioned
cond_mask[:, self.config.model.txt_length:] = True
full_mask = torch.full_like(x0, self.mask_index, device=x0.device)
pad_mask = x0 == self.tokenizer.pad_token_id
rprint(f'Getting similarity with {times.shape[0]} timesteps, {effective_batch_size} samples, {do_unconditional} unconditional, {self.parameterization} parameterization, {self.config.eval.cfg} cfg, {num_timesteps} num_timesteps, {txt_cond} txt_cond')
# for t in times:
# # t = self._sample_t(1, x0.device).expand(effective_batch_size)
# breakpoint()
# if getattr(self.config.eval, "`use_random_timesteps_diff_batch`", False):
# t = t.expand(effective_batch_size)
# else:
# t = t.expand(1)
for i in range(num_timesteps):
empty_device_cache()
if getattr(self.config.eval, "use_random_timesteps_diff_batch", False):
t = times[:, i]
else:
t = times[i]
t = t.expand(effective_batch_size)
sigma, dsigma = self.noise(t)
# print(sigma, t)
unet_conditioning = None # sigma[:, None] -> This causes CUDA OOM
move_chance = 1 - torch.exp(-sigma[:, None])
xt, ignore_batch_mask_for_metrics, joint_ar_nar_mask, _, __ = self.q_xt(x0, move_chance, return_ignore_batch_mask_for_metrics=True, batch=batch)
if not do_unconditional:
cond = torch.where(cond_mask, x0, xt)
if self.config.eval.cfg is not None:
uncond = torch.where(cond_mask, full_mask, xt)
cond_output = self.forward(
cond, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=None, joint_ar_nar_mask=joint_ar_nar_mask, modality=batch['modality'], return_logits=True
)
uncond_output = self.forward(
uncond, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=None, joint_ar_nar_mask=joint_ar_nar_mask, modality=batch['modality'], return_logits=True
)
cat_output = torch.stack([cond_output, uncond_output])
logits = cfg(self.config, t, cat_output).squeeze(0)
model_output = self._subs_parameterization(logits, xt=xt, batch=batch, modality=batch['modality'])
else:
# return logits false so already done with subs parameterization
model_output = self.forward(
cond, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=None, joint_ar_nar_mask=joint_ar_nar_mask, modality=batch['modality']
)
else:
if self.config.eval.cfg is not None:
uncond = torch.where(cond_mask, full_mask, xt)
cond_output = self.forward(
xt, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=None, joint_ar_nar_mask=joint_ar_nar_mask, modality=batch['modality'], return_logits=True
)
uncond_output = self.forward(
uncond, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=None, joint_ar_nar_mask=joint_ar_nar_mask, modality=batch['modality'], return_logits=True
)
cat_output = torch.stack([cond_output, uncond_output])
logits = cfg(self.config, t, cat_output).squeeze(0)
model_output = self._subs_parameterization(logits, xt=xt, batch=batch, modality=batch['modality'])
else:
# return logits false so already done with subs parameterization
model_output = self.forward(
xt, unet_conditioning, return_additional_loss=True, batch=batch, x_img_emb=None, joint_ar_nar_mask=joint_ar_nar_mask, modality=batch['modality']
)
# print(f'Time: {t[0]}')
log_p_theta = torch.gather(input=model_output, dim=-1, index=x0[:, :, None]).squeeze(-1)
# print(f'Log P Theta before pad remove: {-log_p_theta.mean()} | {(log_p_theta == 0).sum()}')
zeros = torch.zeros_like(log_p_theta)
log_p_theta = torch.where(pad_mask, zeros, log_p_theta)
# zero out the loss on conditioned part
if not do_unconditional:
log_p_theta = torch.where(cond_mask, zeros, log_p_theta)
# print(f'Log P Theta after pad remove: {-log_p_theta.mean()} | {(log_p_theta == 0).sum()}')
std_weighting = (dsigma / torch.expm1(sigma))[:, None]
unweighed_log_p_theta = -log_p_theta
loss = -log_p_theta * std_weighting
log_probs = loss.sum(dim=-1) / (~pad_mask).sum(dim=-1)
unweighed_log_probs = unweighed_log_p_theta.sum(dim=-1) / (~pad_mask).sum(dim=-1)
# print(f'Weighed loss: {log_probs.mean()} | Log P Theta: {-log_p_theta.mean()} | Std Weighting: {std_weighting.mean()}')
class_log_probs.append(log_probs)
unweighed_class_log_probs.append(unweighed_log_probs)
overall_time_log_probs = torch.stack(class_log_probs) # (num_time, B)
unweighed_overall_time_log_probs = torch.stack(unweighed_class_log_probs) # (num_time, B)
if return_unweighed:
return unweighed_overall_time_log_probs.mean(dim=0) # (B)
return overall_time_log_probs.mean(dim=0) # (B)
def get_similarity_ar(x0, batch, txt_cond=True, do_unconditional=False, **kwargs):
# get likelihood for each token and then average
img_first = kwargs.get("img_first", False)
if img_first:
x0 = torch.cat([x0[:, self.config.model.txt_length:], x0[:, :self.config.model.txt_length]], dim=1)
mod = batch['modality']
mod = torch.cat([mod[:, self.config.model.txt_length:], mod[:, :self.config.model.txt_length]], dim=1)
else:
mod = batch['modality']
empty_device_cache()
do_unconditional = do_unconditional or getattr(self.config.eval, "do_unconditional", False)
if getattr(self.config.eval, "cfg", None):
rprint('NOT SETTING CFG for AR')
# if getattr(self.config.eval, "cfg", None):
# cat_mod_input_ids = torch.cat([x0, torch.where(batch['modality'] == 1, self.mask_index, x0)], dim=0)
# _modality = torch.cat([batch['modality'], batch['modality']], dim=0)
# cat_p_x0 = self.forward(
# cat_mod_input_ids,
# sigma=None,
# batch=dict(modality=_modality), modality=_modality
# )
# logit_c, logit_u = cat_p_x0.chunk(2, dim=0)
# _w = getattr(self.config.eval, "cfg", None)
# model_output = (1 + _w) * logit_c - _w * logit_u
# else:
model_output = self.forward(x=x0, sigma=None, modality=mod)
x0 = x0[:, 1:]
# attention_mask = batch['attention_mask'][0][None, :].repeat(x0.shape[0], 1)[:, 1:]
attention_mask = x0 != self.tokenizer.pad_token_id
log_p_theta = model_output.gather(-1, x0[:, :, None])[:, :, 0]
if img_first:
txt_sl = slice(self.config.model.img_length-1, None)
img_sl = slice(None, self.config.model.img_length-1)
else:
txt_sl = slice(None, self.config.model.txt_length - 1)
img_sl = slice(self.config.model.txt_length - 1, None)
nll = (-log_p_theta * attention_mask).sum(dim=-1) / attention_mask.sum(dim=-1)
txt_nll = (-log_p_theta[:, txt_sl] * attention_mask[:, txt_sl]).sum(dim=-1) / attention_mask[:, txt_sl].sum(dim=-1)
img_nll = (-log_p_theta[:, img_sl] * attention_mask[:, img_sl]).sum(dim=-1) / attention_mask[:, img_sl].sum(dim=-1)
if do_unconditional:
return nll
return img_nll if txt_cond else txt_nll
def get_similarity_chameleon(zipp, batch, txt_cond=True, do_unconditional=False, prompts=None, images=None, **kwargs):
# get likelihood for each token and then average
empty_device_cache()
img_first = kwargs.get("img_first", False)
img_start_tok_id = self.chameleon_processor.tokenizer(self.chameleon_processor.image_start_token)['input_ids'][1]
img_end_tok_id = self.chameleon_processor.tokenizer(self.chameleon_processor.image_end_token)['input_ids'][1]
do_unconditional = do_unconditional or getattr(self.config.eval, "do_unconditional", False)
if not prompts and not images:
prompt, image = zipp
if img_first:
_prompt = f"<image> {prompt}"
else:
_prompt = f"{prompt} <image>"
inputs = self.get_anole_data(self.chameleon_model, self.chameleon_processor, _prompt, image, dtype=self.dtype, device=self.device)
else:
inputs = self.get_anole_data(self.chameleon_model, self.chameleon_processor, prompts, images, dtype=self.dtype, device=self.device)
# mod mask which is one for image tokens from the indx we see img_start_tok_id to img_end_tok_id
if img_first:
mod_mask = torch.cumsum(inputs['input_ids'] == img_end_tok_id, dim=1).bool()
else:
mod_mask = torch.cumsum(inputs['input_ids'] == img_start_tok_id, dim=1).bool()
mod_mask = mod_mask.cumsum(dim=1) > 1
output = self.chameleon_model(
input_ids=inputs['input_ids'].to(self.device),
attention_mask=inputs['attention_mask'].to(self.device),
labels=inputs['input_ids'].to(self.device)
)
loss = output.loss
logits = output.logits
logits = logits.transpose(-1, -2)
sample_chunk = inputs["input_ids"]
nlls = F.cross_entropy(logits[..., :-1].to(self.device), sample_chunk[..., 1:].to(self.device), reduction="none")
mod_mask = mod_mask[:, 1:]
# img nll is where mod_mask == 1
zeros = torch.zeros_like(nlls)
img_nll = torch.where(mod_mask, nlls, zeros)
txt_nll = torch.where(~mod_mask, nlls, zeros)
if do_unconditional:
return nlls.mean(dim=-1)
return img_nll.mean(dim=-1) if txt_cond else txt_nll.mean(dim=-1)
if dataset_name == "nlphuji/flickr30k":
txt_tokens, img_tokens = self._sample(
text_only=False,
x0=batch["input_ids"],
x0_unmask=batch["attention_mask"],
modality=batch["modality"],
)
img_samples = decode_latents(self.config, self.get_vae(), img_tokens[:, :self.config.model.img_length])
txt_samples = wrapped_batch_decode(self.tokenizer, txt_tokens[:, self.config.model.img_length:], clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)
gt_text_samples = wrapped_batch_decode(self.tokenizer, batch['gt_input_ids'][:, :self.config.model.txt_length], skip_special_tokens=True, clean_up_tokenization_spaces=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)
self.compute_cider(txt_samples, gt_text_samples)
elif dataset_name == "facebook/winoground":
# breakpoint()
# if batch_idx <= 15:
# return
a0_0 = batch["input_ids_0_0"] # a
a0_1 = batch["input_ids_0_1"] # d
a1_0 = batch["input_ids_1_0"] # b
a1_1 = batch["input_ids_1_1"] # c
text_correct_count = 0
image_correct_count = 0
group_correct_count = 0
wino_chameleon = getattr(self.config.eval, "wino_chameleon", False)
s0_0, s0_1, s1_0, s1_1 = None, None, None, None
modes = ['image', 'text', 'group']
if wino_chameleon:
txt0 = wrapped_batch_decode(tokens=batch['caption_0_input_ids'], tokenizer=self.tokenizer, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)[0]
txt1 = wrapped_batch_decode(tokens=batch['caption_1_input_ids'], tokenizer=self.tokenizer, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)[0]
img0 = Im(batch['img_0']).pil
img1 = Im(batch['img_1']).pil
prompts = [txt0, txt0, txt1, txt1]
images = [img0, img1, img0, img1]
zipp = list(zip(prompts, images))
# note - signs are reversed since we have loss, so want to minimize instead of maximize
def text_correct(result):
return torch.logical_and(result["s0_i0"] < result["s1_i0"], result["s1_i1"] < result["s0_i1"])
def image_correct(result):
return torch.logical_and(result["s0_i0"] < result["s0_i1"], result["s1_i1"] < result["s1_i0"])
def group_correct(result):
return torch.logical_and(image_correct(result), text_correct(result))
results_cond = {}
for mode in modes:
do_unconditional = (mode == 'group')
txt_cond = not (mode == 'text')
img_first = mode == 'text'
if wino_chameleon:
do_unconditional = True
s0_0 = get_similarity_chameleon(zipp[0], batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first)
s0_1 = get_similarity_chameleon(zipp[1], batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first)
s1_0 = get_similarity_chameleon(zipp[2], batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first)
s1_1 = get_similarity_chameleon(zipp[3], batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first)
elif self.parameterization == "ar":
s0_0 = get_similarity_ar(a0_0, batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first)
s0_1 = get_similarity_ar(a0_1, batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first)
s1_0 = get_similarity_ar(a1_0, batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first)
s1_1 = get_similarity_ar(a1_1, batch, txt_cond=False, do_unconditional=do_unconditional, img_first=img_first)
else:
s0_0 = get_similarity(a0_0, batch, txt_cond=txt_cond, do_unconditional=do_unconditional)
s0_1 = get_similarity(a0_1, batch, txt_cond=txt_cond, do_unconditional=do_unconditional)
s1_0 = get_similarity(a1_0, batch, txt_cond=txt_cond, do_unconditional=do_unconditional)
s1_1 = get_similarity(a1_1, batch, txt_cond=txt_cond, do_unconditional=do_unconditional)
result = {
"s0_i0": s0_0,
"s0_i1": s0_1,
"s1_i0": s1_0,
"s1_i1": s1_1,
}
if mode == 'text':
results_cond['text'] = text_correct(result)
text_correct_count += text_correct(result).sum().item()
elif mode == 'image':
results_cond['image'] = image_correct(result)
image_correct_count += image_correct(result).sum().item()
elif mode == 'group':
if getattr(self.config.eval, "wino_group_conditional", False):
rprint('[Winoground] Using conditional group accuracy')
group_correct_count = (torch.logical_and(results_cond['text'], results_cond['image'])).sum().item()
else:
rprint('[Winoground] Using unconditional group accuracy')
group_correct_count += group_correct(result).sum().item()
bsz = a0_0.shape[0]
txt_acc = text_correct_count / bsz
img_acc = image_correct_count / bsz
group_acc = group_correct_count / bsz
self.win_text_accuracy.update(txt_acc)
self.win_image_accuracy.update(img_acc)
self.win_group_accuracy.update(group_acc)
running_avg_txt = self.win_text_accuracy.compute()
running_avg_img = self.win_image_accuracy.compute()
running_avg_group = self.win_group_accuracy.compute()
rprint(f"[{batch_idx}] Winoground Text Accuracy: {txt_acc} ({running_avg_txt}), Image Accuracy: {img_acc} ({running_avg_img}), Group Accuracy: {group_acc} ({running_avg_group})")
else:
# def randomize_batch - input is a batch. for the batch['input_ids'] which contains self.config.model.txt_length txt tokens + self.config.model.img_length img tokens which are PAIRED
# we want to randomly swap the img/txt tokens between each other
x0 = batch['input_ids']
img_first = getattr(self.config.model, "img_first", False)
only_one_correct = getattr(self.config.eval, "only_one_correct", False)
wino_chameleon = getattr(self.config.eval, "wino_chameleon", False)
# todo check attn mask for text retrieval
x0_txt = x0.clone()
x0_img = x0.clone()
if only_one_correct:
# for each sample from 1st batch onwards, shuffle the img/txt tokens, as in map randomly
x0c = x0.clone()
if img_first:
second_half = x0c[1:, self.config.model.img_length:]
else:
second_half = x0c[1:, self.config.model.txt_length:]
# shuffle second half
# second_half = second_half[torch.randperm(second_half.size(0))]
second_half = torch.cat([second_half[1:], second_half[0].unsqueeze(0)], dim=0)
# replace img tokens with txt tokens
if img_first:
x0c[1:, self.config.model.img_length:] = second_half
else:
x0c[1:, self.config.model.txt_length:] = second_half
if wino_chameleon:
if img_first:
img_tokens = x0c[:, :self.config.model.img_length]
txt_tokens = x0c[:, self.config.model.img_length:]
else:
txt_tokens = x0c[:, :self.config.model.txt_length]
img_tokens = x0c[:, self.config.model.txt_length:]
dec_txt = wrapped_batch_decode(self.tokenizer, txt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)
dec_imgs = decode_latents(self.config, self.get_vae(), img_tokens - self.text_vocab_size)
dec_imgs = [Im(img).pil for img in dec_imgs]
if img_first:
# append '<image>' to beginning of each txt sample
dec_txt = ['<image> ' + txt for txt in dec_txt]
else:
dec_txt = [txt + ' <image>' for txt in dec_txt]
class_sim = get_similarity_chameleon(None, batch, do_unconditional=True, img_first=img_first, prompts=dec_txt, images=dec_imgs)
if torch.isinf(class_sim).any():
rprint(f'[Chameleon] Inf found in class_sim, check transformers version')
breakpoint()
elif self.parameterization == "ar":
class_sim = get_similarity_ar(x0c, batch, do_unconditional=True)
else:
class_sim = get_similarity(x0c, batch, do_unconditional=True)
topk = class_sim.topk(k=1, dim=0, largest=False)
topk_indices = topk.indices
topk_acc = (topk_indices == 0).float().mean().item()
rprint(f"[{batch_idx}] Datacomp Correct Pair Retrieval Acc: {topk_acc} ({self.datacomp_img_acc.compute()})")
self.datacomp_img_acc.update(topk_acc)
else:
if img_first:
# image retrieval given text, so fix text
x0_txt[:, self.config.model.img_length:] = x0[0, self.config.model.img_length:] # make all texts the first text
# text retrieval given image
x0_img[:, :self.config.model.img_length] = x0[0, :self.config.model.img_length] # make all images the first image
else:
# image retrieval given text, so fix text
x0_txt[:, :self.config.model.txt_length] = x0[0, :self.config.model.txt_length] # make all texts the first text
# text retrieval given image
x0_img[:, self.config.model.txt_length:] = x0[0, self.config.model.txt_length:] # make all images the first image
if self.parameterization == "ar":
txt_class_sim = get_similarity_ar(x0_txt, batch, txt_cond=True)
img_class_sim = get_similarity_ar(x0_img, batch, txt_cond=True) # TODO MAYBE REVERT?
else:
txt_class_sim = get_similarity(x0_txt, batch, txt_cond=True)
img_class_sim = get_similarity(x0_img, batch, txt_cond=False)
img_topk = img_class_sim.topk(k=1, dim=0, largest=False)
txt_topk = txt_class_sim.topk(k=1, dim=0, largest=False)
img_topk_indices = img_topk.indices
txt_topk_indices = txt_topk.indices
img_acc = (img_topk_indices == 0).float().mean().item()
txt_acc = (txt_topk_indices == 0).float().mean().item()
rprint(f"[{batch_idx}] Datacomp Text Retrieval Acc: {img_acc}, Datacomp Image Retrieval Accuracy: {txt_acc}")
self.datacomp_img_acc.update(img_acc)
self.datacomp_txt_acc.update(txt_acc)
# img_class_sim is (B) - argmin since loss txt_conds
@torch.no_grad()
def validation_step(self, batch, batch_idx):
batch = self.update_batch(batch)
continuous_mode = self.config.trainer.image_mode == "continuous"
if self.config.mode == "eval":
logs = dict()
logs["gpu_max_mem_reserved_gb"] = torch.cuda.max_memory_reserved() / (1024**3)
logs["gpu_cur_mem_reserved_gb"] = torch.cuda.memory_reserved() / (1024**3)
logs["gpu_max_mem_allocated_gb"] = torch.cuda.max_memory_allocated() / (1024**3)
logs["gpu_cur_mem_allocated_gb"] = torch.cuda.memory_allocated() / (1024**3)
log({**logs, **self.get_step_metrics()})
if self.get_every_n_evals(getattr(self.config.eval, "log_every_n_evals", 10)) \
and self.image_model \
and (batch_idx == 0 or self.config.eval.visualize_data_only) \
and not continuous_mode:
self.visualize_samples(batch, batch_idx)
if self.config.eval.visualize_data_only: return
if batch_idx < self.config.eval.num_sample_batches and self.config.eval.compute_generative_perplexity:
if continuous_mode:
# todo update to use modality once multimodal batches update is done by alex
gt_text_samples = wrapped_batch_decode(self.tokenizer, batch['text_tokens'][:, :self.config.model.txt_length], skip_special_tokens=True, clean_up_tokenization_spaces=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos) # since input_ids is for images
else:
input_ids = batch["input_ids"]
pad_tokens = torch.full_like(input_ids, self.tokenizer.pad_token_id)
text_tokens = torch.where(batch["modality"] == 0, input_ids, pad_tokens)
gt_text_samples = wrapped_batch_decode(self.tokenizer, text_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)
if getattr(self.config.trainer, "disable_text_modality", False):
gt_text_samples = [' ']
self.compute_generative_perplexity(gt_text_samples, gt=True)
if getattr(self.config.trainer, "log_flops", False) \
and batch_idx == 0 \
and self.current_run_global_step <= 1 \
and self.config.trainer.fsdp is False:
self.log_flops(batch=batch, batch_idx=batch_idx)
if self.fid_eval:
if self.config.eval.fid_mode == "inline":
self.update_inline_fid(batch, batch_idx)
elif self.config.eval.fid_mode == "clean":
self.update_clean_fid(batch, batch_idx)
else:
raise ValueError(f"Invalid FID mode: {self.config.eval.fid_mode}")
if getattr(self.config.eval, "get_top_k", False) and self.config.parameterization == "ar":
self.get_top_k(batch, batch_idx)
try:
if self.config.eval.compute_img_to_txt_mauve_clip and not self.config.eval.unconditional_fid:
self.update_img_to_txt_mauve_clip(batch, batch_idx)
except Exception as e:
empty_device_cache()
rprint(f"Error in update_img_to_txt_mauve_clip: {e}")
if (self.get_every_n_evals(getattr(self.config.eval, "log_every_n_evals", 10)) \
and continuous_mode \
and self.config.eval.generate_samples \
and not self.config.eval.test_eval_speed):
# todo remove this from here and move to on_validation_epoch_end
data = self.sample_transfusion(batch_size_per_gpu=batch['input_ids'].shape[0])
# TODO @sid support batching. prob pass list of lists to be general.
rec_embs = [data.xt_img_embed[i, data.modality[i] == 1] for i in range(data.shape[0])]
# stack and transpose
rec_embs = torch.stack(rec_embs)
rec_txt = data.xt_ids[data.modality == 0][None]
recon_image = decode_latents(self.config, self.get_vae(), rec_embs, batched=True) # TODO @sid support batching e.g. not just first element. prob pass list of lists to be general.
txt = wrapped_batch_decode(self.tokenizer, rec_txt, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)
rprint(f"Sampled {len(txt)} text samples:\n {txt[:1][:50]}")
image_list = [wandb.Image(img) for img in recon_image]
val_loss = self.compute_loss(batch, prefix="val")
log({"val/gen_img": image_list, "val/loss": val_loss, **self.get_step_metrics()})
if (
self.get_every_n_evals(getattr(self.config.eval, "log_every_n_evals", 10))
and (self.unified_model or self.cub_model or self.vggface_model)
and batch_idx < getattr(self.config.eval, "num_masking_viz_batches", 1)
and not continuous_mode # todo add masking val support s
):
self.sample_masking(batch=batch, batch_idx=batch_idx)
return self.compute_loss(batch, prefix="val", batch_idx=batch_idx)
@try_except(write_error_to_file=True)
@torch.no_grad()
def zero_shot_eval_epoch_end(self, example_batch=None):
dataset_name = self.config.data.train
dprint("zero_shot_eval_epoch_end")
if dataset_name == "nlphuji/flickr30k":
cider_score = self.cider_score.compute()
rprint('Flickr30k CIDEr score: ', cider_score)
# log it
log({
'val/cider_score': cider_score
})
elif dataset_name == "facebook/winoground":
win_text_accuracy = self.win_text_accuracy.compute()
win_image_accuracy = self.win_image_accuracy.compute()
win_group_accuracy = self.win_group_accuracy.compute()
rprint(f'Winoground Text Accuracy: {win_text_accuracy}')
rprint(f'Winoground Image Accuracy: {win_image_accuracy}')
rprint(f'Winoground Group Accuracy: {win_group_accuracy}')
# log it
log({
'val/win_text_accuracy': win_text_accuracy,
'val/win_image_accuracy': win_image_accuracy,
'val/win_group_accuracy': win_group_accuracy
})
else:
datacomp_img_acc = self.datacomp_img_acc.compute()
datacomp_txt_acc = self.datacomp_txt_acc.compute()
rprint(f'Datacomp Text Accuracy: {datacomp_img_acc}')
rprint(f'Datacomp Image Accuracy: {datacomp_txt_acc}')
# log it
log({
'val/datacomp_text_retr_acc': datacomp_img_acc,
'val/datacomp_img_retr_acc': datacomp_txt_acc
})
@try_except(write_error_to_file=True)
@torch.no_grad()
def get_img_text_saturation_batch(self, example_batch):
max_sampling_steps = self.config.model.length
batch_size_per_gpu = example_batch["input_ids"].shape[0]
do_standalone = getattr(self.config.eval, "cham_standalone", False)
pplx_per_step = []
# make stpes linspace between 1 and max_sampling_steps with 100 steps
# steps = np.linspace(1, max_sampling_steps, 10).astype(int)
# steps = [1,2,4,8,16,32,64,128,256,512,1024]
steps = [1,2,4,8,16,32,64] # todo revert
rprint(f"do_standalone: {do_standalone} with steps: {steps}")
dec_txt_list = []
dec_img_list = []
for step in steps:
rprint(f"Step: {step}")
(txt_tokens, img_tokens), nfe_cnt = self._sample(text_only=False, batch_size_per_gpu=batch_size_per_gpu, sample_modality=example_batch["modality"], return_nfe=True, num_steps=step)
decoded_img = Im(decode_latents(self.config, self.get_vae(), img_tokens)).pil
decoded_txt = wrapped_batch_decode(self.tokenizer, txt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True, disable_mask_after_eos=self.config.data.disable_mask_after_eos)
if not isinstance(decoded_img, list):
decoded_img = [decoded_img]
if not isinstance(decoded_txt, list):
decoded_txt = [decoded_txt]
dec_txt_list.append(decoded_txt)
dec_img_list.append(decoded_img)
tot_ppl, tot_loss, img_loss, txt_loss = self.calculate_chameleon_perplexity(self.chameleon_model, self.chameleon_processor, prompts=decoded_txt, images=decoded_img, return_all=True)[0]
rprint(f"Step {step} - Total PPL: {tot_ppl} | Total Loss: {tot_loss} | Img Loss: {img_loss} | Txt Loss: {txt_loss}")
pplx_per_step.append((step, tot_ppl, tot_loss, img_loss, txt_loss))
empty_device_cache()
return dec_txt_list, dec_img_list, pplx_per_step
@torch.no_grad()
@try_except(write_error_to_file=True)
@torch.no_grad()
def on_validation_epoch_end(self, example_batch=None):
dprint("on_validation_epoch_end")
if self.config.eval.compute_val_metrics_standalone:
self.compute_val_metrics_standalone()
all_val_metrics = self.get_step_metrics()
all_val_metrics.update(self.valid_metrics.compute())
if hasattr(self, "valid_txt_metrics"):
valid_txt_metrics = self.valid_txt_metrics.compute()
valid_img_metrics = self.valid_img_metrics.compute()
all_val_metrics.update({
**{f"val/txt_{k.split('/')[-1]}": v for k, v in replace_nan_dict(valid_txt_metrics).items()},
**{f"val/img_{k.split('/')[-1]}": v for k, v in replace_nan_dict(valid_img_metrics).items()},
})
log(all_val_metrics)
gprint("example_batch['input_ids'].ndim: ", example_batch['input_ids'].ndim)
if example_batch['input_ids'].ndim == 3:
combined_batches = example_batch
example_batch = self.update_batch(example_batch[0])
else:
example_batch = self.update_batch(example_batch)
if self.config.eval.auto_enhance:
self.auto_enhance(combined_batches)
continuous_mode = self.config.trainer.image_mode == "continuous"
compute_chameleon_perplexity = getattr(self.config.eval, "compute_chameleon_perplexity", False)
all_images = []
with try_except(write_error_to_file=True, clear_cuda_cache=True):
if self.fid_eval:
if self.config.eval.fid_mode == "inline":
self.compute_inline_fid_eval()
elif self.config.eval.fid_mode == "clean":
self.compute_clean_fid_eval()
else:
raise ValueError(f"Invalid FID mode: {self.config.eval.fid_mode}")
if self.config.eval.calculate_clip_score:
prefix = "unconditional" if self.config.eval.unconditional_fid else "fid"
self.compute_clip_score(self.fid_gen_dir, f"{prefix}_gen")
self.compute_clip_score(self.fid_gt_dir, f"{prefix}_gt")
if self.config.trainer.ar_inpainting:
import shutil
target_dir = Path(self.fid_gt_dir).parent / "fid_inpainting"
target_dir.mkdir(parents=True, exist_ok=True)
for img_file in Path(self.fid_gt_dir).rglob("*.png"):
shutil.copy2(img_file, target_dir / img_file.name)
for json_file in Path(self.fid_gen_dir).rglob("*.json"):
shutil.copy2(json_file, target_dir / json_file.name)
self.compute_clip_score(target_dir, f"{prefix}_inpainting")
if self.config.eval.unconditional_fid and \
self.config.eval.compute_img_to_txt_mauve_during_unconditional_fid and self.config.eval.compute_img_to_txt_mauve_clip:
rprint("Computing img to txt mauve during unconditional fid")
# CLIP score is the same as the fid clip score so we don't need to compute it again
gen_txt_tokens = self.gather_tokens(self.saved_tokens["unconditional_gen_txt_tokens"])
gt_txt_tokens = self.gather_tokens(self.saved_tokens["unconditional_gt_txt_tokens"])
if not getattr(self.config.eval, "global_disable_mauve", False):
self.compute_mauve_entropy(self.fid_gen_dir, self.fid_gt_dir, gen_txt_tokens, gt_txt_tokens, "unconditional")
elif self.config.eval.compute_img_to_txt_mauve_clip:
gen_txt_tokens = self.gather_tokens(self.saved_tokens["img_to_txt_gen_txt_tokens"])
gt_txt_tokens = self.gather_tokens(self.saved_tokens["img_to_txt_gt_txt_tokens"])
if not getattr(self.config.eval, "global_disable_mauve", False):
self.compute_mauve_entropy(self.img_to_txt_mauve_gen_dir, self.img_to_txt_mauve_gt_dir, gen_txt_tokens, gt_txt_tokens, "img_to_txt")
if self.config.eval.calculate_clip_score:
self.compute_clip_score(self.img_to_txt_mauve_gen_dir, "img_to_txt_mauve_gen")
self.compute_clip_score(self.img_to_txt_mauve_gt_dir, "img_to_txt_mauve_gt")
self.compute_mauve_entropy(self.img_to_txt_mauve_gen_dir, self.img_to_txt_mauve_gt_dir, gen_txt_tokens, gt_txt_tokens, "img_to_txt")
should_eval_speed = getattr(self.config.eval, "test_eval_speed", False)
if self.config.eval.generate_samples:
with try_except(write_error_to_file=True):
empty_device_cache()
if getattr(self.config.eval, 'set_random_gen_seed', False):
new_seed = get_rank() * 10 + 32
torch.manual_seed(new_seed)
torch.cuda.manual_seed(new_seed)
random.seed(new_seed)
np.random.seed(new_seed)
tot_time_per_sample = []
tot_token_time_per_token = []
tot_nfe_cnt = 0
batch_size_per_gpu = self.config.loader.eval_batch_size
sampling_steps = self.config.sampling.steps
num_batches = self.config.eval.num_sample_batches
gen_ppl_max_batches = 1e8
compute_entropy = getattr(self.config.eval, "compute_entropy", False)
compute_gen_ppl = self.config.eval.compute_generative_perplexity
entropies = []
if self.config.eval.compute_standalone_mauve and not getattr(self.config.eval, "global_disable_mauve", False):
mauve_N = self.config.eval.mauve_num_samples
# we need to generate this many samples distributed over the batch size * num_gpus
# if not clean division, generate one extra batch we can discard later
num_batches = math.ceil(mauve_N / (batch_size_per_gpu * get_num_gpus()))
should_eval_speed = True # if we are generating this many samples might as well time it
gen_ppl_max_batches = getattr(self.config.eval, "gen_ppl_max_batches", 1e8) # since we are generating a lot of samples, we can compute gen ppl for a few batches but not all since that'll be slow with eval_mode = llama
compute_entropy = True
compute_gen_ppl = True
rprint(f"[MAUVE] Generating {mauve_N} samples with batch size {batch_size_per_gpu}, sampling steps {sampling_steps}, total length {self.config.model.length}, num_batches: {num_batches}, max_gen_ppl_batches: {gen_ppl_max_batches}")
rprint(f"Generating {num_batches} samples with batch size {batch_size_per_gpu}, sampling steps {sampling_steps}, total length {self.config.model.length}, compute_entropy: {compute_entropy}, compute_gen_ppl: {compute_gen_ppl}")
all_samples = []
get_img_text_saturation = getattr(self.config.eval, "get_img_text_saturation", False)
for i in tqdm(range(num_batches), desc="Generating samples"):
if get_img_text_saturation:
dec_txt_list, dec_img_list, all_vals = self.get_img_text_saturation_batch(example_batch)
# Prepare data for logging
df = pd.DataFrame(all_vals, columns=["step", "tot_ppl", "tot_loss", "img_loss", "txt_loss"])
df.to_csv(Path(self.config.output_dir) / f"img_text_saturation_batch_{i}.csv", index=False)
rprint(f"Saved img_text_saturation_batch_{i}.csv to {Path(self.config.output_dir) / f'img_text_saturation_batch_{i}.csv'}")
log_data = []
for (step, tot_ppl, tot_loss, img_loss, txt_loss), dec_txt, dec_img in zip(all_vals, dec_txt_list, dec_img_list):
concatenated_text = ' | '.join(dec_txt)
concatenated_image = dec_img[0]
log_data.append([step, tot_ppl, tot_loss, img_loss, txt_loss, concatenated_text, wandb.Image(concatenated_image)])
# Log to wandb
log_table = wandb.Table(columns=["Step", "Total PPL", "Total Loss", "Image Loss", "Text Loss", "Generated Text", "Generated Image"], data=log_data)
wandb.log({"img_text_saturation": log_table, "trainer/global_step": self.global_step})
rprint("Logged img_text_saturation table to wandb")
# log (step, Im)
# make it into pd df and store in output_dir
break
if should_eval_speed:
start_time = start_timing(sync=True, enable=True, message="Evaluating inference speed")
if self.parameterization == "ar" and continuous_mode:
data = self.sample_transfusion(text_only=True, batch_size_per_gpu=batch_size_per_gpu)
txt_tokens = data.xt_ids[:, self.static_txt_sl]
else:
(txt_tokens, img_tokens), nfe_cnt = self._sample(
text_only=False,
batch_size_per_gpu=batch_size_per_gpu,
sample_modality=example_batch["modality"],
return_nfe=True,
)
tot_nfe_cnt += nfe_cnt
if should_eval_speed:
tot_time = end_timing(start_time, enable=True, sync=True)
if continuous_mode: assert (data.modality == 0).all()
tot_time_per_sample.append(tot_time)
tot_token_time_per_token.append((tot_time) / self.config.model.length)
if compute_entropy:
entropies.append(self.compute_entropy(txt_tokens).item())
if compute_chameleon_perplexity:
all_images.extend(Im(decode_latents(self.config, self.get_vae(), img_tokens)).pil)
text_samples = wrapped_batch_decode(self.tokenizer, txt_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True)
if self.config.eval.compute_standalone_mauve and not getattr(self.config.eval, "global_disable_mauve", False):
self.mauve_predictions.extend(text_samples)
if len(text_samples) > 0 and len(text_samples[0]) > 0 and self.config.eval.compute_generative_perplexity and i <= gen_ppl_max_batches:
self.compute_generative_perplexity(text_samples)
rprint(f"Generated {len(text_samples)} samples - {[text_samples[i][:200] for i in range(min(len(text_samples), 5))]}")
all_samples.extend(text_samples)
# TODO: @ssg2 is this needed?
# Log the last generated samples
# if not compute_chameleon_perplexity:
# text_samples = all_samples[:self.config.sampling.num_sample_log]
# all_images = all_images[:self.config.sampling.num_sample_log]
avg_nfe_cnt = tot_nfe_cnt / num_batches
if should_eval_speed:
# TODO: @ssg2 is this needed?
# data_dict = {
# f"samples": wandb.Table(columns=["Generated Samples", "Time per sample", "Time per token", "Generated Images"], data=[[s, t, tt, wandb.Image(img)] for s, t, tt, img in zip(text_samples, tot_time_per_sample, tot_token_time_per_token, all_images )]),
# "trainer/global_step": self.global_step,
# }
data_dict = {
f"samples": wandb.Table(columns=["Generated Samples", "Generated Images"], data=[[s, wandb.Image(img)] for s, img in zip(all_samples[:self.config.sampling.num_sample_log], all_images[:self.config.sampling.num_sample_log])]),
"trainer/global_step": self.global_step,
}
assert len(tot_time_per_sample) == len(tot_token_time_per_token)
if len(tot_time_per_sample) > 1:
tot_time_per_sample = tot_time_per_sample[1:] # exclude warmup
tot_token_time_per_token = tot_token_time_per_token[1:]
print(f'Have {len(tot_time_per_sample)} samples')