Skip to content

fix: validate tensor shapes in get_stats for multiclass mode#1284

Open
lapertor wants to merge 3 commits intoqubvel-org:mainfrom
lapertor:fix/get-stats-multiclass-shape-validation
Open

fix: validate tensor shapes in get_stats for multiclass mode#1284
lapertor wants to merge 3 commits intoqubvel-org:mainfrom
lapertor:fix/get-stats-multiclass-shape-validation

Conversation

@lapertor
Copy link
Copy Markdown
Contributor

Summary

Fixes a silent bug in get_stats when using mode='multiclass'.

Problem

When calling get_stats with mode='multiclass', the function expects:

  • output of shape (N, H, W) or (N, C, H, W) (with logits/probabilities)
  • target of shape (N, H, W) containing class indices

However, no shape validation was performed. Passing a 4D target tensor of shape (N, C, H, W) (as one would use in multilabel mode) would silently produce incorrect results instead of raising a clear, informative error.

Reproduction

Before the fix (silent wrong results):

import torch
import segmentation_models_pytorch as smp

# Output: class indices (N, H, W) (correct for multiclass)
output = torch.randint(0, 3, [10, 256, 256])
# Target: wrong shape (N, C, H, W) (one-hot tensor passed by mistake)
target = torch.randint(0, 3, [10, 3, 256, 256])

# Used to silently produce incorrect tp/fp/fn/tn values
tp, fp, fn, tn = smp.metrics.get_stats(output, target, mode='multiclass', num_classes=3)

After the fix (clear, actionable error):

ValueError: In 'multiclass' mode, ``target`` should contain class indices of shape (N, ...),
but got shape torch.Size([10, 3, 256, 256]). It looks like you passed a one-hot tensor.
Please convert it first with ``target.argmax(dim=1)``.

Fix

Added explicit shape validation in get_stats for multiclass mode that raises a descriptive ValueError when the target shape is incorrect, guiding the user toward the correct usage.

Related Issues

Closes #863

When using mode='multiclass', get_stats expected predictions of shape (N, H, W) and targets of shape (N, H, W). However, no validation was performed, so passing a 4D tensor (e.g. (N, C, H, W)) would silently produce incorrect results instead of raising a clear error. This commit adds explicit shape validation with an informative error message.
Copy link
Copy Markdown
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! please check some minor comments

)

if mode == "multiclass":
if output.ndim > 1 and output.shape[1] == num_classes and output.ndim >= 3:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if output.ndim > 1 and output.shape[1] == num_classes and output.ndim >= 3:
if output.ndim >= 3 and output.shape[1] == num_classes:

f"For that use case, please use mode='multilabel' instead, "
f"or convert your tensor with ``output.argmax(dim=1)`` first."
)
if target.ndim > 1 and target.shape[1] == num_classes and target.ndim >= 3:
Copy link
Copy Markdown
Collaborator

@qubvel qubvel Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if target.ndim > 1 and target.shape[1] == num_classes and target.ndim >= 3:
if target.ndim >= 3 and target.shape[1] == num_classes:


if output.shape != target.shape:
# Check if user accidentally passed a one-hot / logits tensor in multiclass mode
if mode == "multiclass" and output.ndim == target.ndim + 1:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not super robust, what if both are passed as (N, C, H, W), but better than nothing 👍

@lapertor
Copy link
Copy Markdown
Contributor Author

lapertor commented Mar 13, 2026

I've applied both simplifications as suggested.

Regarding the (N, C, H, W) case where both output and target have the same shape: I've addressed it in this new commit by adding an unconditional ndim == 4 check inside the if mode == "multiclass" block, which is always reached regardless of whether output.shape == target.shape. I think that it covers that edge case; it did in my own tests.

@qubvel
Copy link
Copy Markdown
Collaborator

qubvel commented Mar 19, 2026

please fix style issue, thanks

@lapertor
Copy link
Copy Markdown
Contributor Author

Done! I ran ruff to fix the style issues. The CI checks should be green now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants