Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion segmentation_models_pytorch/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
smooth: float = 0.0,
ignore_index: Optional[int] = None,
eps: float = 1e-7,
class_weights: Optional[List[float]] = None,
):
"""Dice loss for image segmentation task.
It supports binary, multiclass and multilabel cases
Expand All @@ -32,6 +33,9 @@ def __init__(
ignore_index: Label that indicates ignored pixels (does not contribute to loss)
eps: A small epsilon for numerical stability to avoid zero division error
(denominator will be always greater or equal to eps)
class_weights: Array of weights for each class. If not None, the loss for each class
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class_weights: Array of weights for each class. If not None, the loss for each class
class_weights: List of weights for each class. If not ``None``, the loss for each class

is multiplied by the corresponding weight. Only supported for multiclass and
multilabel modes. Weights do not need to be normalized.

Shape
- **y_pred** - torch.Tensor of shape (N, C, H, W)
Expand All @@ -43,6 +47,8 @@ def __init__(
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
super(DiceLoss, self).__init__()
self.mode = mode
if class_weights is not None and mode == BINARY_MODE:
raise ValueError("class_weights are not supported with mode=binary")
if classes is not None:
assert mode != BINARY_MODE, (
"Masking classes is not supported with mode=binary"
Expand All @@ -55,6 +61,9 @@ def __init__(
self.eps = eps
self.log_loss = log_loss
self.ignore_index = ignore_index
self.class_weights = (
to_tensor(class_weights, dtype=torch.float) if class_weights is not None else None
)

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
assert y_true.size(0) == y_pred.size(0)
Expand Down Expand Up @@ -128,7 +137,21 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:

return self.aggregate_loss(loss)

def aggregate_loss(self, loss):
def aggregate_loss(self, loss: torch.Tensor) -> torch.Tensor:
"""Aggregate per-class losses into a single scalar.

Args:
loss: Per-class loss tensor of shape (C,)

Returns:
Scalar loss value
"""
if self.class_weights is not None:
weights = self.class_weights.to(loss.device)
# If classes filter is applied, slice weights accordingly
if self.classes is not None:
weights = weights[self.classes]
return (loss * weights).sum() / weights.sum()
return loss.mean()

def compute_score(
Expand Down
23 changes: 19 additions & 4 deletions segmentation_models_pytorch/losses/focal.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Optional
from typing import Optional, List
from functools import partial

import torch
from torch.nn.modules.loss import _Loss
from ._functional import focal_loss_with_logits
from ._functional import focal_loss_with_logits, to_tensor
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE

__all__ = ["FocalLoss"]
Expand All @@ -19,6 +19,7 @@ def __init__(
reduction: Optional[str] = "mean",
normalized: bool = False,
reduced_threshold: Optional[float] = None,
class_weights: Optional[List[float]] = None,
):
"""Compute Focal loss

Expand All @@ -31,6 +32,9 @@ def __init__(
normalized: Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf).
reduced_threshold: Switch to reduced focal loss. Note, when using this mode you
should use `reduction="sum"`.
class_weights: Array of weights for each class. If not None, the loss for each class
is multiplied by the corresponding weight. Only supported for multiclass mode.
Weights do not need to be normalized.

