Skip to content

Commit 4bf6ec0

Browse files
Harsh-2005dqubvel
andauthored
Add from_logits support to FocalLoss and eps for API consistency with DiceLoss (#1268)
* add * Update segmentation_models_pytorch/losses/focal.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update segmentation_models_pytorch/losses/focal.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update segmentation_models_pytorch/losses/focal.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * Update segmentation_models_pytorch/losses/focal.py Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com> * fixed formatting issue --------- Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
1 parent 506a444 commit 4bf6ec0

2 files changed

Lines changed: 45 additions & 0 deletions

File tree

segmentation_models_pytorch/losses/focal.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ def __init__(
1616
alpha: Optional[float] = None,
1717
gamma: Optional[float] = 2.0,
1818
ignore_index: Optional[int] = None,
19+
from_logits: bool = True,
20+
eps: float = 1e-7,
1921
reduction: Optional[str] = "mean",
2022
normalized: bool = False,
2123
reduced_threshold: Optional[float] = None,
@@ -25,6 +27,8 @@ def __init__(
2527
2628
Args:
2729
mode: Loss mode 'binary', 'multiclass' or 'multilabel'
30+
from_logits: If True, assumes input is raw logits
31+
eps: Small value used for numerical stability when converting probabilities to logits .
2832
alpha: Prior probability of having positive value in target.
2933
gamma: Power factor for dampening weight (focal strength).
3034
ignore_index: If not None, targets may contain values to be ignored.
@@ -51,8 +55,11 @@ def __init__(
5155
raise ValueError("class_weights are not supported with mode=binary")
5256

5357
self.mode = mode
58+
self.from_logits = from_logits
5459
self.ignore_index = ignore_index
5560
self.reduction = reduction
61+
self.eps = eps
62+
5663
self.focal_loss_fn = partial(
5764
focal_loss_with_logits,
5865
alpha=alpha,
@@ -68,6 +75,18 @@ def __init__(
6875
)
6976

7077
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
78+
79+
if not self.from_logits:
80+
y_pred = torch.clamp(y_pred, self.eps, 1 - self.eps)
81+
82+
if self.mode in {BINARY_MODE, MULTILABEL_MODE}:
83+
# inverse sigmoid
84+
y_pred = torch.log(y_pred / (1 - y_pred))
85+
86+
elif self.mode == MULTICLASS_MODE:
87+
# convert softmax probabilities to log-space
88+
y_pred = torch.log(y_pred)
89+
7190
if self.mode == BINARY_MODE:
7291
y_true = y_true.reshape(-1)
7392
y_pred = y_pred.reshape(-1)

tests/test_losses.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,32 @@
1313
)
1414

1515

16+
def test_focal_loss_from_logits_false_multiclass():
17+
torch.manual_seed(0)
18+
19+
input_logits = torch.tensor(
20+
[[0.0, 10.0, 0.0], [10.0, 0.0, 0.0], [0.0, 0.0, 10.0]]
21+
).float()
22+
target = torch.tensor([1, 0, 2]).long()
23+
# Convert to probabilities
24+
input_probs = torch.softmax(input_logits, dim=1)
25+
26+
loss_logits = smp.losses.FocalLoss(
27+
mode="multiclass",
28+
from_logits=True,
29+
)(input_logits, target)
30+
31+
loss_probs = smp.losses.FocalLoss(
32+
mode="multiclass",
33+
from_logits=False,
34+
)(input_probs, target)
35+
36+
# They should be close (not exact due to constant shift issue)
37+
assert torch.isfinite(loss_probs)
38+
assert torch.isfinite(loss_logits)
39+
assert abs(loss_logits - loss_probs) < 0.2
40+
41+
1642
def test_focal_loss_with_logits():
1743
input_good = torch.tensor([10, -10, 10]).float()
1844
input_bad = torch.tensor([-1, 2, 0]).float()

0 commit comments

Comments
 (0)