forked from HiDream-ai/HiDream-I1
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmonkey_patch_cat.py
More file actions
49 lines (40 loc) · 1.55 KB
/
Copy pathmonkey_patch_cat.py
File metadata and controls
49 lines (40 loc) · 1.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import functools
# Save original torch.cat function
original_cat = torch.cat
# Helper function to synchronize devices
def sync_tensors_to_same_device(tensor_list):
"""Ensure all tensors are on the same device"""
if not tensor_list:
return tensor_list
# Find the target device (we choose the device of the first non-None tensor)
target_device = None
for t in tensor_list:
if t is not None and hasattr(t, 'device'):
target_device = t.device
break
if target_device is None:
return tensor_list
# Move all tensors to the same device
result = []
for t in tensor_list:
if t is not None and hasattr(t, 'device') and t.device != target_device:
# print(f"Moving tensor from {t.device} to {target_device}")
result.append(t.to(target_device))
else:
result.append(t)
return result
# Create a new torch.cat implementation that ensures all tensors are on the same device
@functools.wraps(original_cat)
def patched_cat(tensors, *args, **kwargs):
# Ensure all tensors are on the same device
synced_tensors = sync_tensors_to_same_device(tensors)
return original_cat(synced_tensors, *args, **kwargs)
# Apply the monkey patch
def apply_patch():
torch.cat = patched_cat
print("Applied torch.cat patch, will automatically synchronize tensors on different devices")
# Restore original function
def remove_patch():
torch.cat = original_cat
print("Removed torch.cat patch, restored original behavior")