Skip to content

Use .expand() instead of .repeat() to avoid unnecessary memory allocation in scatter_reduce_ index construction #65

@Mulanss

Description

@Mulanss

🚀 Suggestion: Use .expand() Instead of .repeat() in scatter_reduce_ Index Construction

Hi team 👋,

In hi_diffusers/models/moe.py at line 153, the following line is used to construct the index tensor for scatter_reduce_:

exp_token_idx.view(-1, 1).repeat(1, x.shape[-1])

This can be safely replaced with .expand(-1, x.shape[-1]), which produces the same result but avoids the unnecessary memory duplication caused by repeat(). Since this index tensor is not modified after creation, .expand() provides a more memory- and compute-efficient way to achieve the same broadcasting effect.


✅ Suggested Change

exp_token_idx.view(-1, 1).expand(-1, x.shape[-1])

✅ Benefits

  • 🧠 Reduces memory usage (especially when x.shape[-1] is large)
  • 🚀 Slightly improves performance (no data duplication)
  • 🧩 Same behavior and output as before
  • 📦 Cleaner and more efficient indexing logic

Note: This is safe because the tensor is only used as a read-only index, and expand() creates a broadcasted view without copying data.


Thanks for your excellent work on HiDream‑I1! 🙌

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions