Skip to content

Commit fdcc46e

Browse files
authored
Fix keras model initialization after restart (#1602)
* Address case when keras model has not been evaluated and returns empty metrics Signed-off-by: Patrick Foley <psfoley@gmail.com> * Address review comments Signed-off-by: Patrick Foley <psfoley@gmail.com> * Fix lint formatting issue Signed-off-by: Patrick Foley <psfoley@gmail.com> --------- Signed-off-by: Patrick Foley <psfoley@gmail.com>
1 parent 00f240c commit fdcc46e

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

openfl/federated/task/runner_keras.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,11 @@ def train_task(
169169

170170
return global_tensor_dict, local_tensor_dict
171171

172+
def _initialize_metrics_result(self, batch_size):
173+
# evaluation needed before metrics can be resolved
174+
self.model.evaluate(self.data_loader.get_valid_loader(batch_size), verbose=1)
175+
return self.model.get_metrics_result()
176+
172177
def train_(self, batch_generator, metrics: list = None, **kwargs):
173178
"""Train single epoch. Override this function for custom training.
174179
@@ -187,15 +192,14 @@ def train_(self, batch_generator, metrics: list = None, **kwargs):
187192
# initialization (build_model).
188193
# If metrics are added (i.e. not a subset of what was originally
189194
# defined) then the model must be recompiled.
195+
196+
batch_size = kwargs.get("batch_size", 1)
190197
try:
191198
results = self.model.get_metrics_result()
199+
if len(results) == 0:
200+
results = self._initialize_metrics_result(batch_size)
192201
except ValueError:
193-
if "batch_size" in kwargs:
194-
batch_size = kwargs["batch_size"]
195-
else:
196-
batch_size = 1
197-
# evaluation needed before metrics can be resolved
198-
self.model.evaluate(self.data_loader.get_valid_loader(batch_size), verbose=1)
202+
self._initialize_metrics_result(batch_size)
199203
results = self.model.get_metrics_result()
200204

201205
# TODO if there are new metrics in the flplan that were not included
@@ -230,10 +234,7 @@ def validate_task(self, col_name, round_num, input_tensor_dict, **kwargs):
230234
These correspond to acc, precision, f1_score, etc.
231235
dict: Empty dictionary.
232236
"""
233-
if "batch_size" in kwargs:
234-
batch_size = kwargs["batch_size"]
235-
else:
236-
batch_size = 1
237+
batch_size = kwargs.get("batch_size", 1)
237238

238239
self.rebuild_model(round_num, input_tensor_dict, validation=True)
239240
param_metrics = kwargs["metrics"]

0 commit comments

Comments
 (0)