Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions segmentation_models_pytorch/metrics/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,27 @@ def get_stats(
)

if output.shape != target.shape:
# Check if user accidentally passed a one-hot / logits tensor in multiclass mode
if mode == "multiclass" and output.ndim == target.ndim + 1:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not super robust, what if both are passed as (N, C, H, W), but better than nothing 👍

raise ValueError(
f"In 'multiclass' mode, ``output`` should contain class indices of shape "
f"(N, ...), but got shape {output.shape}. "
f"It looks like you passed a one-hot or logits tensor. "
f"Please convert it first with ``output.argmax(dim=1)``."
)
if mode == "multiclass" and target.ndim == output.ndim + 1:
raise ValueError(
f"In 'multiclass' mode, ``target`` should contain class indices of shape "
f"(N, ...), but got shape {target.shape}. "
f"It looks like you passed a one-hot tensor. "
f"Please convert it first with ``target.argmax(dim=1)``."
)
raise ValueError(
"Dimensions should match, but ``output`` shape is not equal to ``target`` "
+ f"shape, {output.shape} != {target.shape}"
)


if mode != "multiclass" and ignore_index is not None:
raise ValueError(
f"``ignore_index`` parameter is not supported for '{mode}' mode"
Expand All @@ -163,6 +179,21 @@ def get_stats(
)

if mode == "multiclass":
if output.ndim > 1 and output.shape[1] == num_classes and output.ndim >= 3:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if output.ndim > 1 and output.shape[1] == num_classes and output.ndim >= 3:
if output.ndim >= 3 and output.shape[1] == num_classes:

raise ValueError(
f"In 'multiclass' mode, ``output`` should contain class indices of shape "
f"(N, H, W) or (N,), but got shape {tuple(output.shape)}. "
f"It looks like you passed a one-hot or logits tensor of shape (N, C, ...). "
f"For that use case, please use mode='multilabel' instead, "
f"or convert your tensor with ``output.argmax(dim=1)`` first."
)
if target.ndim > 1 and target.shape[1] == num_classes and target.ndim >= 3:
Copy link
Copy Markdown
Collaborator

@qubvel qubvel Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if target.ndim > 1 and target.shape[1] == num_classes and target.ndim >= 3:
if target.ndim >= 3 and target.shape[1] == num_classes:

raise ValueError(
f"In 'multiclass' mode, ``target`` should contain class indices of shape "
f"(N, H, W) or (N,), but got shape {tuple(target.shape)}. "
f"It looks like you passed a one-hot encoded tensor of shape (N, C, ...). "
f"Convert it with ``target.argmax(dim=1)`` first."
)
tp, fp, fn, tn = _get_stats_multiclass(
output, target, num_classes, ignore_index
)
Expand Down
54 changes: 54 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
import torch
import segmentation_models_pytorch as smp


class TestGetStatsMulticlass:
"""Tests for get_stats in multiclass mode."""

def test_correct_input(self):
"""Class index tensors of shape (N, ...) should work correctly."""
tp, fp, fn, tn = smp.metrics.get_stats(
output=torch.tensor([[0, 1, 2, 1]]),
target=torch.tensor([[0, 1, 2, 2]]),
mode="multiclass",
num_classes=3,
)
assert tp.shape == (1, 3)
assert fp.tolist() == [[0, 1, 0]]
assert fn.tolist() == [[0, 0, 1]]

def test_onehot_output_raises(self):
"""Passing a one-hot encoded output (N, C, ...) should raise ValueError with hint."""
with pytest.raises(ValueError, match="output.argmax"):
smp.metrics.get_stats(
output=torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]]),
target=torch.tensor([[0, 1, 2]]),
mode="multiclass",
num_classes=3,
)

def test_onehot_target_raises(self):
"""Passing a one-hot encoded target (N, C, ...) should raise ValueError with hint."""
with pytest.raises(ValueError, match="target.argmax"):
smp.metrics.get_stats(
output=torch.tensor([[0, 1, 2]]),
target=torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]]),
mode="multiclass",
num_classes=3,
)


def test_argmax_fix_gives_perfect_iou(self):
"""Correcting a one-hot tensor with argmax(dim=1) should yield IoU=1.0."""
output_onehot = torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]])
target_onehot = torch.tensor([[[1, 0, 0], [0, 1, 0], [0, 0, 1]]])

tp, fp, fn, tn = smp.metrics.get_stats(
output=output_onehot.argmax(dim=1),
target=target_onehot.argmax(dim=1),
mode="multiclass",
num_classes=3,
)
iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="macro")
assert iou.item() == pytest.approx(1.0)