@@ -140,26 +140,27 @@ def get_stats(
140140
141141 if output .shape != target .shape :
142142 # Check if user accidentally passed a one-hot / logits tensor in multiclass mode
143- if mode == "multiclass" and output .ndim == target .ndim + 1 :
144- raise ValueError (
145- f"In 'multiclass' mode, ``output`` should contain class indices of shape "
146- f"(N, ...), but got shape { output .shape } . "
147- f"It looks like you passed a one-hot or logits tensor. "
148- f"Please convert it first with ``output.argmax(dim=1)``."
149- )
150- if mode == "multiclass" and target .ndim == output .ndim + 1 :
151- raise ValueError (
152- f"In 'multiclass' mode, ``target`` should contain class indices of shape "
153- f"(N, ...), but got shape { target .shape } . "
154- f"It looks like you passed a one-hot tensor. "
155- f"Please convert it first with ``target.argmax(dim=1)``."
156- )
143+ if mode == "multiclass" :
144+ if output .ndim >= 3 and output .shape [1 ] == num_classes :
145+ raise ValueError (
146+ f"In 'multiclass' mode, ``output`` should contain class indices of shape "
147+ f"(N, H, W) or (N,), but got shape { tuple (output .shape )} . "
148+ f"It looks like you passed a one-hot or logits tensor of shape (N, C, ...). "
149+ f"For that use case, please use mode='multilabel' instead, "
150+ f"or convert your tensor with ``output.argmax(dim=1)`` first."
151+ )
152+ if target .ndim >= 3 and target .shape [1 ] == num_classes :
153+ raise ValueError (
154+ f"In 'multiclass' mode, ``target`` should contain class indices of shape "
155+ f"(N, H, W) or (N,), but got shape { tuple (target .shape )} . "
156+ f"It looks like you passed a one-hot encoded tensor of shape (N, C, ...). "
157+ f"Convert it with ``target.argmax(dim=1)`` first."
158+ )
157159 raise ValueError (
158160 "Dimensions should match, but ``output`` shape is not equal to ``target`` "
159161 + f"shape, { output .shape } != { target .shape } "
160162 )
161163
162-
163164 if mode != "multiclass" and ignore_index is not None :
164165 raise ValueError (
165166 f"``ignore_index`` parameter is not supported for '{ mode } ' mode"
@@ -179,21 +180,37 @@ def get_stats(
179180 )
180181
181182 if mode == "multiclass" :
182- if output .ndim > 1 and output .shape [1 ] == num_classes and output . ndim >= 3 :
183+ if output .ndim >= 3 and output .shape [1 ] == num_classes :
183184 raise ValueError (
184185 f"In 'multiclass' mode, ``output`` should contain class indices of shape "
185186 f"(N, H, W) or (N,), but got shape { tuple (output .shape )} . "
186187 f"It looks like you passed a one-hot or logits tensor of shape (N, C, ...). "
187188 f"For that use case, please use mode='multilabel' instead, "
188189 f"or convert your tensor with ``output.argmax(dim=1)`` first."
189190 )
190- if target .ndim > 1 and target .shape [1 ] == num_classes and target . ndim >= 3 :
191+ if target .ndim >= 3 and target .shape [1 ] == num_classes :
191192 raise ValueError (
192193 f"In 'multiclass' mode, ``target`` should contain class indices of shape "
193194 f"(N, H, W) or (N,), but got shape { tuple (target .shape )} . "
194195 f"It looks like you passed a one-hot encoded tensor of shape (N, C, ...). "
195196 f"Convert it with ``target.argmax(dim=1)`` first."
196197 )
198+ # Additional robustness: a 4D tensor (N, C, H, W) is always incorrect in
199+ # 'multiclass' mode, regardless of whether shape[1] matches num_classes
200+ if output .ndim == 4 :
201+ raise ValueError (
202+ f"In 'multiclass' mode, ``output`` should contain class indices of shape "
203+ f"(N, H, W), but got shape { tuple (output .shape )} . "
204+ f"A 4D tensor is always incorrect in 'multiclass' mode. "
205+ f"Convert it with ``output.argmax(dim=1)`` first."
206+ )
207+ if target .ndim == 4 :
208+ raise ValueError (
209+ f"In 'multiclass' mode, ``target`` should contain class indices of shape "
210+ f"(N, H, W), but got shape { tuple (target .shape )} . "
211+ f"A 4D tensor is always incorrect in 'multiclass' mode. "
212+ f"Convert it with ``target.argmax(dim=1)`` first."
213+ )
197214 tp , fp , fn , tn = _get_stats_multiclass (
198215 output , target , num_classes , ignore_index
199216 )
0 commit comments