Skip to content

feat(losses): add class_weights support to Dice, Jaccard, Tversky and Focal losses#1290

Merged
qubvel merged 2 commits intoqubvel-org:mainfrom
lapertor:feat/class-weights-loss
Mar 19, 2026
Merged

feat(losses): add class_weights support to Dice, Jaccard, Tversky and Focal losses#1290
qubvel merged 2 commits intoqubvel-org:mainfrom
lapertor:feat/class-weights-loss

Conversation

@lapertor
Copy link
Copy Markdown
Contributor

Closes #881

Description

This PR adds a class_weights parameter to DiceLoss, JaccardLoss, TverskyLoss, and FocalLoss,
allowing users to weight the contribution of each class to the final loss.
This is useful for dealing with class imbalance.

Changes

  • DiceLoss, JaccardLoss, TverskyLoss: class_weights supported in multiclass and multilabel modes. Raises ValueError for binary mode.
  • FocalLoss: class_weights supported in multiclass mode only (the per-class loop already exists). Raises ValueError for other modes.
  • Weights do not need to be normalized; the implementation divides by weights.sum().
  • New tests covering: uniform weights equivalence, non-uniform weights effect, scale invariance, and invalid mode error.

Example

loss = DiceLoss(mode="multiclass", class_weights=[1.0, 2.0, 0.5])

Tests

All 29 tests pass:

  • 5 new tests for DiceLoss/JaccardLoss/TverskyLoss class_weights
  • 2 new tests for FocalLoss class_weights

Copy link
Copy Markdown
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks for adding tests

ignore_index: Label that indicates ignored pixels (does not contribute to loss)
eps: A small epsilon for numerical stability to avoid zero division error
(denominator will be always greater or equal to eps)
class_weights: Array of weights for each class. If not None, the loss for each class
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class_weights: Array of weights for each class. If not None, the loss for each class
class_weights: List of weights for each class. If not ``None``, the loss for each class

@lapertor
Copy link
Copy Markdown
Contributor Author

Thanks for the review! I've applied your docstring suggestion across all four loss files and also fixed the failing style check.

@qubvel qubvel merged commit f72d8c2 into qubvel-org:main Mar 19, 2026
17 checks passed
@lapertor lapertor deleted the feat/class-weights-loss branch March 23, 2026 13:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants