Skip to content

Commit 41754f2

Browse files
committed
Fix #2661 ... don't skip reset_parameters/init when meta device detected as it breaks use of accelerate and similar dispatch override context managers
1 parent bdab30b commit 41754f2

18 files changed

Lines changed: 88 additions & 85 deletions

timm/layers/blur_pool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def __init__(
5151
filt_shape = (channels or 1, 1, filt_size, filt_size)
5252
self.register_buffer('filt', torch.empty(filt_shape, device=device, dtype=dtype), persistent=False)
5353

54-
if not self.filt.is_meta:
55-
self.reset_parameters()
54+
# TODO: skip init when on meta device when safe to do so
55+
self.reset_parameters()
5656

5757
def reset_parameters(self) -> None:
5858
"""Initialize buffers."""

timm/layers/lambda_layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ def __init__(
125125

126126
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
127127

128-
if not self.qkv.weight.is_meta:
129-
self.reset_parameters()
128+
# TODO: skip init when on meta device when safe to do so
129+
self.reset_parameters()
130130

131131
def reset_parameters(self) -> None:
132132
"""Initialize parameters and buffers."""

timm/layers/pos_embed_rel.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,8 @@ def __init__(
299299
persistent=False,
300300
)
301301

302-
if not self.relative_position_bias_table.is_meta:
303-
self.reset_parameters()
302+
# TODO: skip init when on meta device when safe to do so
303+
self.reset_parameters()
304304

305305
def reset_parameters(self) -> None:
306306
"""Initialize parameters and buffers."""
@@ -420,8 +420,8 @@ def __init__(
420420
persistent=False,
421421
)
422422

423-
if not self.mlp.fc1.weight.is_meta:
424-
self.reset_parameters()
423+
# TODO: skip init when on meta device when safe to do so
424+
self.reset_parameters()
425425

426426
def get_bias(self) -> torch.Tensor:
427427
relative_position_bias = self.mlp(self.rel_coords_log)
@@ -554,8 +554,8 @@ def __init__(
554554
self.register_buffer('height_lookup', torch.empty(height_lookup_shape, **dd), persistent=False)
555555
self.register_buffer('width_lookup', torch.empty(width_lookup_shape, **dd), persistent=False)
556556

557-
if not self.relative_position_bias_table.is_meta:
558-
self.reset_parameters()
557+
# TODO: skip init when on meta device when safe to do so
558+
self.reset_parameters()
559559

560560
def reset_parameters(self) -> None:
561561
"""Initialize parameters and buffers."""

timm/layers/pos_embed_sincos.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
from ._fx import register_notrace_function
1212
from .grid import ndgrid
1313
from .trace_utils import _assert
14-
from .weight_init import is_meta_device
15-
1614

1715
def pixel_freq_bands(
1816
num_bands: int,
@@ -188,8 +186,8 @@ def __init__(
188186
self.keep_spatial = keep_spatial
189187
self.register_buffer('bands', torch.empty(num_bands, device=device, dtype=dtype), persistent=False)
190188

191-
if not is_meta_device(device):
192-
self.reset_parameters()
189+
# TODO: skip init when on meta device when safe to do so
190+
self.reset_parameters()
193191

194192
def reset_parameters(self) -> None:
195193
"""Initialize parameters and buffers."""
@@ -447,8 +445,8 @@ def __init__(
447445
self.register_buffer('pos_embed_sin', torch.empty(emb_shape, device=device, dtype=dtype), persistent=False)
448446
self.register_buffer('pos_embed_cos', torch.empty(emb_shape, device=device, dtype=dtype), persistent=False)
449447

450-
if not is_meta_device(device):
451-
self.reset_parameters()
448+
# TODO: skip init when on meta device when safe to do so
449+
self.reset_parameters()
452450

453451
def reset_parameters(self) -> None:
454452
"""Initialize parameters and buffers."""
@@ -583,8 +581,8 @@ def __init__(
583581
emb_shape = (num_pos, dim * 2) # concatenated sin & cos
584582
self.register_buffer('pos_embed', torch.empty(emb_shape, device=device, dtype=dtype), persistent=False)
585583

586-
if not is_meta_device(device):
587-
self.reset_parameters()
584+
# TODO: skip init when on meta device when safe to do so
585+
self.reset_parameters()
588586

589587
def reset_parameters(self) -> None:
590588
"""Initialize parameters and buffers."""
@@ -851,8 +849,8 @@ def __init__(
851849
num_pos *= s
852850
self.register_buffer('t_x', torch.empty(num_pos, device=device, dtype=dtype), persistent=False)
853851
self.register_buffer('t_y', torch.empty(num_pos, device=device, dtype=dtype), persistent=False)
854-
if not is_meta_device(device):
855-
self._init_buffers()
852+
# TODO: skip init when on meta device when safe to do so
853+
self._init_buffers()
856854
else:
857855
self.t_x = self.t_y = None
858856

@@ -1087,8 +1085,8 @@ def __init__(
10871085
else:
10881086
self.pos_embed_cached = None
10891087

1090-
if not is_meta_device(device):
1091-
self.reset_parameters()
1088+
# TODO: skip init when on meta device when safe to do so
1089+
self.reset_parameters()
10921090

10931091
def reset_parameters(self) -> None:
10941092
"""Initialize parameters and buffers."""

timm/models/beit.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@ def __init__(
179179
self.proj = nn.Linear(all_head_dim, dim, **dd)
180180
self.proj_drop = nn.Dropout(proj_drop)
181181

182-
if not self.proj.weight.is_meta:
183-
self.reset_parameters()
182+
# TODO: skip init when on meta device when safe to do so
183+
self.reset_parameters()
184184

185185
def _get_rel_pos_bias(self) -> torch.Tensor:
186186
"""Get relative position bias for the attention window.
@@ -362,8 +362,8 @@ def __init__(
362362
else:
363363
self.gamma_1, self.gamma_2 = None, None
364364

365-
if not self.norm1.weight.is_meta:
366-
self.reset_parameters()
365+
# TODO: skip init when on meta device when safe to do so
366+
self.reset_parameters()
367367

368368
def reset_parameters(self) -> None:
369369
"""Initialize parameters."""
@@ -416,8 +416,8 @@ def __init__(self, window_size: Tuple[int, int], num_heads: int, device=None, dt
416416
persistent=False,
417417
)
418418

419-
if not self.relative_position_bias_table.is_meta:
420-
self.reset_parameters()
419+
# TODO: skip init when on meta device when safe to do so
420+
self.reset_parameters()
421421

422422
def reset_parameters(self) -> None:
423423
"""Initialize parameters and buffers."""
@@ -569,8 +569,8 @@ def __init__(
569569
self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
570570
self.head_init_scale = head_init_scale
571571

572-
if not self.patch_embed.proj.weight.is_meta:
573-
self.init_weights(needs_reset=False)
572+
# TODO: skip init when on meta device when safe to do so
573+
self.init_weights(needs_reset=False)
574574

575575
def init_weights(self, needs_reset: bool = True) -> None:
576576
"""Initialize model weights.

timm/models/csatv2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,8 @@ def __init__(
226226
self.register_buffer('imagenet_mean', torch.empty(3, 1, 1, device=device, dtype=dtype), persistent=False)
227227
self.register_buffer('imagenet_std', torch.empty(3, 1, 1, device=device, dtype=dtype), persistent=False)
228228

229-
if not self.mean.is_meta:
230-
self.reset_parameters()
229+
# TODO: skip init when on meta device when safe to do so
230+
self.reset_parameters()
231231

232232
def reset_parameters(self) -> None:
233233
"""Initialize buffers."""
@@ -609,8 +609,8 @@ def __init__(
609609

610610
self.head = NormMlpClassifierHead(dims[-1], num_classes, pool_type=global_pool, **dd)
611611

612-
if not self.stem_dct.conv_y.weight.is_meta:
613-
self.init_weights(needs_reset=False)
612+
# TODO: skip init when on meta device when safe to do so
613+
self.init_weights(needs_reset=False)
614614

615615
def init_weights(self, needs_reset: bool = True):
616616
self.apply(partial(self._init_weights, needs_reset=needs_reset))

timm/models/deit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def __init__(self, *args, **kwargs):
4646
self.distilled_training = False # must set this True to train w/ distillation token
4747

4848
self.weight_init_mode = 'reset' if weight_init == 'skip' else weight_init
49-
if weight_init != 'skip' and not next(self.parameters()).is_meta:
49+
# TODO: skip init when on meta device when safe to do so
50+
if weight_init != 'skip':
5051
self.init_weights(needs_reset=False)
5152

5253
def init_weights(self, mode='', needs_reset=True):

timm/models/efficientformer_v2.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ def __init__(
159159
)
160160
self.attention_bias_cache = {}
161161

162-
if not self.attention_biases.is_meta:
163-
self.reset_parameters()
162+
# TODO: skip init when on meta device when safe to do so
163+
self.reset_parameters()
164164

165165
@torch.no_grad()
166166
def train(self, mode=True):
@@ -300,8 +300,8 @@ def __init__(
300300
)
301301
self.attention_bias_cache = {}
302302

303-
if not self.attention_biases.is_meta:
304-
self.reset_parameters()
303+
# TODO: skip init when on meta device when safe to do so
304+
self.reset_parameters()
305305

306306
@torch.no_grad()
307307
def train(self, mode=True):
@@ -719,8 +719,8 @@ def __init__(
719719
else:
720720
self.head_dist = None
721721

722-
if not self.norm.weight.is_meta:
723-
self.init_weights(needs_reset=False)
722+
# TODO: skip init when on meta device when safe to do so
723+
self.init_weights(needs_reset=False)
724724

725725
self.distilled_training = False
726726

timm/models/efficientvit_msra.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ def __init__(
207207
)
208208
self.attention_bias_cache = {}
209209

210-
if not self.attention_bias_idxs.is_meta:
211-
self.reset_parameters()
210+
# TODO: skip init when on meta device when safe to do so
211+
self.reset_parameters()
212212

213213
def reset_parameters(self) -> None:
214214
"""Initialize parameters and buffers."""
@@ -537,8 +537,8 @@ def __init__(
537537
self.head = NormLinear(
538538
self.num_features, num_classes, drop=self.drop_rate, **dd) if num_classes > 0 else torch.nn.Identity()
539539

540-
if not self.patch_embed.conv1.conv.weight.is_meta:
541-
self.init_weights(needs_reset=False)
540+
# TODO: skip init when on meta device when safe to do so
541+
self.init_weights(needs_reset=False)
542542

543543
def init_weights(self, needs_reset: bool = True):
544544
self.apply(partial(self._init_weights, needs_reset=needs_reset))

timm/models/eva.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ def __init__(
178178
self.proj = nn.Linear(attn_dim, dim, **dd)
179179
self.proj_drop = nn.Dropout(proj_drop)
180180

181-
if not self.proj.weight.is_meta:
182-
self.reset_parameters()
181+
# TODO: skip init when on meta device when safe to do so
182+
self.reset_parameters()
183183

184184
def reset_parameters(self) -> None:
185185
"""Initialize parameters and buffers."""
@@ -369,8 +369,8 @@ def __init__(
369369
self.gamma_2 = nn.Parameter(torch.empty(dim, **dd)) if init_values is not None else None
370370
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
371371

372-
if not self.norm1.weight.is_meta:
373-
self.reset_parameters()
372+
# TODO: skip init when on meta device when safe to do so
373+
self.reset_parameters()
374374

375375
def reset_parameters(self) -> None:
376376
"""Initialize parameters."""
@@ -738,8 +738,8 @@ def __init__(
738738
self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
739739
self.head_init_scale = head_init_scale
740740

741-
if not self.patch_embed.proj.weight.is_meta:
742-
self.init_weights(needs_reset=False)
741+
# TODO: skip init when on meta device when safe to do so
742+
self.init_weights(needs_reset=False)
743743

744744
def init_weights(self, needs_reset: bool = True):
745745
self.apply(partial(self._init_weights, needs_reset=needs_reset))

0 commit comments

Comments
 (0)