@@ -640,24 +640,29 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
640640 metric = metric .tolist ()
641641 if torch .cuda .device_count () == 1 or dist .get_rank () == 0 :
642642 for _c in range (metric_dim ):
643- logger .debug (f"Evaluation metric - class { _c + 1 } : { metric [2 * _c ] / metric [2 * _c + 1 ]} " )
643+ if metric [2 * _c + 1 ] == 0 :
644+ logger .debug (f"Warning: class { _c + 1 } has no samples in validation fold, skipping." )
645+ logger .debug (f"Evaluation metric - class { _c + 1 } : { metric [2 * _c ] / metric [2 * _c + 1 ] if metric [2 * _c + 1 ] != 0 else float ('nan' )} " )
644646 try :
645647 writer .add_scalar (
646- f"val_class/acc_{ class_names [_c ]} " , metric [2 * _c ] / metric [2 * _c + 1 ], epoch
648+ f"val_class/acc_{ class_names [_c ]} " , metric [2 * _c ] / metric [2 * _c + 1 ]if metric [ 2 * _c + 1 ] != 0 else float ( 'nan' ) , epoch
647649 )
648650 mlflow .log_metric (
649- f"val_class/acc_{ class_names [_c ]} " , metric [2 * _c ] / metric [2 * _c + 1 ], step = epoch
651+ f"val_class/acc_{ class_names [_c ]} " , metric [2 * _c ] / metric [2 * _c + 1 ]if metric [ 2 * _c + 1 ] != 0 else float ( 'nan' ) , step = epoch
650652 )
651653 except BaseException :
652- writer .add_scalar (f"val_class/acc_{ _c } " , metric [2 * _c ] / metric [2 * _c + 1 ], epoch )
654+ writer .add_scalar (f"val_class/acc_{ _c } " , metric [2 * _c ] / metric [2 * _c + 1 ]if metric [ 2 * _c + 1 ] != 0 else float ( 'nan' ) , epoch )
653655 mlflow .log_metric (
654- f"val_class/acc_{ _c } " , metric [2 * _c ] / metric [2 * _c + 1 ], step = epoch
656+ f"val_class/acc_{ _c } " , metric [2 * _c ] / metric [2 * _c + 1 ]if metric [ 2 * _c + 1 ] != 0 else float ( 'nan' ) , step = epoch
655657 )
656658
657659 avg_metric = 0
660+ count = 0
658661 for _c in range (metric_dim ):
659- avg_metric += metric [2 * _c ] / metric [2 * _c + 1 ]
660- avg_metric = avg_metric / float (metric_dim )
662+ if metric [2 * _c + 1 ] != 0 :
663+ avg_metric += metric [2 * _c ] / metric [2 * _c + 1 ]
664+ count += 1
665+ avg_metric = avg_metric / float (count )
661666 logger .debug (f"Avg_metric: { avg_metric } " )
662667
663668 writer .add_scalar ("val/acc" , avg_metric , epoch )
@@ -801,13 +806,16 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override):
801806 if torch .cuda .device_count () == 1 or dist .get_rank () == 0 :
802807 for _c in range (metric_dim ):
803808 logger .debug (
804- f"Evaluation metric at original resolution - class { _c + 1 } : { metric [2 * _c ] / metric [2 * _c + 1 ]} "
809+ f"Evaluation metric at original resolution - class { _c + 1 } : { metric [2 * _c ] / metric [2 * _c + 1 ] if metric [ 2 * _c + 1 ] != 0 else float ( 'nan' ) } "
805810 )
806811
807812 avg_metric = 0
813+ count = 0
808814 for _c in range (metric_dim ):
809- avg_metric += metric [2 * _c ] / metric [2 * _c + 1 ]
810- avg_metric = avg_metric / float (metric_dim )
815+ if metric [2 * _c + 1 ] != 0 :
816+ avg_metric += metric [2 * _c ] / metric [2 * _c + 1 ]
817+ count += 1
818+ avg_metric = avg_metric / float (count )
811819 logger .debug (f"Avg_metric at original resolution: { avg_metric } " )
812820
813821 with open (os .path .join (ckpt_path , "progress.yaml" ), "r" ) as out_file :
0 commit comments