22
33import collections
44import functools
5+ import itertools
56import json
67import logging
78import os
@@ -231,7 +232,11 @@ def summarize_report(
231232
232233 if wandb_logger :
233234 metrics = meter .get_scalar_dict ()
234- wandb_logger .log_metrics ({** metrics , "trainer/global_step" : current_iteration })
235+ wandb_logger .log_metrics ({** metrics , "trainer/global_step" : current_iteration }, commit = False )
236+
237+ # Log the learning rate if available
238+ if wandb_logger and 'lr' in extra .keys ():
239+ wandb_logger .log_metrics ({"train/learning_rate" : float (extra ["lr" ])})
235240
236241 if not should_print :
237242 return
@@ -400,7 +405,6 @@ class WandbLogger:
400405 save_dir: Path where data is saved (./save/logs/wandb/ by default).
401406 project: Display name for the project.
402407 config: Configuration for the run.
403- **init_kwargs: Arguments passed to :func:`wandb.init`.
404408
405409 Raises:
406410 ImportError: If wandb package is not installed.
@@ -413,7 +417,6 @@ def __init__(
413417 save_dir : Optional [str ] = None ,
414418 project : Optional [str ] = None ,
415419 config : Optional [Dict ] = None ,
416- ** init_kwargs ,
417420 ):
418421 try :
419422 import wandb
@@ -429,6 +432,11 @@ def __init__(
429432 entity = entity , name = name , project = project , dir = save_dir , config = config
430433 )
431434
435+ init_kwargs = dict (
436+ itertools .islice (
437+ config .training .wandb .items (), 4 , len (config .training .wandb )
438+ )
439+ )
432440 self ._wandb_init .update (** init_kwargs )
433441
434442 self .setup ()
@@ -459,14 +467,15 @@ def _should_log_wandb(self):
459467 else :
460468 return True
461469
462- def log_metrics (self , metrics : Dict [str , float ]):
470+ def log_metrics (self , metrics : Dict [str , float ], commit = True ):
463471 """
464472 Log the monitored metrics to the wand dashboard.
465473
466474 Args:
467- metrics (Dict[str, float]): [description]
475+ metrics (Dict[str, float]): A dictionary of metrics to log.
476+ commit (bool): Save the metrics dict to the wandb server and increment the step. (default: True)
468477 """
469478 if not self ._should_log_wandb ():
470479 return
471480
472- self ._wandb .log (metrics )
481+ self ._wandb .log (metrics , commit = commit )
0 commit comments