Add from_logits support to FocalLoss and eps for API consistency with DiceLoss#1268
Conversation
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
|
The current fix keeps the existing implementation unchanged, but it still reconstructs logits from probabilities, which is not strictly equivalent to the original logits formulation. The OpenMMLab implementation handles logits and activated probabilities separately and also provides an optimized CUDA implementation. I kept this PR minimal, but I can open a separate issue to discuss the differences between the two implementations if that would be helpful. |
|
Please run
thanks for keeping it minimal and easy to review, I would prefer to merge it as is, don't want to add optimized CUDA kernels atm |
|
what's the issue here? |
|
No issues, merging! |
Summary
This PR adds a
from_logitsandepsparameter toFocalLossto align its API withDiceLoss.Currently,
DiceLosssupports both logits and probability inputs via thefrom_logitsflag, whileFocalLossalways assumes raw logits. This creates inconsistency when using models configured with an activation function (e.g.,activation="softmax"), requiring users to manually remove activations or wrap the loss.This change introduces a
from_logitsflag toFocalLossto support both logits and probability inputs in a consistent manner.Motivation
Adresses the Issue #1263
Example of current behavior:
Because
FocalLossalways assumes logits, using it with probability outputs results in incorrect loss computation.This PR resolves that inconsistency.
Implementation Details
Added
from_logits: bool = Trueparameter toFocalLoss.When
from_logits=False:For binary/multilabel modes:
For multiclass mode:
No changes were made to the underlying focal formulation.
Backward compatibility is preserved (
from_logits=Trueby default).Tests
Added tests to verify:
FocalLossworks correctly withfrom_logits=False.All existing tests pass.
Backward Compatibility
from_logits=False.