Skip to content

Commit 4a824d6

Browse files
committed
fixed re-calibration parameter handling in post-processing
1 parent e60a443 commit 4a824d6

2 files changed

Lines changed: 12 additions & 4 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818
### Bug fixes
1919
- fixed incorrect handling of hydra choices in CONTINUE mode
2020
- fixed .env loading for non default system
21+
- fixed re-calibration parameter handling in post-processing
2122

2223
## 1.0.4 (05/02/2025):
2324
This patch transitions the testing strategy from GPU based to purely CPU tests.

src/mml/core/scripts/schedulers/postprocess_scheduler.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import sklearn.base
1717
import torch
1818
import torch.nn as nn
19-
from omegaconf import DictConfig
19+
from omegaconf import DictConfig, ListConfig
2020
from psrcal.calibration import AffineCalLogLoss, calibrate
2121
from quapy.method.aggregative import ACC
2222
from torchmetrics import BootStrapper, MetricCollection
@@ -106,11 +106,17 @@ def calibrate_predictions(self, task: str, model_index: int):
106106
# load prediction
107107
all_splits_prediction = torch.load(model.predictions[pred])
108108
# we use the validation split as base for inferring calibration parameters
109-
if DataSplit.VAL.value not in all_splits_prediction:
109+
if DataSplit.VAL.value not in all_splits_prediction or len(all_splits_prediction[DataSplit.VAL.value]) == 0:
110110
raise RuntimeError(
111111
f"No predictions have been made on validation data for model @ {model._stored}"
112112
f"and prediction on {pred} (@ {model.predictions[pred]})."
113113
)
114+
if (len(all_splits_prediction[DataSplit.TEST.value]) == 0) and (len(all_splits_prediction[DataSplit.UNLABELLED.value]) == 0):
115+
warnings.warn(
116+
f"No predictions have been made on unlabeled & test data for model @ {model._stored}"
117+
f"and prediction on {pred} (@ {model.predictions[pred]}). Will omit re-calibration."
118+
)
119+
continue
114120
all_logits = {}
115121
all_labels = {}
116122
for split in [DataSplit.VAL, DataSplit.TEST, DataSplit.UNLABELLED]:
@@ -176,9 +182,9 @@ def predict(self, X):
176182
# to avoid any zero division we assume a minimum probability
177183
prior = np.clip(prior, a_min=1e-8, a_max=None)
178184
prior = prior / np.sum(prior)
179-
elif isinstance(self.cfg.mode.prior, list):
185+
elif isinstance(self.cfg.mode.prior, list) or isinstance(self.cfg.mode.prior, ListConfig):
180186
# use given priors
181-
prior = self.cfg.mode.prior
187+
prior = list(self.cfg.mode.prior)
182188
elif self.cfg.mode.prior == "val":
183189
# infer priors on validation data, psrcal infers them
184190
prior = None
@@ -214,6 +220,7 @@ def predict(self, X):
214220
for case_idx, case_logits in enumerate(logits):
215221
all_splits_prediction[split.value][case_idx]["calibrated"] = case_logits
216222
torch.save(obj=all_splits_prediction, f=model.predictions[pred])
223+
logger.info(f"Updated calibrated predictions for model @ {model._stored} and prediction on {pred}.")
217224

218225
def select_ensemble(self):
219226
pivot_struct = self.get_struct(self.pivot)

0 commit comments

Comments
 (0)