fix: validate tensor shapes in get_stats for multiclass mode#1284
fix: validate tensor shapes in get_stats for multiclass mode#1284lapertor wants to merge 3 commits intoqubvel-org:mainfrom
Conversation
When using mode='multiclass', get_stats expected predictions of shape (N, H, W) and targets of shape (N, H, W). However, no validation was performed, so passing a 4D tensor (e.g. (N, C, H, W)) would silently produce incorrect results instead of raising a clear error. This commit adds explicit shape validation with an informative error message.
qubvel
left a comment
There was a problem hiding this comment.
Thanks! please check some minor comments
| ) | ||
|
|
||
| if mode == "multiclass": | ||
| if output.ndim > 1 and output.shape[1] == num_classes and output.ndim >= 3: |
There was a problem hiding this comment.
| if output.ndim > 1 and output.shape[1] == num_classes and output.ndim >= 3: | |
| if output.ndim >= 3 and output.shape[1] == num_classes: |
| f"For that use case, please use mode='multilabel' instead, " | ||
| f"or convert your tensor with ``output.argmax(dim=1)`` first." | ||
| ) | ||
| if target.ndim > 1 and target.shape[1] == num_classes and target.ndim >= 3: |
There was a problem hiding this comment.
| if target.ndim > 1 and target.shape[1] == num_classes and target.ndim >= 3: | |
| if target.ndim >= 3 and target.shape[1] == num_classes: |
|
|
||
| if output.shape != target.shape: | ||
| # Check if user accidentally passed a one-hot / logits tensor in multiclass mode | ||
| if mode == "multiclass" and output.ndim == target.ndim + 1: |
There was a problem hiding this comment.
not super robust, what if both are passed as (N, C, H, W), but better than nothing 👍
|
I've applied both simplifications as suggested. Regarding the |
|
please fix style issue, thanks |
|
Done! I ran |
Summary
Fixes a silent bug in
get_statswhen usingmode='multiclass'.Problem
When calling
get_statswithmode='multiclass', the function expects:However, no shape validation was performed. Passing a 4D
targettensor of shape(N, C, H, W)(as one would use inmultilabelmode) would silently produce incorrect results instead of raising a clear, informative error.Reproduction
Before the fix (silent wrong results):
After the fix (clear, actionable error):
Fix
Added explicit shape validation in
get_statsformulticlassmode that raises a descriptiveValueErrorwhen the target shape is incorrect, guiding the user toward the correct usage.Related Issues
Closes #863