Skip to content

Commit f72d8c2

Browse files
authored
feat(losses): add class_weights support to Dice, Jaccard, Tversky and Focal losses (#1290)
* feat(losses): add class_weights support to Dice, Jaccard, Tversky and Focal losses * fix(losses): apply docstring suggestions from code review --------- Co-authored-by: Raphael Lapertot <raphael.lapertot@gmail.com>
1 parent 2a13579 commit f72d8c2

5 files changed

Lines changed: 214 additions & 12 deletions

File tree

segmentation_models_pytorch/losses/dice.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
smooth: float = 0.0,
2020
ignore_index: Optional[int] = None,
2121
eps: float = 1e-7,
22+
class_weights: Optional[List[float]] = None,
2223
):
2324
"""Dice loss for image segmentation task.
2425
It supports binary, multiclass and multilabel cases
@@ -32,6 +33,9 @@ def __init__(
3233
ignore_index: Label that indicates ignored pixels (does not contribute to loss)
3334
eps: A small epsilon for numerical stability to avoid zero division error
3435
(denominator will be always greater or equal to eps)
36+
class_weights: List of weights for each class. If not ``None``, the loss for each class
37+
is multiplied by the corresponding weight. Only supported for multiclass and
38+
multilabel modes. Weights do not need to be normalized.
3539
3640
Shape
3741
- **y_pred** - torch.Tensor of shape (N, C, H, W)
@@ -43,6 +47,8 @@ def __init__(
4347
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
4448
super(DiceLoss, self).__init__()
4549
self.mode = mode
50+
if class_weights is not None and mode == BINARY_MODE:
51+
raise ValueError("class_weights are not supported with mode=binary")
4652
if classes is not None:
4753
assert mode != BINARY_MODE, (
4854
"Masking classes is not supported with mode=binary"
@@ -55,6 +61,11 @@ def __init__(
5561
self.eps = eps
5662
self.log_loss = log_loss
5763
self.ignore_index = ignore_index
64+
self.class_weights = (
65+
to_tensor(class_weights, dtype=torch.float)
66+
if class_weights is not None
67+
else None
68+
)
5869

5970
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6071
assert y_true.size(0) == y_pred.size(0)
@@ -128,7 +139,21 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
128139

129140
return self.aggregate_loss(loss)
130141

131-
def aggregate_loss(self, loss):
142+
def aggregate_loss(self, loss: torch.Tensor) -> torch.Tensor:
143+
"""Aggregate per-class losses into a single scalar.
144+
145+
Args:
146+
loss: Per-class loss tensor of shape (C,)
147+
148+
Returns:
149+
Scalar loss value
150+
"""
151+
if self.class_weights is not None:
152+
weights = self.class_weights.to(loss.device)
153+
# If classes filter is applied, slice weights accordingly
154+
if self.classes is not None:
155+
weights = weights[self.classes]
156+
return (loss * weights).sum() / weights.sum()
132157
return loss.mean()
133158

134159
def compute_score(

segmentation_models_pytorch/losses/focal.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from typing import Optional
1+
from typing import Optional, List
22
from functools import partial
33

44
import torch
55
from torch.nn.modules.loss import _Loss
6-
from ._functional import focal_loss_with_logits
6+
from ._functional import focal_loss_with_logits, to_tensor
77
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE
88

99
__all__ = ["FocalLoss"]
@@ -19,6 +19,7 @@ def __init__(
1919
reduction: Optional[str] = "mean",
2020
normalized: bool = False,
2121
reduced_threshold: Optional[float] = None,
22+
class_weights: Optional[List[float]] = None,
2223
):
2324
"""Compute Focal loss
2425
@@ -31,6 +32,9 @@ def __init__(
3132
normalized: Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf).
3233
reduced_threshold: Switch to reduced focal loss. Note, when using this mode you
3334
should use `reduction="sum"`.
35+
class_weights: List of weights for each class. If not ``None``, the loss for each class
36+
is multiplied by the corresponding weight. Only supported for multiclass and
37+
multilabel modes. Weights do not need to be normalized.
3438
3539
Shape
3640
- **y_pred** - torch.Tensor of shape (N, C, H, W)
@@ -43,6 +47,9 @@ def __init__(
4347
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
4448
super().__init__()
4549

50+
if class_weights is not None and mode == BINARY_MODE:
51+
raise ValueError("class_weights are not supported with mode=binary")
52+
4653
self.mode = mode
4754
self.ignore_index = ignore_index
4855
self.reduction = reduction
@@ -54,9 +61,14 @@ def __init__(
5461
reduction=reduction,
5562
normalized=normalized,
5663
)
64+
self.class_weights = (
65+
to_tensor(class_weights, dtype=torch.float)
66+
if class_weights is not None
67+
else None
68+
)
5769

5870
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
59-
if self.mode in {BINARY_MODE, MULTILABEL_MODE}:
71+
if self.mode == BINARY_MODE:
6072
y_true = y_true.reshape(-1)
6173
y_pred = y_pred.reshape(-1)
6274

@@ -68,22 +80,32 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6880

6981
loss = self.focal_loss_fn(y_pred, y_true)
7082

71-
elif self.mode == MULTICLASS_MODE:
83+
elif self.mode in {MULTILABEL_MODE, MULTICLASS_MODE}:
7284
num_classes = y_pred.size(1)
73-
loss = 0
7485

7586
# Filter anchors with -1 label from loss computation
7687
if self.ignore_index is not None:
7788
not_ignored = y_true != self.ignore_index
7889

90+
class_losses = []
7991
for cls in range(num_classes):
80-
cls_y_true = (y_true == cls).long()
92+
if self.mode == MULTICLASS_MODE:
93+
cls_y_true = (y_true == cls).long()
94+
else:
95+
cls_y_true = y_true[:, cls, ...]
8196
cls_y_pred = y_pred[:, cls, ...]
8297

8398
if self.ignore_index is not None:
8499
cls_y_true = cls_y_true[not_ignored]
85100
cls_y_pred = cls_y_pred[not_ignored]
86101

87-
loss += self.focal_loss_fn(cls_y_pred, cls_y_true)
102+
class_losses.append(self.focal_loss_fn(cls_y_pred, cls_y_true))
103+
class_losses = torch.stack(class_losses) # shape (C,)
104+
105+
if self.class_weights is not None:
106+
weights = self.class_weights.to(class_losses.device)
107+
loss = (class_losses * weights).sum() / weights.sum()
108+
else:
109+
loss = class_losses.mean()
88110

89111
return loss

segmentation_models_pytorch/losses/jaccard.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
smooth: float = 0.0,
2020
ignore_index: Optional[int] = None,
2121
eps: float = 1e-7,
22+
class_weights: Optional[List[float]] = None,
2223
):
2324
"""Jaccard loss for image segmentation task.
2425
It supports binary, multiclass and multilabel cases
@@ -31,6 +32,9 @@ def __init__(
3132
smooth: Smoothness constant for dice coefficient
3233
eps: A small epsilon for numerical stability to avoid zero division error
3334
(denominator will be always greater or equal to eps)
35+
class_weights: List of weights for each class. If not ``None``, the loss for each class
36+
is multiplied by the corresponding weight. Only supported for multiclass and
37+
multilabel modes. Weights do not need to be normalized.
3438
3539
Shape
3640
- **y_pred** - torch.Tensor of shape (N, C, H, W)
@@ -43,6 +47,8 @@ def __init__(
4347
super(JaccardLoss, self).__init__()
4448

4549
self.mode = mode
50+
if class_weights is not None and mode == BINARY_MODE:
51+
raise ValueError("class_weights are not supported with mode=binary")
4652
if classes is not None:
4753
assert mode != BINARY_MODE, (
4854
"Masking classes is not supported with mode=binary"
@@ -55,6 +61,11 @@ def __init__(
5561
self.ignore_index = ignore_index
5662
self.eps = eps
5763
self.log_loss = log_loss
64+
self.class_weights = (
65+
to_tensor(class_weights, dtype=torch.float)
66+
if class_weights is not None
67+
else None
68+
)
5869

5970
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
6071
assert y_true.size(0) == y_pred.size(0)
@@ -130,4 +141,10 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
130141
if self.classes is not None:
131142
loss = loss[self.classes]
132143

144+
if self.class_weights is not None:
145+
weights = self.class_weights.to(loss.device)
146+
if self.classes is not None:
147+
weights = weights[self.classes]
148+
return (loss * weights).sum() / weights.sum()
149+
133150
return loss.mean()

segmentation_models_pytorch/losses/tversky.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class TverskyLoss(DiceLoss):
1717
Args:
1818
mode: Metric mode {'binary', 'multiclass', 'multilabel'}
1919
classes: Optional list of classes that contribute in loss computation;
20-
By default, all channels are included.
20+
By default, all channels are included.
2121
log_loss: If True, loss computed as ``-log(tversky)`` otherwise ``1 - tversky``
2222
from_logits: If True assumes input is raw logits
2323
smooth:
@@ -26,6 +26,9 @@ class TverskyLoss(DiceLoss):
2626
alpha: Weight constant that penalize model for FPs (False Positives)
2727
beta: Weight constant that penalize model for FNs (False Negatives)
2828
gamma: Constant that squares the error function. Defaults to ``1.0``
29+
class_weights: List of weights for each class. If not ``None``, the loss for each class
30+
is multiplied by the corresponding weight. Only supported for multiclass and
31+
multilabel modes. Weights do not need to be normalized.
2932
3033
Return:
3134
loss: torch.Tensor
@@ -35,7 +38,7 @@ class TverskyLoss(DiceLoss):
3538
def __init__(
3639
self,
3740
mode: str,
38-
classes: List[int] = None,
41+
classes: Optional[List[int]] = None,
3942
log_loss: bool = False,
4043
from_logits: bool = True,
4144
smooth: float = 0.0,
@@ -44,16 +47,37 @@ def __init__(
4447
alpha: float = 0.5,
4548
beta: float = 0.5,
4649
gamma: float = 1.0,
50+
class_weights: Optional[List[float]] = None,
4751
):
4852
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
4953
super().__init__(
50-
mode, classes, log_loss, from_logits, smooth, ignore_index, eps
54+
mode,
55+
classes,
56+
log_loss,
57+
from_logits,
58+
smooth,
59+
ignore_index,
60+
eps,
61+
class_weights,
5162
)
5263
self.alpha = alpha
5364
self.beta = beta
5465
self.gamma = gamma
5566

56-
def aggregate_loss(self, loss):
67+
def aggregate_loss(self, loss: torch.Tensor) -> torch.Tensor:
68+
"""Aggregate per-class losses into a single scalar, raised to the power of gamma.
69+
70+
Args:
71+
loss: Per-class loss tensor of shape (C,)
72+
73+
Returns:
74+
Scalar loss value
75+
"""
76+
if self.class_weights is not None:
77+
weights = self.class_weights.to(loss.device)
78+
if self.classes is not None:
79+
weights = weights[self.classes]
80+
return ((loss * weights).sum() / weights.sum()) ** self.gamma
5781
return loss.mean() ** self.gamma
5882

5983
def compute_score(

tests/test_losses.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
SoftCrossEntropyLoss,
1010
TverskyLoss,
1111
MCCLoss,
12+
FocalLoss,
1213
)
1314

1415

@@ -332,3 +333,116 @@ def test_binary_mcc_loss():
332333

333334
loss = criterion(y_pred, y_true)
334335
assert float(loss) == pytest.approx(0.5, abs=eps)
336+
337+
338+
@torch.inference_mode()
339+
def test_class_weights_uniform_equivalent_to_no_weights_multiclass():
340+
"""Uniform class_weights should produce the same loss as no weights (multiclass)."""
341+
eps = 1e-5
342+
torch.manual_seed(0)
343+
y_pred = torch.randn(2, 3, 4, 4)
344+
y_true = torch.randint(0, 3, (2, 4, 4))
345+
346+
for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]:
347+
loss_no_w = loss_cls(mode=smp.losses.MULTICLASS_MODE)(y_pred, y_true)
348+
loss_uniform = loss_cls(
349+
mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 1.0, 1.0]
350+
)(y_pred, y_true)
351+
assert torch.allclose(loss_no_w, loss_uniform, atol=eps), (
352+
f"Uniform weights should be equivalent to no weights for {loss_cls.__name__}"
353+
)
354+
355+
356+
@torch.inference_mode()
357+
def test_class_weights_uniform_equivalent_to_no_weights_multilabel():
358+
"""Uniform class_weights should produce the same loss as no weights (multilabel)."""
359+
eps = 1e-5
360+
torch.manual_seed(0)
361+
y_pred = torch.randn(2, 3, 4, 4)
362+
y_true = torch.randint(0, 2, (2, 3, 4, 4)).float()
363+
364+
for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]:
365+
loss_no_w = loss_cls(mode=smp.losses.MULTILABEL_MODE)(y_pred, y_true)
366+
loss_uniform = loss_cls(
367+
mode=smp.losses.MULTILABEL_MODE, class_weights=[1.0, 1.0, 1.0]
368+
)(y_pred, y_true)
369+
assert torch.allclose(loss_no_w, loss_uniform, atol=eps), (
370+
f"Uniform weights should be equivalent to no weights for {loss_cls.__name__}"
371+
)
372+
373+
374+
@torch.inference_mode()
375+
def test_class_weights_nonuniform_changes_loss_multiclass():
376+
"""Non-uniform class_weights should change the loss value (multiclass)."""
377+
torch.manual_seed(0)
378+
y_pred = torch.randn(2, 3, 4, 4)
379+
y_true = torch.randint(0, 3, (2, 4, 4))
380+
381+
for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]:
382+
loss_no_w = loss_cls(mode=smp.losses.MULTICLASS_MODE)(y_pred, y_true)
383+
loss_weighted = loss_cls(
384+
mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 2.0, 0.5]
385+
)(y_pred, y_true)
386+
assert not torch.allclose(loss_no_w, loss_weighted, atol=1e-6), (
387+
f"Non-uniform weights should change the loss for {loss_cls.__name__}"
388+
)
389+
390+
391+
@torch.inference_mode()
392+
def test_class_weights_scale_invariant_multiclass():
393+
"""Scaling all weights by a constant should not change the loss (multiclass)."""
394+
eps = 1e-5
395+
torch.manual_seed(0)
396+
y_pred = torch.randn(2, 3, 4, 4)
397+
y_true = torch.randint(0, 3, (2, 4, 4))
398+
399+
for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]:
400+
loss_w = loss_cls(
401+
mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 2.0, 0.5]
402+
)(y_pred, y_true)
403+
loss_w_scaled = loss_cls(
404+
mode=smp.losses.MULTICLASS_MODE, class_weights=[10.0, 20.0, 5.0]
405+
)(y_pred, y_true)
406+
assert torch.allclose(loss_w, loss_w_scaled, atol=eps), (
407+
f"Loss should be scale-invariant w.r.t. class_weights for {loss_cls.__name__}"
408+
)
409+
410+
411+
@torch.inference_mode()
412+
def test_class_weights_binary_mode_raises():
413+
"""class_weights should raise an error when used with binary mode."""
414+
for loss_cls in [DiceLoss, JaccardLoss, TverskyLoss]:
415+
with pytest.raises(ValueError):
416+
loss_cls(mode=smp.losses.BINARY_MODE, class_weights=[1.0, 2.0])
417+
418+
419+
@torch.inference_mode()
420+
def test_focal_class_weights_uniform_equivalent_to_no_weights():
421+
"""Uniform class_weights should produce a loss equivalent to no-weights loss."""
422+
eps = 1e-5
423+
torch.manual_seed(0)
424+
y_pred = torch.randn(2, 3, 4, 4)
425+
y_true = torch.randint(0, 3, (2, 4, 4))
426+
427+
loss_no_w = FocalLoss(mode=smp.losses.MULTICLASS_MODE)(y_pred, y_true)
428+
loss_uniform = FocalLoss(
429+
mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 1.0, 1.0]
430+
)(y_pred, y_true)
431+
assert torch.allclose(loss_no_w, loss_uniform, atol=eps)
432+
433+
434+
@torch.inference_mode()
435+
def test_focal_class_weights_scale_invariant():
436+
"""Scaling all weights by a constant should not change FocalLoss."""
437+
eps = 1e-5
438+
torch.manual_seed(0)
439+
y_pred = torch.randn(2, 3, 4, 4)
440+
y_true = torch.randint(0, 3, (2, 4, 4))
441+
442+
loss_w = FocalLoss(mode=smp.losses.MULTICLASS_MODE, class_weights=[1.0, 2.0, 0.5])(
443+
y_pred, y_true
444+
)
445+
loss_w_scaled = FocalLoss(
446+
mode=smp.losses.MULTICLASS_MODE, class_weights=[10.0, 20.0, 5.0]
447+
)(y_pred, y_true)
448+
assert torch.allclose(loss_w, loss_w_scaled, atol=eps)

0 commit comments

Comments
 (0)