-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathflux_tiny_imagenet.yaml
More file actions
200 lines (176 loc) · 6.39 KB
/
flux_tiny_imagenet.yaml
File metadata and controls
200 lines (176 loc) · 6.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# Example configuration for Flow Matching training
model:
# Structured component specifications
vae:
module: "models.flux_vae.AutoEncoder"
params:
in_channels: 3
z_channels: 16
scale_factor: 0.3611
shift_factor: 0.1159
text_encoder:
module: "models.vanilla_embedder.VanillaEmbedder"
params:
vocab_size: 1001
embedding_dim: 768
return_datum_lens: true
# Check the "data" section of this file for the correct paths
embeddings_path: "$MINFM_DATA_DIR/imagenet/ilsvrc2012_meta.pt::clip_embeddings"
txt_to_label_path: "$MINFM_DATA_DIR/imagenet/ilsvrc2012_meta.pt::txt_to_label"
clip_encoder:
module: "models.vanilla_embedder.VanillaEmbedder"
params:
vocab_size: 1001
embedding_dim: 768
return_datum_lens: false
# Check the "data" section of this file for the correct paths
embeddings_path: "$MINFM_DATA_DIR/imagenet/ilsvrc2012_meta.pt::clip_embeddings"
txt_to_label_path: "$MINFM_DATA_DIR/imagenet/ilsvrc2012_meta.pt::txt_to_label"
patchifier:
module: "models.patchifier.Patchifier"
params:
patch_size: [ 1, 2, 2 ] # [frames, height, width] - DiT typical
vae_latent_channels: 16 # VAE latent channels
# must agree with vae
vae_compression_factors: [ 1, 8, 8 ] # VAE compression factors [frames, height, width]
denoiser:
module: "models.flux_denoiser.FluxDenoiser"
params:
d_model: 1024
d_head: 64
# n_ds_blocks: 19
# n_ss_blocks: 38
n_ds_blocks: 8
n_ss_blocks: 16
d_txt: 768
d_vec: 768
# must match vae_latent_channels * prod(vae_compression_factors) in patchifier
d_img: 64
# must have sum equal to d_head;
# must have number of elements equal to patch_size in patchifier
rope_axis_dim: [ 8, 28, 28 ] # tyx coordinates
guidance_embed: false
fsdp:
meta_device_init: true
shard_size: 1
param_dtype: "bf16"
reduce_dtype: "fp32"
ac_freq: 0
blocks_attr: [ "double_blocks", "single_blocks" ]
reshard_after_forward_policy: "default"
blocks_per_shard_group: 12 # -1
time_sampler:
module: "utils_fm.noiser.TimeSampler"
params:
use_logit_normal: true
mu: 0.0 # Mean of the logit normal distribution
sigma: 1.0 # Standard deviation of the logit normal distribution
time_warper:
module: "utils_fm.noiser.TimeWarper"
params:
base_len: 256 # Base sequence length
base_shift: 0.5 # Base shift parameter for time warping
max_len: 4096 # Maximum sequence length
max_shift: 1.15 # Maximum shift parameter for time warping
time_weighter:
module: "utils_fm.noiser.TimeWeighter"
params:
use_logit_normal: false
mu: 0.0 # Mean of the logit normal distribution
sigma: 1.0 # Standard deviation of the logit normal distribution
flow_noiser:
module: "utils_fm.noiser.FlowNoiser"
params:
compute_dtype: "fp32" # Internal computation dtype: "fp32", "fp16", "bf16"
balancer:
use_dit_balancer: false # Use DIT balancer for sequence length balancing
dit_balancer_specs: "g1n8" # Bag specifications for DIT balancer
dit_balancer_gamma: 0.5 # Gamma parameter for DIT workload estimator
trainer:
module: "trainers.dit_trainer.DiTTrainer"
params:
# Text dropout probability
txt_drop_prob: 0.1
# EMA Settings
ema_decay: 0.999
# Training Schedule
max_steps: 1000_000
warmup_steps: 1000
# AdamW Optimizer Settings
max_lr: 0.0003 # Maximum learning rate for AdamW optimizer
min_lr: 0.00001 # Minimum learning rate for cosine decay schedule
weight_decay: 0.0 # L2 regularization weight decay coefficient
adam_betas: [ 0.9, 0.95 ] # Beta coefficients for AdamW momentum terms [beta1, beta2]
# Gradient accumulation settings
total_batch_size: 1024
# Gradient Safeguarding Settings
gradient_clip_norm: 1.0
grad_norm_spike_threshold: 2.0
grad_norm_spike_detection_start_step: 1000
# Checkpoint Settings
init_ckpt: null # Optional: "path/to/checkpoint"
init_ckpt_load_plan: "ckpt_model:mem_model,ckpt_ema:mem_ema,ckpt_optimizer:mem_optimizer,ckpt_scheduler:mem_scheduler,ckpt_step:mem_step"
ckpt_freq: 2000
exp_dir: "./experiments/flux_tiny_imagenet"
# Logging Settings
wandb_mode: "disabled" # online, offline, or disabled (disabled = no wandb logging)
wandb_project: "minFM"
wandb_name: "flux_tiny_imagenet" # Optional: experiment name, defaults to wandb auto-naming
# wandb_entity: <your-wandb-entity> # Optional: wandb entity/organization
# wandb_host: <your-wandb-host> # Optional: wandb host # Optional: Hostname for custom-hosted setup
log_freq: 20
# Validation Settings
val_freq: 10_000
val_num_samples: 10_000
# Inference Settings
inference_at_start: false
inference_then_exit: false
inference_freq: 2000
inferencer:
ckpt_dir: "./experiments/flux_tiny_imagenet/checkpoints/step_00098000"
inference_ops_args:
use_ema: false
prompt_file: "./resources/inference_imagenet_prompts.txt"
output_dir: "./experiments/inference_results_flux_tiny_imagenet"
img_fhw: [ 1, 256, 256 ]
samples_per_prompt: 4
num_steps: 50
neg_prompt: ""
cfg_scale: 5.0
eta: 1.0
file_ext: "jpg"
per_gpu_bs: 16
guidance: null
sample_method: "ddim"
save_as_npz: false
### Use the following inference setup for computing FID scores
### You can try different cfg_scale
### Usually lower cfg_scale leads to better FID scores, but visual quality may be worse
# inferencer:
# ckpt_dir: "./experiments/flux_tiny_imagenet/step_00380000"
# inference_ops_args:
# use_ema: true
# prompt_file: "./resources/inference_imagenet_1kcls.txt"
# output_dir: "./experiments/inference_results_flux_tiny_imagenet-cfg5"
# img_fhw: [ 1, 256, 256 ]
# samples_per_prompt: 50
# num_steps: 50
# neg_prompt: ""
# cfg_scale: 5.0
# eta: 1.0
# file_ext: "jpg"
# per_gpu_bs: 16
# guidance: null
# sample_method: "ddim"
# save_as_npz: true
data:
module: "data.imagenet.ImagenetDataModule"
params:
batch_size: 128
resolution: 256
num_workers: 16
p_horizon_flip: 0.5
data_root_dir: "$MINFM_DATA_DIR/imagenet"
image_metas_path: "$MINFM_DATA_DIR/imagenet/ilsvrc2012_meta.pt::image_metas"
label_to_txt_path: "$MINFM_DATA_DIR/imagenet/ilsvrc2012_meta.pt::label_to_txt"
use_precomputed_latents: true