@@ -115,15 +115,15 @@ def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
115115
116116
117117class DummyDataset (Dataset ):
118- def __init__ (self , num_samples = 100000 , width = 1024 , height = 576 , sample_frames = 25 ):
118+ def __init__ (self , base_folder : str , num_samples = 100000 , width = 1024 , height = 576 , sample_frames = 25 ):
119119 """
120120 Args:
121121 num_samples (int): Number of samples in the dataset.
122122 channels (int): Number of channels, default is 3 for RGB.
123123 """
124124 self .num_samples = num_samples
125125 # Define the path to the folder containing video frames
126- self .base_folder = 'bdd100k/images/track/mini'
126+ self .base_folder = base_folder
127127 self .folders = os .listdir (self .base_folder )
128128 self .channels = 3
129129 self .width = width
@@ -342,6 +342,11 @@ def parse_args():
342342 parser = argparse .ArgumentParser (
343343 description = "Script to train Stable Diffusion XL for InstructPix2Pix."
344344 )
345+ parser .add_argument (
346+ "--base_folder" ,
347+ required = True ,
348+ type = str ,
349+ )
345350 parser .add_argument (
346351 "--pretrained_model_name_or_path" ,
347352 type = str ,
@@ -711,6 +716,10 @@ def main():
711716 variant = "fp16" ,
712717 )
713718
719+ # attribute handling for models using DDP
720+ if isinstance (unet , (torch .nn .DataParallel , torch .nn .parallel .DistributedDataParallel )):
721+ unet = unet .module
722+
714723 # Freeze vae and image_encoder
715724 vae .requires_grad_ (False )
716725 image_encoder .requires_grad_ (False )
@@ -853,7 +862,7 @@ def load_model_hook(models, input_dir):
853862 # DataLoaders creation:
854863 args .global_batch_size = args .per_gpu_batch_size * accelerator .num_processes
855864
856- train_dataset = DummyDataset (width = args .width , height = args .height , sample_frames = args .num_frames )
865+ train_dataset = DummyDataset (args . base_folder , width = args .width , height = args .height , sample_frames = args .num_frames )
857866 sampler = RandomSampler (train_dataset )
858867 train_dataloader = torch .utils .data .DataLoader (
859868 train_dataset ,
@@ -946,9 +955,9 @@ def _get_add_time_ids(
946955 ):
947956 add_time_ids = [fps , motion_bucket_id , noise_aug_strength ]
948957
949- passed_add_embed_dim = unet .module . config .addition_time_embed_dim * \
958+ passed_add_embed_dim = unet .config .addition_time_embed_dim * \
950959 len (add_time_ids )
951- expected_add_embed_dim = unet .module . add_embedding .linear_1 .in_features
960+ expected_add_embed_dim = unet .add_embedding .linear_1 .in_features
952961
953962 if expected_add_embed_dim != passed_add_embed_dim :
954963 raise ValueError (
0 commit comments