Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 138 additions & 5 deletions segmentation_models_pytorch/encoders/timm_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,20 @@
- Automatic alignment for inconsistent feature scales:
- Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale.
- VGG-style models (include scale-1 features): Align outputs for compatibility.
- ViT-style models (single-scale): Use adapter to generate multi-scale features.
- Easy access to feature scale information via the `reduction` property.

Feature Scale Differences:
- Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32.
- Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale.
- VGG-style models: Include scale-1 features (input resolution).
- ViT-style models: Single-scale output, adapted to multi-scale via learnable layers.

Notes:
- `output_stride` is unsupported in some models, especially transformer-based architectures.
- Special handling for models like TResNet and DLA to ensure correct feature indexing.
- VGG-style models use `_is_vgg_style` to align scale-1 features with standard outputs.
- ViT-style models use `_is_vit_adapter_style` with adapter layers for multi-scale output.
"""

from typing import Any
Expand All @@ -33,6 +36,67 @@
import torch.nn as nn


class ViTFeatureAdapter(nn.Module):
"""
Adapter module to convert single-scale ViT features to multi-scale hierarchical features.

ViT models output features at a single scale (e.g., 1/16). This adapter generates
features at multiple scales (1/4, 1/8, 1/16, 1/32) using upsampling and downsampling.
"""

def __init__(self, in_channels: int, vit_reduction: int, target_reductions: list[int]):
"""
Args:
in_channels: Number of channels in ViT output features.
vit_reduction: The reduction factor of ViT features (e.g., 16 for patch16).
target_reductions: List of target reduction factors (e.g., [4, 8, 16, 32]).
"""
super().__init__()
self.vit_reduction = vit_reduction
self.target_reductions = target_reductions

self.adapters = nn.ModuleList()
self.out_channels_list = []

for target_red in target_reductions:
if target_red < vit_reduction:
scale_factor = vit_reduction // target_red
out_ch = in_channels // scale_factor
out_ch = max(out_ch, 1)
adapter = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_ch, kernel_size=scale_factor, stride=scale_factor),
nn.BatchNorm2d(out_ch),
nn.GELU(),
)
elif target_red == vit_reduction:
out_ch = in_channels
adapter = nn.Identity()
else:
scale_factor = target_red // vit_reduction
out_ch = in_channels * scale_factor
adapter = nn.Sequential(
nn.Conv2d(in_channels, out_ch, kernel_size=3, stride=scale_factor, padding=1),
nn.BatchNorm2d(out_ch),
nn.GELU(),
)

self.adapters.append(adapter)
self.out_channels_list.append(out_ch)

def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""
Args:
x: ViT feature tensor of shape (B, C, H, W).

Returns:
List of feature tensors at different scales.
"""
features = []
for adapter in self.adapters:
features.append(adapter(x))
return features


class TimmUniversalEncoder(nn.Module):
"""
A universal encoder leveraging the `timm` library for feature extraction from
Expand Down Expand Up @@ -104,23 +168,63 @@ def __init__(
# Determine the model's downsampling pattern and set hierarchy flags
encoder_stage = len(tmp_model.feature_info.reduction())
reduction_scales = list(tmp_model.feature_info.reduction())
feature_channels = list(tmp_model.feature_info.channels())

# Initialize style flags
self._is_transformer_style = False
self._is_vgg_style = False
self._is_vit_adapter_style = False

if reduction_scales == [2 ** (i + 2) for i in range(encoder_stage)]:
# Transformer-style downsampling: scales (4, 8, 16, 32)
self._is_transformer_style = True
self._is_vgg_style = False
elif reduction_scales == [2 ** (i + 1) for i in range(encoder_stage)]:
# Traditional-style downsampling: scales (2, 4, 8, 16, 32)
self._is_transformer_style = False
self._is_vgg_style = False
pass
elif reduction_scales == [2**i for i in range(encoder_stage)]:
# Vgg-style models including scale 1: scales (1, 2, 4, 8, 16, 32)
self._is_transformer_style = False
self._is_vgg_style = True
elif len(set(reduction_scales)) == 1:
self._is_vit_adapter_style = True
else:
raise ValueError("Unsupported model downsampling pattern.")

if self._is_transformer_style:
if self._is_vit_adapter_style:
vit_reduction = reduction_scales[0]
vit_channels = feature_channels[-1]

target_reductions = [2 ** (i + 2) for i in range(depth - 1)] if depth > 1 else []
if not target_reductions and depth > 1:
# If depth > 1 but target_reductions is empty (should not happen with logic above)
pass # Default behavior handles empty list regarding adapter features

common_kwargs.pop("features_only", None)
common_kwargs.pop("out_indices", None)

if output_stride != 32:
raise ValueError(f"ViT adapter style does not support output_stride={output_stride}. Only 32 is supported.")

timm_model_kwargs = _merge_kwargs_no_duplicates(common_kwargs, kwargs)
self.model = timm.create_model(name, **timm_model_kwargs)

if not hasattr(self.model, "forward_intermediates"):
raise ValueError(f"Model {name} does not support forward_intermediates, required for ViT adapter.")

if hasattr(self.model, "blocks"):
if depth > len(self.model.blocks):
raise ValueError(f"Depth {depth} exceeds model blocks {len(self.model.blocks)}")

self.vit_adapter = ViTFeatureAdapter(
in_channels=vit_channels,
vit_reduction=vit_reduction,
target_reductions=target_reductions,
)

self._out_channels = (
[in_channels] + [0] + self.vit_adapter.out_channels_list
)

elif self._is_transformer_style:
# Transformer-like models (start at scale 4)
if "tresnet" in name:
# 'tresnet' models start feature extraction at stage 1,
Expand Down Expand Up @@ -157,6 +261,32 @@ def __init__(
self._depth = depth
self._output_stride = output_stride

# ViT adapter style models are not TorchScript compatible due to forward_intermediates
if self._is_vit_adapter_style:
self._is_torch_scriptable = False

@torch.jit.unused
def _forward_vit_adapter(self, x: torch.Tensor) -> list[torch.Tensor]:
intermediates = self.model.forward_intermediates(
x,
indices=[-1],
intermediates_only=True,
)
vit_feature = intermediates[-1]
if isinstance(vit_feature, tuple):
vit_feature = vit_feature[0]

if self._is_channel_last:
vit_feature = vit_feature.permute(0, 3, 1, 2).contiguous()

features = self.vit_adapter(vit_feature)

B, _, H, W = x.shape
dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device)
features = [x, dummy] + features

return features

def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""
Forward pass to extract multi-stage features.
Expand All @@ -167,6 +297,9 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
Returns:
list[torch.Tensor]: List of feature maps at different scales.
"""
if self._is_vit_adapter_style:
return self._forward_vit_adapter(x)

features = self.model(x)

# Convert NHWC to NCHW if needed
Expand Down
13 changes: 13 additions & 0 deletions tests/encoders/test_vit_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from tests.encoders import base
from tests.utils import has_timm_test_models

class TestViTAdapterEncoder(base.BaseEncoderTester):
encoder_names = ["tu-vit_base_patch16_224", "tu-vit_tiny_patch16_224", "tu-vit_large_patch16_224"]

default_height = 224
default_width = 224

supports_dilated = False

depth_to_test = [3, 4, 5]
in_channels_to_test = [1, 3, 4]