Skip to content

Commit cfccf8a

Browse files
Adding attribute handling for DDP and non-DDP UNet models (#60)
1 parent 0e054db commit cfccf8a

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

train_svd.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,15 @@ def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
115115

116116

117117
class DummyDataset(Dataset):
118-
def __init__(self, num_samples=100000, width=1024, height=576, sample_frames=25):
118+
def __init__(self, base_folder: str, num_samples=100000, width=1024, height=576, sample_frames=25):
119119
"""
120120
Args:
121121
num_samples (int): Number of samples in the dataset.
122122
channels (int): Number of channels, default is 3 for RGB.
123123
"""
124124
self.num_samples = num_samples
125125
# Define the path to the folder containing video frames
126-
self.base_folder = 'bdd100k/images/track/mini'
126+
self.base_folder = base_folder
127127
self.folders = os.listdir(self.base_folder)
128128
self.channels = 3
129129
self.width = width
@@ -342,6 +342,11 @@ def parse_args():
342342
parser = argparse.ArgumentParser(
343343
description="Script to train Stable Diffusion XL for InstructPix2Pix."
344344
)
345+
parser.add_argument(
346+
"--base_folder",
347+
required=True,
348+
type=str,
349+
)
345350
parser.add_argument(
346351
"--pretrained_model_name_or_path",
347352
type=str,
@@ -711,6 +716,10 @@ def main():
711716
variant="fp16",
712717
)
713718

719+
# attribute handling for models using DDP
720+
if isinstance(unet, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
721+
unet = unet.module
722+
714723
# Freeze vae and image_encoder
715724
vae.requires_grad_(False)
716725
image_encoder.requires_grad_(False)
@@ -853,7 +862,7 @@ def load_model_hook(models, input_dir):
853862
# DataLoaders creation:
854863
args.global_batch_size = args.per_gpu_batch_size * accelerator.num_processes
855864

856-
train_dataset = DummyDataset(width=args.width, height=args.height, sample_frames=args.num_frames)
865+
train_dataset = DummyDataset(args.base_folder, width=args.width, height=args.height, sample_frames=args.num_frames)
857866
sampler = RandomSampler(train_dataset)
858867
train_dataloader = torch.utils.data.DataLoader(
859868
train_dataset,
@@ -946,9 +955,9 @@ def _get_add_time_ids(
946955
):
947956
add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
948957

949-
passed_add_embed_dim = unet.module.config.addition_time_embed_dim * \
958+
passed_add_embed_dim = unet.config.addition_time_embed_dim * \
950959
len(add_time_ids)
951-
expected_add_embed_dim = unet.module.add_embedding.linear_1.in_features
960+
expected_add_embed_dim = unet.add_embedding.linear_1.in_features
952961

953962
if expected_add_embed_dim != passed_add_embed_dim:
954963
raise ValueError(

0 commit comments

Comments
 (0)