Skip to content

Commit a0decd2

Browse files
committed
cleaned passing of kwargs, added wandb_logger to write validation metrics, log lr
1 parent 1a74a25 commit a0decd2

File tree

3 files changed

+21
-12
lines changed

3 files changed

+21
-12
lines changed

mmf/configs/defaults.yaml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,11 @@ training:
5454
# Experiment/ run name to be used while logging the experiment
5555
# under the project with wandb
5656
wandb_runname: ${training.experiment_name}
57-
# Specify other argument values to be used while logging the experiment
58-
init_kwargs:
59-
job_type: train
57+
# Specify other argument values that you want to pass to wandb.init(). Check out the documentation
58+
# at https://docs.wandb.ai/ref/python/init to see what arguments are available.
59+
# job_type: 'train'
60+
# tags: ['tag1', 'tag2']
61+
6062

6163

6264
# Size of the batch globally. If distributed or data_parallel

mmf/trainers/callbacks/logistics.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,12 @@ def __init__(self, config, trainer):
5858
if env_wandb_logdir:
5959
log_dir = env_wandb_logdir
6060

61-
wandb_init_kwargs = config.training.wandb.init_kwargs
62-
6361
self.wandb_logger = WandbLogger(
6462
entity=config.training.wandb.entity,
6563
project=config.training.wandb.wandb_projectname,
6664
config=config,
6765
name=config.training.wandb.wandb_runname,
6866
save_dir=log_dir,
69-
**wandb_init_kwargs,
7067
)
7168

7269
def on_train_start(self):
@@ -157,6 +154,7 @@ def on_test_end(self, **kwargs):
157154
meter=kwargs["meter"],
158155
should_print=prefix,
159156
tb_writer=self.tb_writer,
157+
wandb_logger=self.wandb_logger,
160158
)
161159
logger.info(f"Finished run in {self.total_timer.get_time_since_start()}")
162160

mmf/utils/logger.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections
44
import functools
5+
import itertools
56
import json
67
import logging
78
import os
@@ -231,7 +232,11 @@ def summarize_report(
231232

232233
if wandb_logger:
233234
metrics = meter.get_scalar_dict()
234-
wandb_logger.log_metrics({**metrics, "trainer/global_step": current_iteration})
235+
wandb_logger.log_metrics({**metrics, "trainer/global_step": current_iteration}, commit=False)
236+
237+
# Log the learning rate if available
238+
if wandb_logger and 'lr' in extra.keys():
239+
wandb_logger.log_metrics({"train/learning_rate": float(extra["lr"])})
235240

236241
if not should_print:
237242
return
@@ -400,7 +405,6 @@ class WandbLogger:
400405
save_dir: Path where data is saved (./save/logs/wandb/ by default).
401406
project: Display name for the project.
402407
config: Configuration for the run.
403-
**init_kwargs: Arguments passed to :func:`wandb.init`.
404408
405409
Raises:
406410
ImportError: If wandb package is not installed.
@@ -413,7 +417,6 @@ def __init__(
413417
save_dir: Optional[str] = None,
414418
project: Optional[str] = None,
415419
config: Optional[Dict] = None,
416-
**init_kwargs,
417420
):
418421
try:
419422
import wandb
@@ -429,6 +432,11 @@ def __init__(
429432
entity=entity, name=name, project=project, dir=save_dir, config=config
430433
)
431434

435+
init_kwargs = dict(
436+
itertools.islice(
437+
config.training.wandb.items(), 4, len(config.training.wandb)
438+
)
439+
)
432440
self._wandb_init.update(**init_kwargs)
433441

434442
self.setup()
@@ -459,14 +467,15 @@ def _should_log_wandb(self):
459467
else:
460468
return True
461469

462-
def log_metrics(self, metrics: Dict[str, float]):
470+
def log_metrics(self, metrics: Dict[str, float], commit=True):
463471
"""
464472
Log the monitored metrics to the wand dashboard.
465473
466474
Args:
467-
metrics (Dict[str, float]): [description]
475+
metrics (Dict[str, float]): A dictionary of metrics to log.
476+
commit (bool): Save the metrics dict to the wandb server and increment the step. (default: True)
468477
"""
469478
if not self._should_log_wandb():
470479
return
471480

472-
self._wandb.log(metrics)
481+
self._wandb.log(metrics, commit=commit)

0 commit comments

Comments
 (0)