|
11 | 11 | from ._fx import register_notrace_function |
12 | 12 | from .grid import ndgrid |
13 | 13 | from .trace_utils import _assert |
14 | | -from .weight_init import is_meta_device |
15 | | - |
16 | 14 |
|
17 | 15 | def pixel_freq_bands( |
18 | 16 | num_bands: int, |
@@ -188,8 +186,8 @@ def __init__( |
188 | 186 | self.keep_spatial = keep_spatial |
189 | 187 | self.register_buffer('bands', torch.empty(num_bands, device=device, dtype=dtype), persistent=False) |
190 | 188 |
|
191 | | - if not is_meta_device(device): |
192 | | - self.reset_parameters() |
| 189 | + # TODO: skip init when on meta device when safe to do so |
| 190 | + self.reset_parameters() |
193 | 191 |
|
194 | 192 | def reset_parameters(self) -> None: |
195 | 193 | """Initialize parameters and buffers.""" |
@@ -447,8 +445,8 @@ def __init__( |
447 | 445 | self.register_buffer('pos_embed_sin', torch.empty(emb_shape, device=device, dtype=dtype), persistent=False) |
448 | 446 | self.register_buffer('pos_embed_cos', torch.empty(emb_shape, device=device, dtype=dtype), persistent=False) |
449 | 447 |
|
450 | | - if not is_meta_device(device): |
451 | | - self.reset_parameters() |
| 448 | + # TODO: skip init when on meta device when safe to do so |
| 449 | + self.reset_parameters() |
452 | 450 |
|
453 | 451 | def reset_parameters(self) -> None: |
454 | 452 | """Initialize parameters and buffers.""" |
@@ -583,8 +581,8 @@ def __init__( |
583 | 581 | emb_shape = (num_pos, dim * 2) # concatenated sin & cos |
584 | 582 | self.register_buffer('pos_embed', torch.empty(emb_shape, device=device, dtype=dtype), persistent=False) |
585 | 583 |
|
586 | | - if not is_meta_device(device): |
587 | | - self.reset_parameters() |
| 584 | + # TODO: skip init when on meta device when safe to do so |
| 585 | + self.reset_parameters() |
588 | 586 |
|
589 | 587 | def reset_parameters(self) -> None: |
590 | 588 | """Initialize parameters and buffers.""" |
@@ -851,8 +849,8 @@ def __init__( |
851 | 849 | num_pos *= s |
852 | 850 | self.register_buffer('t_x', torch.empty(num_pos, device=device, dtype=dtype), persistent=False) |
853 | 851 | self.register_buffer('t_y', torch.empty(num_pos, device=device, dtype=dtype), persistent=False) |
854 | | - if not is_meta_device(device): |
855 | | - self._init_buffers() |
| 852 | + # TODO: skip init when on meta device when safe to do so |
| 853 | + self._init_buffers() |
856 | 854 | else: |
857 | 855 | self.t_x = self.t_y = None |
858 | 856 |
|
@@ -1087,8 +1085,8 @@ def __init__( |
1087 | 1085 | else: |
1088 | 1086 | self.pos_embed_cached = None |
1089 | 1087 |
|
1090 | | - if not is_meta_device(device): |
1091 | | - self.reset_parameters() |
| 1088 | + # TODO: skip init when on meta device when safe to do so |
| 1089 | + self.reset_parameters() |
1092 | 1090 |
|
1093 | 1091 | def reset_parameters(self) -> None: |
1094 | 1092 | """Initialize parameters and buffers.""" |
|
0 commit comments