Skip to content

Commit 326a222

Browse files
committed
rewrite default value and reduce redundant tensor allocations in inference pipelines
1 parent 90a4e81 commit 326a222

File tree

3 files changed

+21
-21
lines changed

3 files changed

+21
-21
lines changed

generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def main():
314314
'--save_pre_sr_video false/0 to disable'
315315
)
316316
parser.add_argument(
317-
'--rewrite', type=str_to_bool, nargs='?', const=True, default=False,
317+
'--rewrite', type=str_to_bool, nargs='?', const=True, default=True,
318318
help='Enable prompt rewriting (default: true). '
319319
'Use --rewrite or --rewrite true/1 to enable, --rewrite false/0 to disable'
320320
)

hyvideo/pipelines/hunyuan_video_pipeline.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ def _prepare_vision_states(self, reference_image, target_resolution, latents, de
709709
torch.Tensor or None: Vision states tensor or None if vision encoder is unavailable.
710710
"""
711711
if reference_image is None:
712-
vision_states = torch.zeros(latents.shape[0], self.config.vision_num_semantic_tokens, self.config.vision_states_dim).to(latents.device)
712+
vision_states = torch.zeros(latents.shape[0], self.config.vision_num_semantic_tokens, self.config.vision_states_dim, device=latents.device, dtype=latents.dtype)
713713
else:
714714
reference_image = np.array(reference_image) if isinstance(reference_image, Image.Image) else reference_image
715715
if len(reference_image.shape) == 4:
@@ -753,11 +753,11 @@ def _prepare_cond_latents(self, task_type, cond_latents, latents, multitask_mask
753753
latents_concat = cond_latents.repeat(1, 1, latents.shape[2], 1, 1)
754754
latents_concat[:, :, 1:, :, :] = 0.0
755755
else:
756-
latents_concat = torch.zeros(latents.shape[0], latents.shape[1], latents.shape[2], latents.shape[3], latents.shape[4]).to(latents.device)
756+
latents_concat = torch.zeros_like(latents)
757757

758-
mask_zeros = torch.zeros(latents.shape[0], 1, latents.shape[2], latents.shape[3], latents.shape[4])
759-
mask_ones = torch.ones(latents.shape[0], 1, latents.shape[2], latents.shape[3], latents.shape[4])
760-
mask_concat = merge_tensor_by_mask(mask_zeros.cpu(), mask_ones.cpu(), mask=multitask_mask.cpu(), dim=2).to(device=latents.device)
758+
mask_zeros = torch.zeros(latents.shape[0], 1, latents.shape[2], latents.shape[3], latents.shape[4], device=latents.device)
759+
mask_ones = torch.ones(latents.shape[0], 1, latents.shape[2], latents.shape[3], latents.shape[4], device=latents.device)
760+
mask_concat = merge_tensor_by_mask(mask_zeros, mask_ones, mask=multitask_mask.to(device=latents.device), dim=2)
761761

762762
cond_latents = torch.concat([latents_concat, mask_concat], dim=1)
763763

@@ -1203,23 +1203,23 @@ def __call__(
12031203

12041204
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
12051205

1206-
t_expand = t.repeat(latent_model_input.shape[0])
1206+
t_expand = t.expand(latent_model_input.shape[0])
12071207
if self.use_meanflow:
12081208
if i == len(timesteps) - 1:
12091209
timesteps_r = torch.tensor([0.0], device=self.execution_device)
12101210
else:
12111211
timesteps_r = timesteps[i + 1]
1212-
timesteps_r = timesteps_r.repeat(latent_model_input.shape[0])
1212+
timesteps_r = timesteps_r.expand(latent_model_input.shape[0])
12131213
else:
12141214
timesteps_r = None
12151215

12161216
guidance_expand = (
1217-
torch.tensor(
1218-
[embedded_guidance_scale] * latent_model_input.shape[0],
1219-
dtype=torch.float32,
1217+
torch.full(
1218+
(latent_model_input.shape[0],),
1219+
embedded_guidance_scale * 1000.0,
1220+
dtype=self.target_dtype,
12201221
device=device,
1221-
).to(self.target_dtype)
1222-
* 1000.0
1222+
)
12231223
if embedded_guidance_scale is not None
12241224
else None
12251225
)

hyvideo/pipelines/hunyuan_video_sr_pipeline.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def _prepare_lq_cond_latents(self, lq_latents):
156156
torch.Tensor: Low-resolution conditional latent tensor.
157157
"""
158158
b, _, f, h, w = lq_latents.shape
159-
mask_ones = torch.ones(b, 1, f, h, w).to(lq_latents.device)
159+
mask_ones = torch.ones(b, 1, f, h, w, device=lq_latents.device, dtype=lq_latents.dtype)
160160
cond_latents = torch.concat([lq_latents, mask_ones], dim=1)
161161

162162
return cond_latents
@@ -379,23 +379,23 @@ def __call__(
379379

380380
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
381381

382-
t_expand = t.repeat(latent_model_input.shape[0])
382+
t_expand = t.expand(latent_model_input.shape[0])
383383
if not self.use_meanflow:
384384
timesteps_r = None
385385
else:
386386
if i == len(timesteps) - 1:
387387
timesteps_r = torch.tensor([0.0], device=self.execution_device)
388388
else:
389389
timesteps_r = timesteps[i + 1]
390-
timesteps_r = timesteps_r.repeat(latent_model_input.shape[0])
390+
timesteps_r = timesteps_r.expand(latent_model_input.shape[0])
391391

392392
guidance_expand = (
393-
torch.tensor(
394-
[embedded_guidance_scale] * latent_model_input.shape[0],
395-
dtype=torch.float32,
393+
torch.full(
394+
(latent_model_input.shape[0],),
395+
embedded_guidance_scale * 1000.0,
396+
dtype=self.target_dtype,
396397
device=device,
397-
).to(self.target_dtype)
398-
* 1000.0
398+
)
399399
if embedded_guidance_scale is not None
400400
else None
401401
)

0 commit comments

Comments
 (0)