|
7 | 7 | import os |
8 | 8 | import sys |
9 | 9 | import time |
| 10 | +from copy import deepcopy |
10 | 11 | from functools import wraps |
11 | 12 | from typing import Any, Callable, Dict, Optional, Union |
12 | 13 |
|
13 | | -import omegaconf |
14 | 14 | import torch |
15 | 15 | from mmf.common.registry import registry |
16 | 16 | from mmf.utils.configuration import get_mmf_env |
17 | 17 | from mmf.utils.distributed import get_rank, is_main, is_xla |
18 | 18 | from mmf.utils.file_io import PathManager |
19 | 19 | from mmf.utils.timer import Timer |
20 | | -from omegaconf import OmegaConf |
21 | 20 | from termcolor import colored |
22 | 21 |
|
23 | 22 |
|
@@ -228,7 +227,7 @@ def summarize_report( |
228 | 227 | return |
229 | 228 |
|
230 | 229 | # Log the learning rate if available |
231 | | - if wandb_logger and "lr" in extra.keys(): |
| 230 | + if wandb_logger and "lr" in extra: |
232 | 231 | wandb_logger.log_metrics( |
233 | 232 | {"train/learning_rate": float(extra["lr"])}, commit=False |
234 | 233 | ) |
@@ -426,17 +425,12 @@ def __init__( |
426 | 425 | ) |
427 | 426 |
|
428 | 427 | self._wandb = wandb |
429 | | - |
430 | | - self._wandb_init = dict(entity=entity, config=config, project=project) |
431 | | - |
432 | | - wandb_params = config.training.wandb |
433 | | - with omegaconf.open_dict(wandb_params): |
434 | | - wandb_params.pop("enabled") |
435 | | - wandb_params.pop("entity") |
436 | | - wandb_params.pop("project") |
437 | | - |
438 | | - init_kwargs = OmegaConf.to_container(wandb_params, resolve=True) |
439 | | - self._wandb_init.update(**init_kwargs) |
| 428 | + self._wandb_init = dict(entity=entity, project=project) |
| 429 | + wandb_kwargs = deepcopy(config.training.wandb) |
| 430 | + wandb_kwargs.pop("enabled") |
| 431 | + wandb_kwargs.pop("entity") |
| 432 | + wandb_kwargs.pop("project") |
| 433 | + self._wandb_init.update(**wandb_kwargs) |
440 | 434 |
|
441 | 435 | self.setup() |
442 | 436 |
|
|
0 commit comments