Fix: ZeroDivisionError in SwinUNETR Auto3DSeg training when validation fold contains classes with no samples.#427
Fix: ZeroDivisionError in SwinUNETR Auto3DSeg training when validation fold contains classes with no samples.#427daxellwells wants to merge 6 commits intoProject-MONAI:mainfrom
Conversation
…fold has classes with no samples
WalkthroughValidation in the swinunetr training script now guards per-class dice against zero denominators by producing ChangesValidation Metric Guard & Averaging Fix
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Review rate limit: 7/8 reviews remaining, refill in 7 minutes and 30 seconds.Comment |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@auto3dseg/algorithm_templates/swinunetr/scripts/train.py`:
- Around line 659-665: The code computes avg_metric by summing per-class ratios
using metric and metric_dim then unconditionally divides by count, which raises
ZeroDivisionError when count==0; update both avg_metric blocks (the one using
metric, metric_dim, avg_metric, count in the shown loop and the analogous block
later) to check if count>0 before dividing and otherwise set avg_metric to a
safe default (e.g., 0.0 or float('nan')) or skip division; ensure the check uses
the same variables (metric, metric_dim, count) so the division never occurs when
count==0.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 825245b2-fc31-4048-b2aa-c447e69ffe4c
📒 Files selected for processing (2)
auto3dseg/algorithm_templates/swinunetr/scripts/train.pyauto3dseg/configs/metadata.json
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
auto3dseg/algorithm_templates/swinunetr/scripts/train.py (1)
702-704:⚠️ Potential issue | 🟠 Major | ⚡ Quick win
float('nan')passed toEarlyStoppingsilently disables it.When
count == 0,avg_metric = float('nan')is passed toearly_stopping(val_acc=float('nan')). InEarlyStopping.__call__(Lines 95–106), the comparisonval_acc + self.delta < self.best_scoreevaluates toFalsefor any NaN operand (Python NaN comparisons are alwaysFalse). Theelsebranch always executes, settingbest_score = NaNand resettingcounter = 0. Oncebest_scoreis NaN, all subsequent rounds also hit theelsebranch regardless of the metric value —counteris never incremented and early stopping never fires for the rest of training.🛠 Proposed fix — skip EarlyStopping update for NaN avg_metric
if es: - early_stopping(val_acc=avg_metric) - stop_train = torch.tensor(early_stopping.early_stop).to(device) + if not math.isnan(avg_metric): + early_stopping(val_acc=avg_metric) + stop_train = torch.tensor(early_stopping.early_stop).to(device)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@auto3dseg/algorithm_templates/swinunetr/scripts/train.py` around lines 702 - 704, avg_metric can be NaN on the first evaluation which breaks EarlyStopping because NaN causes comparisons to always be False and sets best_score to NaN; modify the block that calls early_stopping(val_acc=avg_metric) (the call site around early_stopping and stop_train) to first check for NaN (e.g., using math.isnan or torch.isnan depending on avg_metric type) and skip calling EarlyStopping.__call__ when avg_metric is NaN, leaving early_stopping.best_score and counter untouched and computing stop_train from early_stopping.early_stop as before.
♻️ Duplicate comments (1)
auto3dseg/algorithm_templates/swinunetr/scripts/train.py (1)
654-669:⚠️ Potential issue | 🟠 Major | ⚡ Quick win
float('nan')reaches unguardedmlflow.log_metriccalls — potential uncaught crash.Two paths pass NaN to MLflow without a catch:
Fallback path (Line 656): When
metric[2 * _c + 1] == 0, everymlflow.log_metriccall in this loop receivesfloat('nan'). If the primary call (Line 651) raises aRestException("Metric cannot have value NaN" — documented on REST-backed tracking servers), theexcept BaseExceptionon Line 653 catches it and falls into the fallback. The fallback on Line 656 retriesmlflow.log_metricwith the same NaN, this time completely unguarded, propagating the exception and crashing the training loop.
avg_metricpath (Lines 668–669): Whencount == 0,avg_metric = float('nan').mlflow.log_metric("val/acc", avg_metric, …)at Line 669 is entirely outside any try/except — an unguarded crash point.The
count == 0ZeroDivisionError flagged in the previous review is now resolved at Line 665 — this comment concerns the new unguarded NaN-to-MLflow propagation.🛠 Proposed fix — guard NaN before MLflow calls
for _c in range(metric_dim): + class_metric = metric[2 * _c] / metric[2 * _c + 1] if metric[2 * _c + 1] != 0 else float('nan') if metric[2 * _c +1] == 0: - logger.debug(f"Warning: class {_c + 1} has no samples in validation fold, skipping.") - logger.debug(f"Evaluation metric - class {_c + 1}: {metric[2 * _c] / metric[2 * _c + 1] if metric[2 * _c + 1] != 0 else float('nan')}") - try: - writer.add_scalar( - f"val_class/acc_{class_names[_c]}", metric[2 * _c] / metric[2 * _c + 1]if metric[2 * _c + 1] != 0 else float('nan'), epoch - ) - mlflow.log_metric( - f"val_class/acc_{class_names[_c]}", metric[2 * _c] / metric[2 * _c + 1]if metric[2 * _c + 1] != 0 else float('nan'), step=epoch - ) - except BaseException: - writer.add_scalar(f"val_class/acc_{_c}", metric[2 * _c] / metric[2 * _c + 1]if metric[2 * _c + 1] != 0 else float('nan'), epoch) - mlflow.log_metric( - f"val_class/acc_{_c}", metric[2 * _c] / metric[2 * _c + 1]if metric[2 * _c + 1] != 0 else float('nan'), step=epoch - ) + logger.warning(f"Class {_c + 1} has no samples in validation fold; logging as NaN.") + logger.debug(f"Evaluation metric - class {_c + 1}: {class_metric}") + if not math.isnan(class_metric): + try: + writer.add_scalar(f"val_class/acc_{class_names[_c]}", class_metric, epoch) + mlflow.log_metric(f"val_class/acc_{class_names[_c]}", class_metric, step=epoch) + except BaseException: + writer.add_scalar(f"val_class/acc_{_c}", class_metric, epoch) + mlflow.log_metric(f"val_class/acc_{_c}", class_metric, step=epoch)- writer.add_scalar("val/acc", avg_metric, epoch) - mlflow.log_metric("val/acc", avg_metric, step=epoch) + if not math.isnan(avg_metric): + writer.add_scalar("val/acc", avg_metric, epoch) + mlflow.log_metric("val/acc", avg_metric, step=epoch)Verify current MLflow NaN rejection behavior:
Does mlflow.log_metric raise an exception when passed float('nan') on a REST tracking server in recent MLflow versions?🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@auto3dseg/algorithm_templates/swinunetr/scripts/train.py` around lines 654 - 669, The log calls pass float('nan') into mlflow.log_metric (and writer.add_scalar) unguarded which can raise on REST servers; before calling mlflow.log_metric (and writer.add_scalar) for per-class and avg values in the loop around writer.add_scalar/mlflow.log_metric and for the avg_metric before logging, check for NaN (use math.isnan or numpy.isnan) and skip logging to MLflow (or substitute a safe numeric sentinel) when the value is NaN; update the block that uses metric[2*_c] / metric[2*_c+1] and the avg_metric compute/usage to only call mlflow.log_metric and writer.add_scalar when the computed value is finite, referring to symbols writer.add_scalar, mlflow.log_metric, metric, avg_metric, and the per-class loop to locate and change the code.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@auto3dseg/algorithm_templates/swinunetr/scripts/train.py`:
- Around line 643-645: The loop incorrectly logs "skipping" with logger.debug
but still falls through and records NaN values; change the check for metric[2 *
_c + 1] == 0 to (1) use logger.warning instead of logger.debug for the
"skipping" message and (2) add an explicit continue immediately after that
warning so the code actually skips further logging and TensorBoard/MLflow
recording for that class (_c). Update the lines referencing metric, _c, and
logger in train.py accordingly.
---
Outside diff comments:
In `@auto3dseg/algorithm_templates/swinunetr/scripts/train.py`:
- Around line 702-704: avg_metric can be NaN on the first evaluation which
breaks EarlyStopping because NaN causes comparisons to always be False and sets
best_score to NaN; modify the block that calls
early_stopping(val_acc=avg_metric) (the call site around early_stopping and
stop_train) to first check for NaN (e.g., using math.isnan or torch.isnan
depending on avg_metric type) and skip calling EarlyStopping.__call__ when
avg_metric is NaN, leaving early_stopping.best_score and counter untouched and
computing stop_train from early_stopping.early_stop as before.
---
Duplicate comments:
In `@auto3dseg/algorithm_templates/swinunetr/scripts/train.py`:
- Around line 654-669: The log calls pass float('nan') into mlflow.log_metric
(and writer.add_scalar) unguarded which can raise on REST servers; before
calling mlflow.log_metric (and writer.add_scalar) for per-class and avg values
in the loop around writer.add_scalar/mlflow.log_metric and for the avg_metric
before logging, check for NaN (use math.isnan or numpy.isnan) and skip logging
to MLflow (or substitute a safe numeric sentinel) when the value is NaN; update
the block that uses metric[2*_c] / metric[2*_c+1] and the avg_metric
compute/usage to only call mlflow.log_metric and writer.add_scalar when the
computed value is finite, referring to symbols writer.add_scalar,
mlflow.log_metric, metric, avg_metric, and the per-class loop to locate and
change the code.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 3ce74817-77a4-40b6-90d3-008e2e302d57
📒 Files selected for processing (1)
auto3dseg/algorithm_templates/swinunetr/scripts/train.py
…ng-class warning level
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
auto3dseg/algorithm_templates/swinunetr/scripts/train.py (2)
655-701:⚠️ Potential issue | 🟠 Major | ⚡ Quick winGuard the rest of the validation bookkeeping when
avg_metricisNaN.The new fallback avoids the divide-by-zero, but this block still feeds
avg_metricintoearly_stoppingand the run-history files. If a validation round has no valid class metrics,EarlyStopping.__call__will seenanand can reset its state unexpectedly.🛠 Proposed fix
- if es: - early_stopping(val_acc=avg_metric) - stop_train = torch.tensor(early_stopping.early_stop).to(device) + if es and not math.isnan(avg_metric): + early_stopping(val_acc=avg_metric) + stop_train = torch.tensor(early_stopping.early_stop).to(device)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@auto3dseg/algorithm_templates/swinunetr/scripts/train.py` around lines 655 - 701, avg_metric can be NaN and you must skip all validation bookkeeping when it is; wrap the remaining validation actions (the writer.add_scalar and mlflow.log_metric calls, the best_metric comparison/save/progress.yaml update, the logger.debug about current epoch, the accuracy_history.csv append, and the early_stopping(val_acc=avg_metric) + stop_train assignment) in an if not math.isnan(avg_metric) guard so none of these use NaN values (refer to avg_metric, writer.add_scalar, mlflow.log_metric, best_metric/best_metric_epoch logic, progress.yaml write, accuracy_history.csv write, and early_stopping).
810-829:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winDon't persist a
NaNaggregate intoprogress.yaml.When every class is invalid for a pass,
avg_metricbecomesnan, but this block still appends it asbest_avg_dice_score. That leaves a non-numeric score in the run artifact and can confuse downstream tooling that readsprogress.yaml.🛠 Proposed fix
- dict_file["best_avg_dice_score"] = float(avg_metric) - dict_file["best_avg_dice_score_epoch"] = int(progress[-1]["best_avg_dice_score_epoch"]) - dict_file["best_avg_dice_score_iteration"] = int(progress[-1]["best_avg_dice_score_iteration"]) - dict_file["inverted_best_validation"] = True - with open(os.path.join(ckpt_path, "progress.yaml"), "a") as out_file: - yaml.dump([dict_file], stream=out_file) + if not math.isnan(avg_metric): + dict_file["best_avg_dice_score"] = float(avg_metric) + dict_file["best_avg_dice_score_epoch"] = int(progress[-1]["best_avg_dice_score_epoch"]) + dict_file["best_avg_dice_score_iteration"] = int(progress[-1]["best_avg_dice_score_iteration"]) + dict_file["inverted_best_validation"] = True + with open(os.path.join(ckpt_path, "progress.yaml"), "a") as out_file: + yaml.dump([dict_file], stream=out_file)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@auto3dseg/algorithm_templates/swinunetr/scripts/train.py` around lines 810 - 829, The code currently computes avg_metric and always appends it to progress.yaml even when it's NaN; change the logic in the block that builds dict_file (around avg_metric, progress, dict_file) so you only persist best_avg_dice_score when it's a valid number (e.g., count>0 and not math.isnan(avg_metric) and math.isfinite(avg_metric)); if the metric is invalid, either skip adding best_avg_dice_score (and related inverted flag) or explicitly set a safe sentinel (e.g., omit the key or use None) before writing with yaml.dump, ensuring you check avg_metric validity using the avg_metric variable and count used to compute it.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@auto3dseg/algorithm_templates/swinunetr/scripts/train.py`:
- Around line 655-701: avg_metric can be NaN and you must skip all validation
bookkeeping when it is; wrap the remaining validation actions (the
writer.add_scalar and mlflow.log_metric calls, the best_metric
comparison/save/progress.yaml update, the logger.debug about current epoch, the
accuracy_history.csv append, and the early_stopping(val_acc=avg_metric) +
stop_train assignment) in an if not math.isnan(avg_metric) guard so none of
these use NaN values (refer to avg_metric, writer.add_scalar, mlflow.log_metric,
best_metric/best_metric_epoch logic, progress.yaml write, accuracy_history.csv
write, and early_stopping).
- Around line 810-829: The code currently computes avg_metric and always appends
it to progress.yaml even when it's NaN; change the logic in the block that
builds dict_file (around avg_metric, progress, dict_file) so you only persist
best_avg_dice_score when it's a valid number (e.g., count>0 and not
math.isnan(avg_metric) and math.isfinite(avg_metric)); if the metric is invalid,
either skip adding best_avg_dice_score (and related inverted flag) or explicitly
set a safe sentinel (e.g., omit the key or use None) before writing with
yaml.dump, ensuring you check avg_metric validity using the avg_metric variable
and count used to compute it.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 648daa8f-1296-470d-9f81-d4cf8c7cc926
📒 Files selected for processing (1)
auto3dseg/algorithm_templates/swinunetr/scripts/train.py
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@auto3dseg/algorithm_templates/swinunetr/scripts/train.py`:
- Line 661: The training loop can produce NaN avg_metric and never write
progress or best model, but later the "original-resolution" phase
unconditionally tries to load best_metric_model.pt and progress.yaml; add guards
around that load: check the in-memory progress list and/or the filesystem for
progress.yaml and best_metric_model.pt (and ensure progress is non-empty and
best model path exists) before attempting to open/load them, and if missing skip
the original-resolution re-evaluation (or bail out cleanly with a descriptive
log) so the code that references avg_metric/best model is only executed when
valid progress entries and a saved best model exist; update the code paths that
append to progress and the code that loads progress/best model (look for
variables avg_metric, progress, best_metric_model.pt, progress.yaml and the
original-resolution phase) accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 2b842389-03fc-496b-aa2b-fc65e24872c1
📒 Files selected for processing (1)
auto3dseg/algorithm_templates/swinunetr/scripts/train.py
Summary
metric_dimdenominator inavg_metricwith acountof present classes only, preventing artificial deflation of reported Dice scoresRoot cause
When a segmentation class is entirely absent from a validation fold,
metric[2 * _c + 1](the not-nan sample count, per the existing code comment) is 0. The bare divisionmetric[2 * _c] / metric[2 * _c + 1]then crashes with ZeroDivisionError. This is a clinical reality in many segmentation tasks. For example, we encountered this bug when running Auto3dseg for vertebrae segmentation and certain cases were missing certain vertebrae.Test plan
python auto3dseg/tests/test_algo_templates.pypasses (Ran 4 tests — OK) on Linux with NVIDIA RTX 4090Summary by CodeRabbit
Bug Fixes
Documentation