Skip to content

Commit bdab30b

Browse files
committed
Remove torch.jit calls except for .ignore and .is_scripting to avoid 2.11 deprecation. Fix #2663
1 parent 9171d82 commit bdab30b

5 files changed

Lines changed: 2 additions & 69 deletions

File tree

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
EXCLUDE_FILTERS = ['*enormous*', '*_7b_*']
8383
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*', '*_3b_*', '*_7b_*']
8484

85-
EXCLUDE_JIT_FILTERS = ['hiera_*', '*naflex*', '*_7b_*']
85+
EXCLUDE_JIT_FILTERS = ['hiera_*', '*naflex*', '*_7b_*', 'hrnet*', 'dpn*', 'densenet*', 'selecsls*']
8686

8787
TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
8888
TARGET_BWD_SIZE = 128

timm/models/densenet.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,24 +72,13 @@ def any_requires_grad(self, x: List[torch.Tensor]) -> bool:
7272
return True
7373
return False
7474

75-
@torch.jit.unused # noqa: T484
7675
def call_checkpoint_bottleneck(self, x: List[torch.Tensor]) -> torch.Tensor:
7776
"""Call bottleneck function with gradient checkpointing."""
7877
def closure(*xs):
7978
return self.bottleneck_fn(xs)
8079

8180
return checkpoint(closure, *x)
8281

83-
@torch.jit._overload_method # noqa: F811
84-
def forward(self, x):
85-
# type: (List[torch.Tensor]) -> (torch.Tensor)
86-
pass
87-
88-
@torch.jit._overload_method # noqa: F811
89-
def forward(self, x):
90-
# type: (torch.Tensor) -> (torch.Tensor)
91-
pass
92-
9382
# torchscript does not yet support *args, so we overload method
9483
# allowing it to take either a List[Tensor] or single Tensor
9584
def forward(self, x: Union[torch.Tensor, List[torch.Tensor]]) -> torch.Tensor: # noqa: F811

timm/models/dpn.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,6 @@ def __init__(
3434
super().__init__()
3535
self.bn = norm_layer(in_chs, eps=0.001, **dd)
3636

37-
@torch.jit._overload_method # noqa: F811
38-
def forward(self, x):
39-
# type: (Tuple[torch.Tensor, torch.Tensor]) -> (torch.Tensor)
40-
pass
41-
42-
@torch.jit._overload_method # noqa: F811
43-
def forward(self, x):
44-
# type: (torch.Tensor) -> (torch.Tensor)
45-
pass
46-
4737
def forward(self, x):
4838
if isinstance(x, tuple):
4939
x = torch.cat(x, dim=1)
@@ -124,16 +114,6 @@ def __init__(
124114
self.c1x1_c1 = None
125115
self.c1x1_c2 = None
126116

127-
@torch.jit._overload_method # noqa: F811
128-
def forward(self, x):
129-
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
130-
pass
131-
132-
@torch.jit._overload_method # noqa: F811
133-
def forward(self, x):
134-
# type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
135-
pass
136-
137117
def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
138118
if isinstance(x, tuple):
139119
x_in = torch.cat(x, dim=1)

timm/models/hrnet.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -502,28 +502,12 @@ class SequentialList(nn.Sequential):
502502
def __init__(self, *args):
503503
super().__init__(*args)
504504

505-
@torch.jit._overload_method # noqa: F811
506-
def forward(self, x):
507-
# type: (List[torch.Tensor]) -> (List[torch.Tensor])
508-
pass
509-
510-
@torch.jit._overload_method # noqa: F811
511-
def forward(self, x):
512-
# type: (torch.Tensor) -> (List[torch.Tensor])
513-
pass
514-
515505
def forward(self, x) -> List[torch.Tensor]:
516506
for module in self:
517507
x = module(x)
518508
return x
519509

520510

521-
@torch.jit.interface
522-
class ModuleInterface(torch.nn.Module):
523-
def forward(self, input: torch.Tensor) -> torch.Tensor: # `input` has a same name in Sequential forward
524-
pass
525-
526-
527511
block_types_dict = {
528512
'BASIC': BasicBlock,
529513
'BOTTLENECK': Bottleneck
@@ -816,7 +800,7 @@ def forward_features(self, x):
816800
if y is None:
817801
y = incre(yl[i])
818802
else:
819-
down: ModuleInterface = self.downsamp_modules[i - 1] # needed for torchscript module indexing
803+
down = self.downsamp_modules[i - 1]
820804
y = incre(yl[i]) + down.forward(y)
821805

822806
y = self.final_layer(y)

timm/models/selecsls.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,6 @@ class SequentialList(nn.Sequential):
2727
def __init__(self, *args):
2828
super().__init__(*args)
2929

30-
@torch.jit._overload_method # noqa: F811
31-
def forward(self, x):
32-
# type: (List[torch.Tensor]) -> (List[torch.Tensor])
33-
pass
34-
35-
@torch.jit._overload_method # noqa: F811
36-
def forward(self, x):
37-
# type: (torch.Tensor) -> (List[torch.Tensor])
38-
pass
39-
4030
def forward(self, x) -> List[torch.Tensor]:
4131
for module in self:
4232
x = module(x)
@@ -49,16 +39,6 @@ def __init__(self, mode='index', index=0):
4939
self.mode = mode
5040
self.index = index
5141

52-
@torch.jit._overload_method # noqa: F811
53-
def forward(self, x):
54-
# type: (List[torch.Tensor]) -> (torch.Tensor)
55-
pass
56-
57-
@torch.jit._overload_method # noqa: F811
58-
def forward(self, x):
59-
# type: (Tuple[torch.Tensor]) -> (torch.Tensor)
60-
pass
61-
6242
def forward(self, x) -> torch.Tensor:
6343
if self.mode == 'index':
6444
return x[self.index]

0 commit comments

Comments
 (0)