Skip to content

Commit b344c3c

Browse files
committed
fix: guard avg_metric division against zero count when all classes absent from validation fold
1 parent a3deff0 commit b344c3c

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

  • auto3dseg/algorithm_templates/swinunetr/scripts

auto3dseg/algorithm_templates/swinunetr/scripts/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
662662
if metric[2 * _c + 1] != 0:
663663
avg_metric += metric[2 * _c] / metric[2 * _c + 1]
664664
count +=1
665-
avg_metric = avg_metric / float(count)
665+
avg_metric = avg_metric / float(count) if count > 0 else float('nan')
666666
logger.debug(f"Avg_metric: {avg_metric}")
667667

668668
writer.add_scalar("val/acc", avg_metric, epoch)
@@ -815,7 +815,7 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
815815
if metric[2 * _c + 1] != 0:
816816
avg_metric += metric[2 * _c] / metric[2 * _c + 1]
817817
count += 1
818-
avg_metric = avg_metric / float(count)
818+
avg_metric = avg_metric / float(count) if count > 0 else float('nan')
819819
logger.debug(f"Avg_metric at original resolution: {avg_metric}")
820820

821821
with open(os.path.join(ckpt_path, "progress.yaml"), "r") as out_file:

0 commit comments

Comments
 (0)