-
Notifications
You must be signed in to change notification settings - Fork 58
Expand file tree
/
Copy pathconfig.py
More file actions
403 lines (315 loc) · 12.7 KB
/
config.py
File metadata and controls
403 lines (315 loc) · 12.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
from pathlib import Path
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator, model_validator
from ltxv_trainer.model_loader import LtxvModelVersion
from ltxv_trainer.quantization import QuantizationOptions
class ConfigBaseModel(BaseModel):
model_config = ConfigDict(extra="forbid")
class ModelConfig(ConfigBaseModel):
"""Configuration for the base model and training mode"""
model_source: str | Path | LtxvModelVersion = Field(
default=LtxvModelVersion.latest(),
description="Model source - can be a HuggingFace repo ID, local path, or LtxvModelVersion",
)
training_mode: Literal["lora", "full"] = Field(
default="lora",
description="Training mode - either LoRA fine-tuning or full model fine-tuning",
)
load_checkpoint: str | Path | None = Field(
default=None,
description="Path to a checkpoint file or directory to load from. "
"If a directory is provided, the latest checkpoint will be used.",
)
# noinspection PyNestedDecorators
@field_validator("model_source", mode="before")
@classmethod
def validate_model_source(cls, v): # noqa: ANN001, ANN206
"""Try to convert model source to LtxvModelVersion if possible."""
if isinstance(v, (str, LtxvModelVersion)):
try:
return LtxvModelVersion(v)
except ValueError:
return v
return v
class LoraConfig(ConfigBaseModel):
"""Configuration for LoRA fine-tuning"""
rank: int = Field(
default=64,
description="Rank of LoRA adaptation",
ge=2,
)
alpha: int = Field(
default=64,
description="Alpha scaling factor for LoRA",
ge=1,
)
dropout: float = Field(
default=0.0,
description="Dropout probability for LoRA layers",
ge=0.0,
le=1.0,
)
target_modules: list[str] = Field(
default=("to_k", "to_q", "to_v", "to_out.0"),
description="List of modules to target with LoRA",
)
class ConditioningConfig(ConfigBaseModel):
"""Configuration for conditioning during training"""
mode: Literal["none", "reference_video"] = Field(
default="none",
description="Type of conditioning to use during training",
)
first_frame_conditioning_p: float = Field(
default=0.1,
description="Probability of conditioning on the first frame during training",
ge=0.0,
le=1.0,
)
reference_latents_dir: str = Field(
default="ref_latents",
description="Directory name for latents of reference videos when using reference_video mode",
)
class OptimizationConfig(ConfigBaseModel):
"""Configuration for optimization parameters"""
learning_rate: float = Field(
default=5e-4,
description="Learning rate for optimization",
)
steps: int = Field(
default=3000,
description="Number of training steps",
)
batch_size: int = Field(
default=2,
description="Batch size for training",
)
gradient_accumulation_steps: int = Field(
default=1,
description="Number of steps to accumulate gradients",
)
max_grad_norm: float = Field(
default=1.0,
description="Maximum gradient norm for clipping",
)
optimizer_type: Literal["adamw", "adamw8bit"] = Field(
default="adamw",
description="Type of optimizer to use for training",
)
scheduler_type: Literal[
"constant",
"linear",
"cosine",
"cosine_with_restarts",
"polynomial",
] = Field(
default="linear",
description="Type of scheduler to use for training",
)
scheduler_params: dict = Field(
default_factory=dict,
description="Parameters for the scheduler",
)
enable_gradient_checkpointing: bool = Field(
default=False,
description="Enable gradient checkpointing to save memory at the cost of slower training",
)
class AccelerationConfig(ConfigBaseModel):
"""Configuration for hardware acceleration and compute optimization"""
mixed_precision_mode: Literal["no", "fp16", "bf16"] | None = Field(
default="bf16",
description="Mixed precision training mode",
)
quantization: QuantizationOptions | None = Field(
default=None,
description="Quantization precision to use",
)
load_text_encoder_in_8bit: bool = Field(
default=False,
description="Whether to load the text encoder in 8-bit precision to save memory",
)
compile_with_inductor: bool = Field(
default=True,
description="Compile the model with Torch Inductor",
)
compilation_mode: Literal["default", "reduce-overhead", "max-autotune"] = Field(
default="reduce-overhead",
description="Compilation mode for Torch Inductor",
)
class DataConfig(ConfigBaseModel):
"""Configuration for data loading and processing"""
preprocessed_data_root: str = Field(
description="Path to folder containing preprocessed training data",
)
num_dataloader_workers: int = Field(
default=2,
description="Number of background processes for data loading (0 means synchronous loading)",
ge=0,
)
class ValidationConfig(ConfigBaseModel):
"""Configuration for validation during training"""
prompts: list[str] = Field(
default_factory=list,
description="List of prompts to use for validation",
)
negative_prompt: str = Field(
default="worst quality, inconsistent motion, blurry, jittery, distorted",
description="Negative prompt to use for validation examples",
)
images: list[str] | None = Field(
default=None,
description="List of image paths to use for validation. "
"One image path must be provided for each validation prompt",
)
reference_videos: list[str] | None = Field(
default=None,
description="List of reference video paths to use for validation. "
"One video path must be provided for each validation prompt",
)
video_dims: tuple[int, int, int] = Field(
default=(704, 480, 161),
description="Dimensions of validation videos (width, height, frames)",
)
seed: int = Field(
default=42,
description="Random seed used when sampling validation videos",
)
inference_steps: int = Field(
default=50,
description="Number of inference steps for validation",
gt=0,
)
interval: int | None = Field(
default=100,
description="Number of steps between validation runs. If None, validation is disabled.",
gt=0,
)
videos_per_prompt: int = Field(
default=1,
description="Number of videos to generate per validation prompt",
gt=0,
)
guidance_scale: float = Field(
default=3.5,
description="Guidance scale to use during validation",
ge=1.0,
)
skip_initial_validation: bool = Field(
default=False,
description="Skip validation video sampling at step 0 (beginning of training)",
)
@field_validator("images")
@classmethod
def validate_num_images(cls, v: list[str] | None, info: ValidationInfo) -> list[str] | None:
"""Validate that number of images (if provided) matches number of prompts."""
num_prompts = len(info.data.get("prompts", []))
if v is not None and len(v) != num_prompts:
raise ValueError(f"Number of images ({len(v)}) must match number of prompts ({num_prompts})")
return v
@field_validator("reference_videos")
@classmethod
def validate_num_reference_videos(cls, v: list[str] | None, info: ValidationInfo) -> list[str] | None:
"""Validate that number of reference videos (if provided) matches number of prompts."""
num_prompts = len(info.data.get("prompts", []))
if v is not None and len(v) != num_prompts:
raise ValueError(f"Number of reference videos ({len(v)}) must match number of prompts ({num_prompts})")
return v
class CheckpointsConfig(ConfigBaseModel):
"""Configuration for model checkpointing during training"""
interval: int | None = Field(
default=None,
description="Number of steps between checkpoint saves. If None, intermediate checkpoints are disabled.",
gt=0,
)
keep_last_n: int = Field(
default=1,
description="Number of most recent checkpoints to keep. Set to -1 to keep all checkpoints.",
ge=-1,
)
class HubConfig(ConfigBaseModel):
"""Configuration for Hugging Face Hub integration"""
push_to_hub: bool = Field(default=False, description="Whether to push the model weights to the Hugging Face Hub")
hub_model_id: str | None = Field(
default=None, description="Hugging Face Hub repository ID (e.g., 'username/repo-name')"
)
@model_validator(mode="after")
def validate_hub_config(self) -> "HubConfig":
"""Validate that hub_model_id is not None when push_to_hub is True."""
if self.push_to_hub and not self.hub_model_id:
raise ValueError("hub_model_id must be specified when push_to_hub is True")
return self
class WandbConfig(ConfigBaseModel):
"""Configuration for Weights & Biases logging"""
enabled: bool = Field(
default=False,
description="Whether to enable W&B logging",
)
project: str = Field(
default="ltxv-trainer",
description="W&B project name",
)
entity: str | None = Field(
default=None,
description="W&B username or team",
)
tags: list[str] = Field(
default_factory=list,
description="Tags to add to the W&B run",
)
log_validation_videos: bool = Field(
default=True,
description="Whether to log validation videos to W&B",
)
class FlowMatchingConfig(ConfigBaseModel):
"""Configuration for flow matching training"""
timestep_sampling_mode: Literal["uniform", "shifted_logit_normal"] = Field(
default="shifted_logit_normal",
description="Mode to use for timestep sampling",
)
timestep_sampling_params: dict = Field(
default_factory=dict,
description="Parameters for timestep sampling",
)
class LtxvTrainerConfig(ConfigBaseModel):
"""Unified configuration for LTXV training"""
# Sub-configurations
model: ModelConfig = Field(default_factory=ModelConfig)
lora: LoraConfig | None = Field(default=None)
conditioning: ConditioningConfig = Field(default_factory=ConditioningConfig)
optimization: OptimizationConfig = Field(default_factory=OptimizationConfig)
acceleration: AccelerationConfig = Field(default_factory=AccelerationConfig)
data: DataConfig = Field(default_factory=DataConfig)
validation: ValidationConfig = Field(default_factory=ValidationConfig)
checkpoints: CheckpointsConfig = Field(default_factory=CheckpointsConfig)
hub: HubConfig = Field(default_factory=HubConfig)
flow_matching: FlowMatchingConfig = Field(default_factory=FlowMatchingConfig)
wandb: WandbConfig = Field(default_factory=WandbConfig)
# General configuration
seed: int = Field(
default=42,
description="Random seed for reproducibility",
)
output_dir: str = Field(
default="outputs",
description="Directory to save model outputs",
)
# noinspection PyNestedDecorators
@field_validator("output_dir")
@classmethod
def expand_output_path(cls, v: str) -> str:
"""Expand user home directory in output path."""
return str(Path(v).expanduser().resolve())
@model_validator(mode="after")
def validate_conditioning_compatibility(self) -> "LtxvTrainerConfig":
"""Validate that conditioning and validation configurations are compatible."""
# Check that reference videos are provided when using reference_video conditioning
if self.conditioning.mode == "reference_video" and self.validation.reference_videos is None:
raise ValueError(
"reference_videos must be provided in validation config when conditioning.mode is 'reference_video'"
)
# Check that LoRA config is provided when training mode is lora
if self.model.training_mode == "lora" and self.lora is None:
raise ValueError("LoRA configuration must be provided when training_mode is 'lora'")
# Check that LoRA config is provided when using reference_video conditioning with LoRA training mode
if self.conditioning.mode == "reference_video" and self.model.training_mode != "lora":
raise ValueError("Training mode must be 'lora' when using reference_video conditioning")
return self