|
16 | 16 | import sklearn.base |
17 | 17 | import torch |
18 | 18 | import torch.nn as nn |
19 | | -from omegaconf import DictConfig |
| 19 | +from omegaconf import DictConfig, ListConfig |
20 | 20 | from psrcal.calibration import AffineCalLogLoss, calibrate |
21 | 21 | from quapy.method.aggregative import ACC |
22 | 22 | from torchmetrics import BootStrapper, MetricCollection |
@@ -106,11 +106,17 @@ def calibrate_predictions(self, task: str, model_index: int): |
106 | 106 | # load prediction |
107 | 107 | all_splits_prediction = torch.load(model.predictions[pred]) |
108 | 108 | # we use the validation split as base for inferring calibration parameters |
109 | | - if DataSplit.VAL.value not in all_splits_prediction: |
| 109 | + if DataSplit.VAL.value not in all_splits_prediction or len(all_splits_prediction[DataSplit.VAL.value]) == 0: |
110 | 110 | raise RuntimeError( |
111 | 111 | f"No predictions have been made on validation data for model @ {model._stored}" |
112 | 112 | f"and prediction on {pred} (@ {model.predictions[pred]})." |
113 | 113 | ) |
| 114 | + if (len(all_splits_prediction[DataSplit.TEST.value]) == 0) and (len(all_splits_prediction[DataSplit.UNLABELLED.value]) == 0): |
| 115 | + warnings.warn( |
| 116 | + f"No predictions have been made on unlabeled & test data for model @ {model._stored}" |
| 117 | + f"and prediction on {pred} (@ {model.predictions[pred]}). Will omit re-calibration." |
| 118 | + ) |
| 119 | + continue |
114 | 120 | all_logits = {} |
115 | 121 | all_labels = {} |
116 | 122 | for split in [DataSplit.VAL, DataSplit.TEST, DataSplit.UNLABELLED]: |
@@ -176,9 +182,9 @@ def predict(self, X): |
176 | 182 | # to avoid any zero division we assume a minimum probability |
177 | 183 | prior = np.clip(prior, a_min=1e-8, a_max=None) |
178 | 184 | prior = prior / np.sum(prior) |
179 | | - elif isinstance(self.cfg.mode.prior, list): |
| 185 | + elif isinstance(self.cfg.mode.prior, list) or isinstance(self.cfg.mode.prior, ListConfig): |
180 | 186 | # use given priors |
181 | | - prior = self.cfg.mode.prior |
| 187 | + prior = list(self.cfg.mode.prior) |
182 | 188 | elif self.cfg.mode.prior == "val": |
183 | 189 | # infer priors on validation data, psrcal infers them |
184 | 190 | prior = None |
@@ -214,6 +220,7 @@ def predict(self, X): |
214 | 220 | for case_idx, case_logits in enumerate(logits): |
215 | 221 | all_splits_prediction[split.value][case_idx]["calibrated"] = case_logits |
216 | 222 | torch.save(obj=all_splits_prediction, f=model.predictions[pred]) |
| 223 | + logger.info(f"Updated calibrated predictions for model @ {model._stored} and prediction on {pred}.") |
217 | 224 |
|
218 | 225 | def select_ensemble(self): |
219 | 226 | pivot_struct = self.get_struct(self.pivot) |
|
0 commit comments