Skip to content

Commit d6d438b

Browse files
committed
fix: raise clear ValueError when one-hot tensors passed in multiclass mode
1 parent 6adf220 commit d6d438b

1 file changed

Lines changed: 34 additions & 17 deletions

File tree

segmentation_models_pytorch/metrics/functional.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -140,26 +140,27 @@ def get_stats(
140140

141141
if output.shape != target.shape:
142142
# Check if user accidentally passed a one-hot / logits tensor in multiclass mode
143-
if mode == "multiclass" and output.ndim == target.ndim + 1:
144-
raise ValueError(
145-
f"In 'multiclass' mode, ``output`` should contain class indices of shape "
146-
f"(N, ...), but got shape {output.shape}. "
147-
f"It looks like you passed a one-hot or logits tensor. "
148-
f"Please convert it first with ``output.argmax(dim=1)``."
149-
)
150-
if mode == "multiclass" and target.ndim == output.ndim + 1:
151-
raise ValueError(
152-
f"In 'multiclass' mode, ``target`` should contain class indices of shape "
153-
f"(N, ...), but got shape {target.shape}. "
154-
f"It looks like you passed a one-hot tensor. "
155-
f"Please convert it first with ``target.argmax(dim=1)``."
156-
)
143+
if mode == "multiclass":
144+
if output.ndim >= 3 and output.shape[1] == num_classes:
145+
raise ValueError(
146+
f"In 'multiclass' mode, ``output`` should contain class indices of shape "
147+
f"(N, H, W) or (N,), but got shape {tuple(output.shape)}. "
148+
f"It looks like you passed a one-hot or logits tensor of shape (N, C, ...). "
149+
f"For that use case, please use mode='multilabel' instead, "
150+
f"or convert your tensor with ``output.argmax(dim=1)`` first."
151+
)
152+
if target.ndim >= 3 and target.shape[1] == num_classes:
153+
raise ValueError(
154+
f"In 'multiclass' mode, ``target`` should contain class indices of shape "
155+
f"(N, H, W) or (N,), but got shape {tuple(target.shape)}. "
156+
f"It looks like you passed a one-hot encoded tensor of shape (N, C, ...). "
157+
f"Convert it with ``target.argmax(dim=1)`` first."
158+
)
157159
raise ValueError(
158160
"Dimensions should match, but ``output`` shape is not equal to ``target`` "
159161
+ f"shape, {output.shape} != {target.shape}"
160162
)
161163

162-
163164
if mode != "multiclass" and ignore_index is not None:
164165
raise ValueError(
165166
f"``ignore_index`` parameter is not supported for '{mode}' mode"
@@ -179,21 +180,37 @@ def get_stats(
179180
)
180181

181182
if mode == "multiclass":
182-
if output.ndim > 1 and output.shape[1] == num_classes and output.ndim >= 3:
183+
if output.ndim >= 3 and output.shape[1] == num_classes:
183184
raise ValueError(
184185
f"In 'multiclass' mode, ``output`` should contain class indices of shape "
185186
f"(N, H, W) or (N,), but got shape {tuple(output.shape)}. "
186187
f"It looks like you passed a one-hot or logits tensor of shape (N, C, ...). "
187188
f"For that use case, please use mode='multilabel' instead, "
188189
f"or convert your tensor with ``output.argmax(dim=1)`` first."
189190
)
190-
if target.ndim > 1 and target.shape[1] == num_classes and target.ndim >= 3:
191+
if target.ndim >= 3 and target.shape[1] == num_classes:
191192
raise ValueError(
192193
f"In 'multiclass' mode, ``target`` should contain class indices of shape "
193194
f"(N, H, W) or (N,), but got shape {tuple(target.shape)}. "
194195
f"It looks like you passed a one-hot encoded tensor of shape (N, C, ...). "
195196
f"Convert it with ``target.argmax(dim=1)`` first."
196197
)
198+
# Additional robustness: a 4D tensor (N, C, H, W) is always incorrect in
199+
# 'multiclass' mode, regardless of whether shape[1] matches num_classes
200+
if output.ndim == 4:
201+
raise ValueError(
202+
f"In 'multiclass' mode, ``output`` should contain class indices of shape "
203+
f"(N, H, W), but got shape {tuple(output.shape)}. "
204+
f"A 4D tensor is always incorrect in 'multiclass' mode. "
205+
f"Convert it with ``output.argmax(dim=1)`` first."
206+
)
207+
if target.ndim == 4:
208+
raise ValueError(
209+
f"In 'multiclass' mode, ``target`` should contain class indices of shape "
210+
f"(N, H, W), but got shape {tuple(target.shape)}. "
211+
f"A 4D tensor is always incorrect in 'multiclass' mode. "
212+
f"Convert it with ``target.argmax(dim=1)`` first."
213+
)
197214
tp, fp, fn, tn = _get_stats_multiclass(
198215
output, target, num_classes, ignore_index
199216
)

0 commit comments

Comments
 (0)