Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion segmentation_models_pytorch/decoders/linknet/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def forward(
self, x: torch.Tensor, skip: Optional[torch.Tensor] = None
) -> torch.Tensor:
x = self.block(x)
if skip is not None:
if skip is not None and skip.shape[1] != 0:
x = x + skip
return x

Expand All @@ -71,6 +71,12 @@ def __init__(
encoder_channels = encoder_channels[::-1]

channels = list(encoder_channels) + [prefinal_channels]
for i in range(1, len(channels) - 1):
# Transformer-style encoders may expose a 0-channel placeholder for the
# missing 1/2-scale skip. Keep the decoder stream non-empty and just
# skip feature fusion at that stage.
if channels[i] == 0:
channels[i] = channels[i - 1]

self.blocks = nn.ModuleList(
[
Expand Down
15 changes: 14 additions & 1 deletion segmentation_models_pytorch/decoders/unetplusplus/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ def __init__(
interpolation_mode: str = "nearest",
):
super().__init__()
self.out_channels = out_channels
self.is_empty = out_channels == 0
self.interpolation_mode = interpolation_mode

if self.is_empty:
self.conv1 = None
self.attention1 = None
self.conv2 = None
self.attention2 = None
return

self.conv1 = md.Conv2dReLU(
in_channels + skip_channels,
out_channels,
Expand All @@ -36,12 +47,14 @@ def __init__(
use_norm=use_norm,
)
self.attention2 = md.Attention(attention_type, in_channels=out_channels)
self.interpolation_mode = interpolation_mode

def forward(
self, x: torch.Tensor, skip: Optional[torch.Tensor] = None
) -> torch.Tensor:
x = F.interpolate(x, scale_factor=2.0, mode=self.interpolation_mode)
if self.is_empty:
height, width = x.shape[2:]
return x.new_empty(x.shape[0], 0, height, width)
if skip is not None:
x = torch.cat([x, skip], dim=1)
x = self.attention1(x)
Expand Down
12 changes: 12 additions & 0 deletions tests/models/test_linknet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
import torch

import segmentation_models_pytorch as smp

from tests.models import base


class TestLinknetModel(base.BaseModelTester):
test_model_type = "linknet"
files_for_diff = [r"decoders/linknet/", r"base/"]

def test_timm_transformer_style_encoder(self):
model = smp.Linknet("tu-convnext_atto", encoder_weights=None).eval()

with torch.inference_mode():
output = model(torch.rand(1, 3, 256, 256))

assert output.shape == (1, 1, 256, 256)
9 changes: 9 additions & 0 deletions tests/models/test_unetplusplus.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import segmentation_models_pytorch as smp
import torch

from tests.models import base

Expand All @@ -7,6 +8,14 @@ class TestUnetPlusPlusModel(base.BaseModelTester):
test_model_type = "unetplusplus"
files_for_diff = [r"decoders/unetplusplus/", r"base/"]

def test_timm_transformer_style_encoder(self):
model = smp.UnetPlusPlus("tu-convnext_atto", encoder_weights=None).eval()

with torch.inference_mode():
output = model(torch.rand(1, 3, 256, 256))

assert output.shape == (1, 1, 256, 256)

def test_interpolation(self):
# test bilinear
model_1 = smp.create_model(
Expand Down