Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion segmentation_models_pytorch/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
smooth: float = 0.0,
ignore_index: Optional[int] = None,
eps: float = 1e-7,
class_weights: Optional[List[float]] = None,
):
"""Dice loss for image segmentation task.
It supports binary, multiclass and multilabel cases
Expand All @@ -32,6 +33,9 @@ def __init__(
ignore_index: Label that indicates ignored pixels (does not contribute to loss)
eps: A small epsilon for numerical stability to avoid zero division error
(denominator will be always greater or equal to eps)
class_weights: List of weights for each class. If not ``None``, the loss for each class
is multiplied by the corresponding weight. Only supported for multiclass and
multilabel modes. Weights do not need to be normalized.

Shape
- **y_pred** - torch.Tensor of shape (N, C, H, W)
Expand All @@ -43,6 +47,8 @@ def __init__(
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
super(DiceLoss, self).__init__()
self.mode = mode
if class_weights is not None and mode == BINARY_MODE:
raise ValueError("class_weights are not supported with mode=binary")
if classes is not None:
assert mode != BINARY_MODE, (
"Masking classes is not supported with mode=binary"
Expand All @@ -55,6 +61,11 @@ def __init__(
self.eps = eps
self.log_loss = log_loss
self.ignore_index = ignore_index
self.class_weights = (
to_tensor(class_weights, dtype=torch.float)
if class_weights is not None
else None
)

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
assert y_true.size(0) == y_pred.size(0)
Expand Down Expand Up @@ -128,7 +139,21 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:

return self.aggregate_loss(loss)

def aggregate_loss(self, loss):
def aggregate_loss(self, loss: torch.Tensor) -> torch.Tensor:
"""Aggregate per-class losses into a single scalar.

Args:
loss: Per-class loss tensor of shape (C,)

Returns:
Scalar loss value
"""
if self.class_weights is not None:
weights = self.class_weights.to(loss.device)
# If classes filter is applied, slice weights accordingly
if self.classes is not None:
weights = weights[self.classes]
return (loss * weights).sum() / weights.sum()
return loss.mean()

def compute_score(
Expand Down
36 changes: 29 additions & 7 deletions segmentation_models_pytorch/losses/focal.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Optional
from typing import Optional, List
from functools import partial

import torch
from torch.nn.modules.loss import _Loss
from ._functional import focal_loss_with_logits
from ._functional import focal_loss_with_logits, to_tensor
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE

__all__ = ["FocalLoss"]
Expand All @@ -19,6 +19,7 @@ def __init__(
reduction: Optional[str] = "mean",
normalized: bool = False,
reduced_threshold: Optional[float] = None,
class_weights: Optional[List[float]] = None,
):
"""Compute Focal loss

Expand All @@ -31,6 +32,9 @@ def __init__(
normalized: Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf).
reduced_threshold: Switch to reduced focal loss. Note, when using this mode you
should use `reduction="sum"`.
class_weights: List of weights for each class. If not ``None``, the loss for each class
is multiplied by the corresponding weight. Only supported for multiclass and
multilabel modes. Weights do not need to be normalized.

