Skip to content

Commit 3d9aef1

Browse files
Support transformer-style tu- encoders in UnetPlusPlus and Linknet (#1294)
1 parent 4bf6ec0 commit 3d9aef1

4 files changed

Lines changed: 42 additions & 2 deletions

File tree

segmentation_models_pytorch/decoders/linknet/decoder.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def forward(
5050
self, x: torch.Tensor, skip: Optional[torch.Tensor] = None
5151
) -> torch.Tensor:
5252
x = self.block(x)
53-
if skip is not None:
53+
if skip is not None and skip.shape[1] != 0:
5454
x = x + skip
5555
return x
5656

@@ -71,6 +71,12 @@ def __init__(
7171
encoder_channels = encoder_channels[::-1]
7272

7373
channels = list(encoder_channels) + [prefinal_channels]
74+
for i in range(1, len(channels) - 1):
75+
# Transformer-style encoders may expose a 0-channel placeholder for the
76+
# missing 1/2-scale skip. Keep the decoder stream non-empty and just
77+
# skip feature fusion at that stage.
78+
if channels[i] == 0:
79+
channels[i] = channels[i - 1]
7480

7581
self.blocks = nn.ModuleList(
7682
[

segmentation_models_pytorch/decoders/unetplusplus/decoder.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@ def __init__(
1818
interpolation_mode: str = "nearest",
1919
):
2020
super().__init__()
21+
self.out_channels = out_channels
22+
self.is_empty = out_channels == 0
23+
self.interpolation_mode = interpolation_mode
24+
25+
if self.is_empty:
26+
self.conv1 = None
27+
self.attention1 = None
28+
self.conv2 = None
29+
self.attention2 = None
30+
return
31+
2132
self.conv1 = md.Conv2dReLU(
2233
in_channels + skip_channels,
2334
out_channels,
@@ -36,12 +47,14 @@ def __init__(
3647
use_norm=use_norm,
3748
)
3849
self.attention2 = md.Attention(attention_type, in_channels=out_channels)
39-
self.interpolation_mode = interpolation_mode
4050

4151
def forward(
4252
self, x: torch.Tensor, skip: Optional[torch.Tensor] = None
4353
) -> torch.Tensor:
4454
x = F.interpolate(x, scale_factor=2.0, mode=self.interpolation_mode)
55+
if self.is_empty:
56+
height, width = x.shape[2:]
57+
return x.new_empty(x.shape[0], 0, height, width)
4558
if skip is not None:
4659
x = torch.cat([x, skip], dim=1)
4760
x = self.attention1(x)

tests/models/test_linknet.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
1+
import torch
2+
3+
import segmentation_models_pytorch as smp
4+
15
from tests.models import base
26

37

48
class TestLinknetModel(base.BaseModelTester):
59
test_model_type = "linknet"
610
files_for_diff = [r"decoders/linknet/", r"base/"]
11+
12+
def test_timm_transformer_style_encoder(self):
13+
model = smp.Linknet("tu-convnext_atto", encoder_weights=None).eval()
14+
15+
with torch.inference_mode():
16+
output = model(torch.rand(1, 3, 256, 256))
17+
18+
assert output.shape == (1, 1, 256, 256)

tests/models/test_unetplusplus.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import segmentation_models_pytorch as smp
2+
import torch
23

34
from tests.models import base
45

@@ -7,6 +8,14 @@ class TestUnetPlusPlusModel(base.BaseModelTester):
78
test_model_type = "unetplusplus"
89
files_for_diff = [r"decoders/unetplusplus/", r"base/"]
910

11+
def test_timm_transformer_style_encoder(self):
12+
model = smp.UnetPlusPlus("tu-convnext_atto", encoder_weights=None).eval()
13+
14+
with torch.inference_mode():
15+
output = model(torch.rand(1, 3, 256, 256))
16+
17+
assert output.shape == (1, 1, 256, 256)
18+
1019
def test_interpolation(self):
1120
# test bilinear
1221
model_1 = smp.create_model(

0 commit comments

Comments
 (0)