@@ -696,7 +696,7 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
696696 )
697697 )
698698
699- if es :
699+ if es and not math . isnan ( avg_metric ) :
700700 early_stopping (val_acc = avg_metric )
701701 stop_train = torch .tensor (early_stopping .early_stop ).to (device )
702702
@@ -819,13 +819,14 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
819819 with open (os .path .join (ckpt_path , "progress.yaml" ), "r" ) as out_file :
820820 progress = yaml .safe_load (out_file )
821821
822- dict_file = {}
823- dict_file ["best_avg_dice_score" ] = float (avg_metric )
824- dict_file ["best_avg_dice_score_epoch" ] = int (progress [- 1 ]["best_avg_dice_score_epoch" ])
825- dict_file ["best_avg_dice_score_iteration" ] = int (progress [- 1 ]["best_avg_dice_score_iteration" ])
826- dict_file ["inverted_best_validation" ] = True
827- with open (os .path .join (ckpt_path , "progress.yaml" ), "a" ) as out_file :
828- yaml .dump ([dict_file ], stream = out_file )
822+ if not math .isnan (avg_metric ):
823+ dict_file = {}
824+ dict_file ["best_avg_dice_score" ] = float (avg_metric )
825+ dict_file ["best_avg_dice_score_epoch" ] = int (progress [- 1 ]["best_avg_dice_score_epoch" ])
826+ dict_file ["best_avg_dice_score_iteration" ] = int (progress [- 1 ]["best_avg_dice_score_iteration" ])
827+ dict_file ["inverted_best_validation" ] = True
828+ with open (os .path .join (ckpt_path , "progress.yaml" ), "a" ) as out_file :
829+ yaml .dump ([dict_file ], stream = out_file )
829830
830831 if torch .cuda .device_count () > 1 :
831832 dist .barrier ()
0 commit comments