Shape
- **y_pred** - torch.Tensor of shape (N, C, H, W)
Expand All @@ -43,6 +47,9 @@ def __init__(
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
super().__init__()

if class_weights is not None and mode == BINARY_MODE:
raise ValueError("class_weights are not supported with mode=binary")

self.mode = mode
self.ignore_index = ignore_index
self.reduction = reduction
Expand All @@ -54,9 +61,14 @@ def __init__(
reduction=reduction,
normalized=normalized,
)
self.class_weights = (
to_tensor(class_weights, dtype=torch.float)
if class_weights is not None
else None
)

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
if self.mode in {BINARY_MODE, MULTILABEL_MODE}:
if self.mode == BINARY_MODE:
y_true = y_true.reshape(-1)
y_pred = y_pred.reshape(-1)

Expand All @@ -68,22 +80,32 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:

loss = self.focal_loss_fn(y_pred, y_true)

elif self.mode == MULTICLASS_MODE:
elif self.mode in {MULTILABEL_MODE, MULTICLASS_MODE}:
num_classes = y_pred.size(1)
loss = 0

# Filter anchors with -1 label from loss computation
if self.ignore_index is not None:
not_ignored = y_true != self.ignore_index

class_losses = []
for cls in range(num_classes):
cls_y_true = (y_true == cls).long()
if self.mode == MULTICLASS_MODE:
cls_y_true = (y_true == cls).long()
else:
cls_y_true = y_true[:, cls, ...]
cls_y_pred = y_pred[:, cls, ...]

if self.ignore_index is not None:
cls_y_true = cls_y_true[not_ignored]
cls_y_pred = cls_y_pred[not_ignored]

loss += self.focal_loss_fn(cls_y_pred, cls_y_true)
class_losses.append(self.focal_loss_fn(cls_y_pred, cls_y_true))
class_losses = torch.stack(class_losses) # shape (C,)

if self.class_weights is not None:
weights = self.class_weights.to(class_losses.device)
loss = (class_losses * weights).sum() / weights.sum()
else:
loss = class_losses.mean()

return loss
17 changes: 17 additions & 0 deletions segmentation_models_pytorch/losses/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
smooth: float = 0.0,
ignore_index: Optional[int] = None,
eps: float = 1e-7,
class_weights: Optional[List[float]] = None,
):
"""Jaccard loss for image segmentation task.
It supports binary, multiclass and multilabel cases
Expand All @@ -31,6 +32,9 @@ def __init__(
smooth: Smoothness constant for dice coefficient
eps: A small epsilon for numerical stability to avoid zero division error
(denominator will be always greater or equal to eps)
class_weights: List of weights for each class. If not ``None``, the loss for each class
is multiplied by the corresponding weight. Only supported for multiclass and
multilabel modes. Weights do not need to be normalized.

Shape
- **y_pred** - torch.Tensor of shape (N, C, H, W)
Expand All @@ -43,6 +47,8 @@ def __init__(
super(JaccardLoss, self).__init__()

self.mode = mode
if class_weights is not None and mode == BINARY_MODE:
raise ValueError("class_weights are not supported with mode=binary")
if classes is not None:
assert mode != BINARY_MODE, (
"Masking classes is not supported with mode=binary"
Expand All @@ -55,6 +61,11 @@ def __init__(
self.ignore_index = ignore_index
self.eps = eps
self.log_loss = log_loss
self.class_weights = (
to_tensor(class_weights, dtype=torch.float)
if class_weights is not None
else None
)

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
assert y_true.size(0) == y_pred.size(0)
Expand Down Expand Up @@ -130,4 +141,10 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
if self.classes is not None:
loss = loss[self.classes]

if self.class_weights is not None:
weights = self.class_weights.to(loss.device)
if self.classes is not None:
weights = weights[self.classes]
return (loss * weights).sum() / weights.sum()

return loss.mean()
32 changes: 28 additions & 4 deletions segmentation_models_pytorch/losses/tversky.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TverskyLoss(DiceLoss):
Args:
mode: Metric mode {'binary', 'multiclass', 'multilabel'}
classes: Optional list of classes that contribute in loss computation;
By default, all channels are included.
By default, all channels are included.
log_loss: If True, loss computed as ``-log(tversky)`` otherwise ``1 - tversky``
from_logits: If True assumes input is raw logits
smooth:
Expand All @@ -26,6 +26,9 @@ class TverskyLoss(DiceLoss):
alpha: Weight constant that penalize model for FPs (False Positives)
beta: Weight constant that penalize model for FNs (False Negatives)
gamma: Constant that squares the error function. Defaults to ``1.0``
class_weights: List of weights for each class. If not ``None``, the loss for each class
is multiplied by the corresponding weight. Only supported for multiclass and
multilabel modes. Weights do not need to be normalized.

Return:
loss: torch.Tensor
Expand All @@ -35,7 +38,7 @@ class TverskyLoss(DiceLoss):
def __init__(
self,
mode: str,
classes: List[int] = None,
classes: Optional[List[int]] = None,
log_loss: bool = False,
from_logits: bool = True,
smooth: float = 0.0,
Expand All @@ -44,16 +47,37 @@ def __init__(
alpha: float = 0.5,
beta: float = 0.5,
gamma: float = 1.0,
class_weights: Optional[List[float]] = None,
):
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
super().__init__(
mode, classes, log_loss, from_logits, smooth, ignore_index, eps
mode,
classes,
log_loss,
from_logits,
smooth,
ignore_index,
eps,
class_weights,
)
self.alpha = alpha
self.beta = beta
self.gamma = gamma

def aggregate_loss(self, loss):
def aggregate_loss(self, loss: torch.Tensor) -> torch.Tensor:
"""Aggregate per-class losses into a single scalar, raised to the power of gamma.

Args:
loss: Per-class loss tensor of shape (C,)

