Skip to content

Commit 812f557

Browse files
committed
fix: skip early stopping and progress.yaml write when avg_metric is NaN
1 parent d4a7af5 commit 812f557

1 file changed

Lines changed: 9 additions & 8 deletions

File tree

  • auto3dseg/algorithm_templates/swinunetr/scripts

auto3dseg/algorithm_templates/swinunetr/scripts/train.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
696696
)
697697
)
698698

699-
if es:
699+
if es and not math.isnan(avg_metric):
700700
early_stopping(val_acc=avg_metric)
701701
stop_train = torch.tensor(early_stopping.early_stop).to(device)
702702

@@ -819,13 +819,14 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
819819
with open(os.path.join(ckpt_path, "progress.yaml"), "r") as out_file:
820820
progress = yaml.safe_load(out_file)
821821

822-
dict_file = {}
823-
dict_file["best_avg_dice_score"] = float(avg_metric)
824-
dict_file["best_avg_dice_score_epoch"] = int(progress[-1]["best_avg_dice_score_epoch"])
825-
dict_file["best_avg_dice_score_iteration"] = int(progress[-1]["best_avg_dice_score_iteration"])
826-
dict_file["inverted_best_validation"] = True
827-
with open(os.path.join(ckpt_path, "progress.yaml"), "a") as out_file:
828-
yaml.dump([dict_file], stream=out_file)
822+
if not math.isnan(avg_metric):
823+
dict_file = {}
824+
dict_file["best_avg_dice_score"] = float(avg_metric)
825+
dict_file["best_avg_dice_score_epoch"] = int(progress[-1]["best_avg_dice_score_epoch"])
826+
dict_file["best_avg_dice_score_iteration"] = int(progress[-1]["best_avg_dice_score_iteration"])
827+
dict_file["inverted_best_validation"] = True
828+
with open(os.path.join(ckpt_path, "progress.yaml"), "a") as out_file:
829+
yaml.dump([dict_file], stream=out_file)
829830

830831
if torch.cuda.device_count() > 1:
831832
dist.barrier()

0 commit comments

Comments
 (0)