Skip to content

Commit a3deff0

Browse files
committed
fix: prevent ZeroDivisionError in SwinUNETR training when validation fold has classes with no samples
1 parent 21ed8e5 commit a3deff0

2 files changed

Lines changed: 21 additions & 12 deletions

File tree

auto3dseg/algorithm_templates/swinunetr/scripts/train.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -640,24 +640,29 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
640640
metric = metric.tolist()
641641
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
642642
for _c in range(metric_dim):
643-
logger.debug(f"Evaluation metric - class {_c + 1}: {metric[2 * _c] / metric[2 * _c + 1]}")
643+
if metric[2 * _c +1] == 0:
644+
logger.debug(f"Warning: class {_c + 1} has no samples in validation fold, skipping.")
645+
logger.debug(f"Evaluation metric - class {_c + 1}: {metric[2 * _c] / metric[2 * _c + 1] if metric[2 * _c + 1] != 0 else float('nan')}")
644646
try:
645647
writer.add_scalar(
646-
f"val_class/acc_{class_names[_c]}", metric[2 * _c] / metric[2 * _c + 1], epoch
648+
f"val_class/acc_{class_names[_c]}", metric[2 * _c] / metric[2 * _c + 1]if metric[2 * _c + 1] != 0 else float('nan'), epoch
647649
)
648650
mlflow.log_metric(
649-
f"val_class/acc_{class_names[_c]}", metric[2 * _c] / metric[2 * _c + 1], step=epoch
651+
f"val_class/acc_{class_names[_c]}", metric[2 * _c] / metric[2 * _c + 1]if metric[2 * _c + 1] != 0 else float('nan'), step=epoch
650652
)
651653
except BaseException:
652-
writer.add_scalar(f"val_class/acc_{_c}", metric[2 * _c] / metric[2 * _c + 1], epoch)
654+
writer.add_scalar(f"val_class/acc_{_c}", metric[2 * _c] / metric[2 * _c + 1]if metric[2 * _c + 1] != 0 else float('nan'), epoch)
653655
mlflow.log_metric(
654-
f"val_class/acc_{_c}", metric[2 * _c] / metric[2 * _c + 1], step=epoch
656+
f"val_class/acc_{_c}", metric[2 * _c] / metric[2 * _c + 1]if metric[2 * _c + 1] != 0 else float('nan'), step=epoch
655657
)
656658

657659
avg_metric = 0
660+
count = 0
658661
for _c in range(metric_dim):
659-
avg_metric += metric[2 * _c] / metric[2 * _c + 1]
660-
avg_metric = avg_metric / float(metric_dim)
662+
if metric[2 * _c + 1] != 0:
663+
avg_metric += metric[2 * _c] / metric[2 * _c + 1]
664+
count +=1
665+
avg_metric = avg_metric / float(count)
661666
logger.debug(f"Avg_metric: {avg_metric}")
662667

663668
writer.add_scalar("val/acc", avg_metric, epoch)
@@ -801,13 +806,16 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
801806
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
802807
for _c in range(metric_dim):
803808
logger.debug(
804-
f"Evaluation metric at original resolution - class {_c + 1}: {metric[2 * _c] / metric[2 * _c + 1]}"
809+
f"Evaluation metric at original resolution - class {_c + 1}: {metric[2 * _c] / metric[2 * _c + 1] if metric[2 * _c + 1] != 0 else float('nan')}"
805810
)
806811

807812
avg_metric = 0
813+
count = 0
808814
for _c in range(metric_dim):
809-
avg_metric += metric[2 * _c] / metric[2 * _c + 1]
810-
avg_metric = avg_metric / float(metric_dim)
815+
if metric[2 * _c + 1] != 0:
816+
avg_metric += metric[2 * _c] / metric[2 * _c + 1]
817+
count += 1
818+
avg_metric = avg_metric / float(count)
811819
logger.debug(f"Avg_metric at original resolution: {avg_metric}")
812820

813821
with open(os.path.join(ckpt_path, "progress.yaml"), "r") as out_file:

auto3dseg/configs/metadata.json

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{
2-
"version": "0.0.8",
2+
"version": "0.0.9",
33
"changelog": {
4+
"0.0.9": "Fix ZeroDivisionError in swinunetr training script for missing classes in validation fold.",
45
"0.0.8": "Update swin unetr pretrained weights link",
56
"0.0.7": "Add support for MLFlow experiment name.",
67
"0.0.6": "Move metadata.json under 'configs' to be consistent with bundles.",
@@ -10,4 +11,4 @@
1011
"0.0.2": "update hyper-parameter naming in dints algorithm template.",
1112
"0.0.1": "this version is based on commit 03a6d4effb9223670f439c3a29198ef34938922f."
1213
}
13-
}
14+
}

0 commit comments

Comments
 (0)