Shape
- **y_pred** - torch.Tensor of shape (N, C, H, W)
Expand All @@ -43,6 +47,9 @@ def __init__(
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
super().__init__()

if class_weights is not None and mode != MULTICLASS_MODE:
raise ValueError("class_weights are only supported with mode=multiclass")

self.mode = mode
self.ignore_index = ignore_index
self.reduction = reduction
Expand All @@ -54,6 +61,7 @@ def __init__(
reduction=reduction,
normalized=normalized,
)
self.class_weights = to_tensor(class_weights, dtype=torch.float) if class_weights is not None else None

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
if self.mode in {BINARY_MODE, MULTILABEL_MODE}:
Expand All @@ -70,12 +78,12 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:

elif self.mode == MULTICLASS_MODE:
num_classes = y_pred.size(1)
loss = 0

# Filter anchors with -1 label from loss computation
if self.ignore_index is not None:
not_ignored = y_true != self.ignore_index

class_losses = []
for cls in range(num_classes):
cls_y_true = (y_true == cls).long()
cls_y_pred = y_pred[:, cls, ...]
Expand All @@ -84,6 +92,13 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
cls_y_true = cls_y_true[not_ignored]
cls_y_pred = cls_y_pred[not_ignored]

loss += self.focal_loss_fn(cls_y_pred, cls_y_true)
class_losses.append(self.focal_loss_fn(cls_y_pred, cls_y_true))
class_losses = torch.stack(class_losses) # shape (C,)

if self.class_weights is not None:
weights = self.class_weights.to(class_losses.device)
loss = (class_losses * weights).sum() / weights.sum()
else:
loss = class_losses.sum()

return loss
15 changes: 15 additions & 0 deletions segmentation_models_pytorch/losses/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
smooth: float = 0.0,
ignore_index: Optional[int] = None,
eps: float = 1e-7,
class_weights: Optional[List[float]] = None,
):
"""Jaccard loss for image segmentation task.
It supports binary, multiclass and multilabel cases
Expand All @@ -31,6 +32,9 @@ def __init__(
smooth: Smoothness constant for dice coefficient
eps: A small epsilon for numerical stability to avoid zero division error
(denominator will be always greater or equal to eps)
class_weights: Array of weights for each class. If not None, the loss for each class
is multiplied by the corresponding weight. Only supported for multiclass and
multilabel modes. Weights do not need to be normalized.

Shape
- **y_pred** - torch.Tensor of shape (N, C, H, W)
Expand All @@ -43,6 +47,8 @@ def __init__(
super(JaccardLoss, self).__init__()

self.mode = mode
if class_weights is not None and mode == BINARY_MODE:
raise ValueError("class_weights are not supported with mode=binary")
if classes is not None:
assert mode != BINARY_MODE, (
"Masking classes is not supported with mode=binary"
Expand All @@ -55,6 +61,9 @@ def __init__(
self.ignore_index = ignore_index
self.eps = eps
self.log_loss = log_loss
self.class_weights = (
to_tensor(class_weights, dtype=torch.float) if class_weights is not None else None
)

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
assert y_true.size(0) == y_pred.size(0)
Expand Down Expand Up @@ -130,4 +139,10 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
if self.classes is not None:
loss = loss[self.classes]

if self.class_weights is not None:
weights = self.class_weights.to(loss.device)
if self.classes is not None:
weights = weights[self.classes]
return (loss * weights).sum() / weights.sum()

return loss.mean()
10 changes: 8 additions & 2 deletions segmentation_models_pytorch/losses/tversky.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,22 @@ def __init__(
alpha: float = 0.5,
beta: float = 0.5,
gamma: float = 1.0,
class_weights: Optional[List[float]] = None,
):
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
super().__init__(
mode, classes, log_loss, from_logits, smooth, ignore_index, eps
mode, classes, log_loss, from_logits, smooth, ignore_index, eps, class_weights
)
self.alpha = alpha
self.beta = beta
self.gamma = gamma

def aggregate_loss(self, loss):
def aggregate_loss(self, loss: torch.Tensor) -> torch.Tensor:
if self.class_weights is not None:
weights = self.class_weights.to(loss.device)
if self.classes is not None:
weights = weights[self.classes]
return ((loss * weights).sum() / weights.sum()) ** self.gamma
return loss.mean() ** self.gamma

def compute_score(
Expand Down
102 changes: 102 additions & 0 deletions tests/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
SoftCrossEntropyLoss,
TverskyLoss,
MCCLoss,
FocalLoss,
)


Expand Down Expand Up @@ -332,3 +333,104 @@ def test_binary_mcc_loss():

loss = criterion(y_pred, y_true)
assert float(loss) == pytest.approx(0.5, abs=eps)


@torch.inference_mode()
def test_class_weights_uniform_equivalent_to_no_weights_multiclass():
"""Uniform class_weights should produce the same loss as no weights (multiclass)."""
eps = 1e-5
torch.manual_seed(0)
y_pred = torch.randn(2, 3, 4, 4)
y_true = torch.randint(0, 3, (2, 4, 4))

for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]:
loss_no_w = loss_cls(mode=smp.losses.MULTICLASS_MODE)(y_pred, y_true)
loss_uniform = loss_cls(mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 1.0, 1.0])(y_pred, y_true)
assert torch.allclose(loss_no_w, loss_uniform, atol=eps), (
f"Uniform weights should be equivalent to no weights for {loss_cls.__name__}"
)


@torch.inference_mode()
def test_class_weights_uniform_equivalent_to_no_weights_multilabel():
"""Uniform class_weights should produce the same loss as no weights (multilabel)."""
eps = 1e-5
torch.manual_seed(0)
y_pred = torch.randn(2, 3, 4, 4)
y_true = torch.randint(0, 2, (2, 3, 4, 4)).float()

for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]:
loss_no_w = loss_cls(mode=smp.losses.MULTILABEL_MODE)(y_pred, y_true)
loss_uniform = loss_cls(mode=smp.losses.MULTILABEL_MODE, class_weights=[1.0, 1.0, 1.0])(y_pred, y_true)
assert torch.allclose(loss_no_w, loss_uniform, atol=eps), (
f"Uniform weights should be equivalent to no weights for {loss_cls.__name__}"
)


@torch.inference_mode()
def test_class_weights_nonuniform_changes_loss_multiclass():
"""Non-uniform class_weights should change the loss value (multiclass)."""
torch.manual_seed(0)
y_pred = torch.randn(2, 3, 4, 4)
y_true = torch.randint(0, 3, (2, 4, 4))

for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]:
loss_no_w = loss_cls(mode=smp.losses.MULTICLASS_MODE)(y_pred, y_true)
loss_weighted = loss_cls(mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 2.0, 0.5])(y_pred, y_true)
assert not torch.allclose(loss_no_w, loss_weighted, atol=1e-6), (
f"Non-uniform weights should change the loss for {loss_cls.__name__}"
)


@torch.inference_mode()
def test_class_weights_scale_invariant_multiclass():
"""Scaling all weights by a constant should not change the loss (multiclass)."""
eps = 1e-5
torch.manual_seed(0)
y_pred = torch.randn(2, 3, 4, 4)
y_true = torch.randint(0, 3, (2, 4, 4))

for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]:
loss_w = loss_cls(mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 2.0, 0.5])(y_pred, y_true)
loss_w_scaled = loss_cls(mode=smp.losses.MULTICLASS_MODE, class_weights=[10.0, 20.0, 5.0])(y_pred, y_true)
assert torch.allclose(loss_w, loss_w_scaled, atol=eps), (
f"Loss should be scale-invariant w.r.t. class_weights for {loss_cls.__name__}"
)


@torch.inference_mode()
def test_class_weights_binary_mode_raises():
"""class_weights should raise an error when used with binary mode."""
for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]:
with pytest.raises(ValueError):
loss_cls(mode=smp.losses.BINARY_MODE, class_weights=[1.0, 2.0])


@torch.inference_mode()
def test_focal_class_weights_uniform_equivalent_to_no_weights():
"""Uniform class_weights should produce a loss proportional to no-weights loss."""
eps = 1e-5
torch.manual_seed(0)
y_pred = torch.randn(2, 3, 4, 4)
y_true = torch.randint(0, 3, (2, 4, 4))
num_classes = 3

from segmentation_models_pytorch.losses import FocalLoss
loss_no_w = FocalLoss(mode=smp.losses.MULTICLASS_MODE)(y_pred, y_true)
loss_uniform = FocalLoss(mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 1.0, 1.0])(y_pred, y_true)
assert torch.allclose(loss_no_w / num_classes, loss_uniform, atol=eps)



@torch.inference_mode()
def test_focal_class_weights_scale_invariant():
"""Scaling all weights by a constant should not change FocalLoss."""
eps = 1e-5
torch.manual_seed(0)
y_pred = torch.randn(2, 3, 4, 4)
y_true = torch.randint(0, 3, (2, 4, 4))

from segmentation_models_pytorch.losses import FocalLoss
loss_w = FocalLoss(mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 2.0, 0.5])(y_pred, y_true)
loss_w_scaled = FocalLoss(mode=smp.losses.MULTICLASS_MODE, class_weights=[10.0, 20.0, 5.0])(y_pred, y_true)
assert torch.allclose(loss_w, loss_w_scaled, atol=eps)
Loading