Returns:
Scalar loss value
"""
if self.class_weights is not None:
weights = self.class_weights.to(loss.device)
if self.classes is not None:
weights = weights[self.classes]
return ((loss * weights).sum() / weights.sum()) ** self.gamma
return loss.mean() ** self.gamma

def compute_score(
Expand Down
114 changes: 114 additions & 0 deletions tests/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
SoftCrossEntropyLoss,
TverskyLoss,
MCCLoss,
FocalLoss,
)


Expand Down Expand Up @@ -332,3 +333,116 @@ def test_binary_mcc_loss():

loss = criterion(y_pred, y_true)
assert float(loss) == pytest.approx(0.5, abs=eps)


@torch.inference_mode()
def test_class_weights_uniform_equivalent_to_no_weights_multiclass():
"""Uniform class_weights should produce the same loss as no weights (multiclass)."""
eps = 1e-5
torch.manual_seed(0)
y_pred = torch.randn(2, 3, 4, 4)
y_true = torch.randint(0, 3, (2, 4, 4))

for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]:
loss_no_w = loss_cls(mode=smp.losses.MULTICLASS_MODE)(y_pred, y_true)
loss_uniform = loss_cls(
mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 1.0, 1.0]
)(y_pred, y_true)
assert torch.allclose(loss_no_w, loss_uniform, atol=eps), (
f"Uniform weights should be equivalent to no weights for {loss_cls.__name__}"
)


@torch.inference_mode()
def test_class_weights_uniform_equivalent_to_no_weights_multilabel():
"""Uniform class_weights should produce the same loss as no weights (multilabel)."""
eps = 1e-5
torch.manual_seed(0)
y_pred = torch.randn(2, 3, 4, 4)
y_true = torch.randint(0, 2, (2, 3, 4, 4)).float()

for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]:
loss_no_w = loss_cls(mode=smp.losses.MULTILABEL_MODE)(y_pred, y_true)
loss_uniform = loss_cls(
mode=smp.losses.MULTILABEL_MODE, class_weights=[1.0, 1.0, 1.0]
)(y_pred, y_true)
assert torch.allclose(loss_no_w, loss_uniform, atol=eps), (
f"Uniform weights should be equivalent to no weights for {loss_cls.__name__}"
)


@torch.inference_mode()
def test_class_weights_nonuniform_changes_loss_multiclass():
"""Non-uniform class_weights should change the loss value (multiclass)."""
torch.manual_seed(0)
y_pred = torch.randn(2, 3, 4, 4)
y_true = torch.randint(0, 3, (2, 4, 4))

for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]:
loss_no_w = loss_cls(mode=smp.losses.MULTICLASS_MODE)(y_pred, y_true)
loss_weighted = loss_cls(
mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 2.0, 0.5]
)(y_pred, y_true)
assert not torch.allclose(loss_no_w, loss_weighted, atol=1e-6), (
f"Non-uniform weights should change the loss for {loss_cls.__name__}"
)


@torch.inference_mode()
def test_class_weights_scale_invariant_multiclass():
"""Scaling all weights by a constant should not change the loss (multiclass)."""
eps = 1e-5
torch.manual_seed(0)
y_pred = torch.randn(2, 3, 4, 4)
y_true = torch.randint(0, 3, (2, 4, 4))

for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]:
loss_w = loss_cls(
mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 2.0, 0.5]
)(y_pred, y_true)
loss_w_scaled = loss_cls(
mode=smp.losses.MULTICLASS_MODE, class_weights=[10.0, 20.0, 5.0]
)(y_pred, y_true)
assert torch.allclose(loss_w, loss_w_scaled, atol=eps), (
f"Loss should be scale-invariant w.r.t. class_weights for {loss_cls.__name__}"
)


@torch.inference_mode()
def test_class_weights_binary_mode_raises():
"""class_weights should raise an error when used with binary mode."""
for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]:
with pytest.raises(ValueError):
loss_cls(mode=smp.losses.BINARY_MODE, class_weights=[1.0, 2.0])


@torch.inference_mode()
def test_focal_class_weights_uniform_equivalent_to_no_weights():
"""Uniform class_weights should produce a loss equivalent to no-weights loss."""
eps = 1e-5
torch.manual_seed(0)
y_pred = torch.randn(2, 3, 4, 4)
y_true = torch.randint(0, 3, (2, 4, 4))

loss_no_w = FocalLoss(mode=smp.losses.MULTICLASS_MODE)(y_pred, y_true)
loss_uniform = FocalLoss(
mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 1.0, 1.0]
)(y_pred, y_true)
assert torch.allclose(loss_no_w, loss_uniform, atol=eps)


@torch.inference_mode()
def test_focal_class_weights_scale_invariant():
"""Scaling all weights by a constant should not change FocalLoss."""
eps = 1e-5
torch.manual_seed(0)
y_pred = torch.randn(2, 3, 4, 4)
y_true = torch.randint(0, 3, (2, 4, 4))

loss_w = FocalLoss(mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 2.0, 0.5])(
y_pred, y_true
)
loss_w_scaled = FocalLoss(
mode=smp.losses.MULTICLASS_MODE, class_weights=[10.0, 20.0, 5.0]
)(y_pred, y_true)
assert torch.allclose(loss_w, loss_w_scaled, atol=eps)