Skip to content

Commit c9ab349

Browse files
ayulockinfacebook-github-bot
authored andcommitted
[feat] Model version control using W&B Artifacts (#1137)
Summary: 🚀 I have extended the `WandbLogger` with the ability to log the `current.pt` checkpoint as W&B Artifacts. Note that this PR is based on top of this [PR](#1129). ### What is W&B Artifacts? > W&B Artifacts was designed to make it effortless to version your datasets and models, regardless of whether you want to store your files with us or whether you already have a bucket you want us to track. Once you've tracked your dataset or model files, W&B will automatically log each and every modification, giving you a complete and auditable history of changes to your files. Through this PR, W&B Artifacts can help save and organize machine learning models throughout a project's lifecycle. More details in the documentation [here](https://docs.wandb.ai/guides/artifacts/model-versioning). ### Modification This PR adds a `log_model_checkpoint` method to the `WandbLogger` class in the `utils/logger.py` file. This method is called in the `utils/checkpoint.py` file. ### Usage To use this, in the `config/defaults.yaml` do, `training.wandb.enabled=true` and `training.wandb.log_checkpoint=true`. ### Result The screenshot shows the `current.pt` checkpoints saved at intervals defined by `training.checkpoint_interval`. You can check out the logged artifacts page [here](https://wandb.ai/ayut/mmf/artifacts/model/run_ey9xextf_model/0dc64164acbdc300fd01/api). ![image](https://user-images.githubusercontent.com/31141479/139390462-d5c8445e-5c20-4fdd-85d0-51ef64846bf0.png) ### Superpowers With this small addition, now one can easily track different versions of the model, download a checkpoint of interest by using the API in the API tab, easily share the checkpoints with teammates, etc. ### Requests This is a draft PR as there are a few more things that can be improved here. * Is there a better way to access the path to the `current.pt` checkpoint? Rather is the modification made to `utils/checkpoint.py` an acceptable way of approaching this? * While logging a file as W&B artifacts we can also provide metadata associated with that file. In this case, we can add current iteration, training metrics, etc. as the metadata. Would love to get suggestions about the different data points that I should log as metadata alongside the checkpoints. * How to determine if a checkpoint is the best one? If a checkpoint is best I can add `best` as an alias for that checkpoint's artifact. Pull Request resolved: #1137 Test Plan: Imported from GitHub, without a `Test Plan:` line. **Static Docs Preview: mmf** |[Full Site](https://our.intern.facebook.com/intern/staticdocs/eph/D32402090/V6/mmf/)| |**Modified Pages**| |[docs/notes/logger](https://our.intern.facebook.com/intern/staticdocs/eph/D32402090/V6/mmf/docs/notes/logger/)| Reviewed By: apsdehal Differential Revision: D32402090 Pulled By: ebsmothers fbshipit-source-id: 94b881ec55c4197301331d571bc926521e2feecc
1 parent b6a5804 commit c9ab349

File tree

5 files changed

+115
-33
lines changed

5 files changed

+115
-33
lines changed

mmf/configs/defaults.yaml

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,24 @@ training:
4545
wandb:
4646
# Whether to use Weights and Biases Logger, (Default: false)
4747
enabled: false
48+
# An entity is a username or team name where you're sending runs.
49+
# This is necessary if you want to log your metrics to a team account. By default
50+
# it will log the run to your user account.
51+
entity: null
4852
# Project name to be used while logging the experiment with wandb
49-
wandb_projectname: mmf_${oc.env:USER,}
53+
project: mmf
5054
# Experiment/ run name to be used while logging the experiment
5155
# under the project with wandb
52-
wandb_runname: ${training.experiment_name}
56+
name: ${training.experiment_name}
57+
# You can save your model checkpoints as W&B Artifacts for model versioning.
58+
# Set the value to `true` to enable this feature.
59+
log_checkpoint: false
60+
# Specify other argument values that you want to pass to wandb.init(). Check out the documentation
61+
# at https://docs.wandb.ai/ref/python/init to see what arguments are available.
62+
# job_type: 'train'
63+
# tags: ['tag1', 'tag2']
64+
65+
5366

5467
# Size of the batch globally. If distributed or data_parallel
5568
# is used, this will be divided equally among GPUs

mmf/trainers/callbacks/logistics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,10 @@ def __init__(self, config, trainer):
5858
if env_wandb_logdir:
5959
log_dir = env_wandb_logdir
6060

61-
wandb_projectname = config.training.wandb.wandb_projectname
62-
wandb_runname = config.training.wandb.wandb_runname
63-
6461
self.wandb_logger = WandbLogger(
65-
name=wandb_runname, save_dir=log_dir, project=wandb_projectname
62+
entity=config.training.wandb.entity,
63+
config=config,
64+
project=config.training.wandb.project,
6665
)
6766

6867
def on_train_start(self):
@@ -153,6 +152,7 @@ def on_test_end(self, **kwargs):
153152
meter=kwargs["meter"],
154153
should_print=prefix,
155154
tb_writer=self.tb_writer,
155+
wandb_logger=self.wandb_logger,
156156
)
157157
logger.info(f"Finished run in {self.total_timer.get_time_since_start()}")
158158

mmf/utils/checkpoint.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,7 @@ def save(self, update, iteration=None, update_best=False):
522522
best_metric = (
523523
self.trainer.early_stop_callback.early_stopping.best_monitored_value
524524
)
525+
525526
model = self.trainer.model
526527
data_parallel = registry.get("data_parallel") or registry.get("distributed")
527528
fp16_scaler = getattr(self.trainer, "scaler", None)
@@ -574,6 +575,15 @@ def save(self, update, iteration=None, update_best=False):
574575
with open_if_main(current_ckpt_filepath, "wb") as f:
575576
self.save_func(ckpt, f)
576577

578+
# Save the current checkpoint as W&B artifacts for model versioning.
579+
if self.config.training.wandb.log_checkpoint:
580+
logger.info(
581+
"Saving current checkpoint as W&B Artifacts for model versioning"
582+
)
583+
self.trainer.logistics_callback.wandb_logger.log_model_checkpoint(
584+
current_ckpt_filepath
585+
)
586+
577587
# Remove old checkpoints if max_to_keep is set
578588
# In XLA, only delete checkpoint files in main process
579589
if self.max_to_keep > 0 and is_main():

mmf/utils/logger.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,12 @@ def summarize_report(
225225
if not is_main() and not is_xla():
226226
return
227227

228+
# Log the learning rate if available
229+
if wandb_logger and "lr" in extra:
230+
wandb_logger.log_metrics(
231+
{"train/learning_rate": float(extra["lr"])}, commit=False
232+
)
233+
228234
if tb_writer:
229235
scalar_dict = meter.get_scalar_dict()
230236
tb_writer.add_scalars(scalar_dict, current_iteration)
@@ -395,21 +401,19 @@ class WandbLogger:
395401
Log using `Weights and Biases`.
396402
397403
Args:
398-
name: Display name for the run.
399-
save_dir: Path where data is saved (./save/logs/wandb/ by default).
400-
project: Display name for the project.
401-
**init_kwargs: Arguments passed to :func:`wandb.init`.
404+
entity: An entity is a username or team name where you're sending runs.
405+
config: Configuration for the run.
406+
project: Name of the W&B project.
402407
403408
Raises:
404409
ImportError: If wandb package is not installed.
405410
"""
406411

407412
def __init__(
408413
self,
409-
name: Optional[str] = None,
410-
save_dir: Optional[str] = None,
414+
entity: Optional[str] = None,
415+
config: Optional[Dict] = None,
411416
project: Optional[str] = None,
412-
**init_kwargs,
413417
):
414418
try:
415419
import wandb
@@ -420,10 +424,13 @@ def __init__(
420424
)
421425

422426
self._wandb = wandb
423-
424-
self._wandb_init = dict(name=name, project=project, dir=save_dir)
425-
426-
self._wandb_init.update(**init_kwargs)
427+
self._wandb_init = dict(entity=entity, config=config, project=project)
428+
wandb_kwargs = dict(config.training.wandb)
429+
wandb_kwargs.pop("enabled")
430+
wandb_kwargs.pop("entity")
431+
wandb_kwargs.pop("project")
432+
wandb_kwargs.pop("log_checkpoint")
433+
self._wandb_init.update(**wandb_kwargs)
427434

428435
self.setup()
429436

@@ -453,14 +460,33 @@ def _should_log_wandb(self):
453460
else:
454461
return True
455462

456-
def log_metrics(self, metrics: Dict[str, float]):
463+
def log_metrics(self, metrics: Dict[str, float], commit=True):
457464
"""
458465
Log the monitored metrics to the wand dashboard.
459466
460467
Args:
461-
metrics (Dict[str, float]): [description]
468+
metrics (Dict[str, float]): A dictionary of metrics to log.
469+
commit (bool): Save the metrics dict to the wandb server and
470+
increment the step. (default: True)
471+
"""
472+
if not self._should_log_wandb():
473+
return
474+
475+
self._wandb.log(metrics, commit=commit)
476+
477+
def log_model_checkpoint(self, model_path):
478+
"""
479+
Log the model checkpoint to the wandb dashboard.
480+
481+
Args:
482+
model_path (str): Path to the model file.
462483
"""
463484
if not self._should_log_wandb():
464485
return
465486

466-
self._wandb.log(metrics)
487+
model_artifact = self._wandb.Artifact(
488+
"run_" + self._wandb.run.id + "_model", type="model"
489+
)
490+
491+
model_artifact.add_file(model_path, name="current.pt")
492+
self._wandb.log_artifact(model_artifact, aliases=["latest"])

website/docs/notes/logging.md

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,75 @@
11
---
2-
id: concepts
3-
title: Terminology and Concepts
4-
sidebar_label: Terminology and Concepts
2+
id: logger
3+
title: Weights and Biases Logging
4+
sidebar_label: Weights and Biases Logging
55
---
66

77
## Weights and Biases Logger
88

9-
MMF has a `WandbLogger` class which lets the user to log their model's progress using [Weights and Biases](https://gitbook-docs.wandb.ai/).
9+
MMF now has a `WandbLogger` class which lets the user to log their model's progress using [Weights and Biases](https://wandb.ai/site). Enable this logger to automatically log the training/validation metrics, system (GPU and CPU) metrics and configuration parameters.
10+
11+
## First time setup
1012

1113
To set up wandb, run the following:
1214
```
1315
pip install wandb
16+
```
17+
In order to log anything to the W&B server you need to authenticate the machine with W&B **API key**. You can create a new account by going to https://wandb.ai/signup which will generate an API key. If you are an existing user you can retrieve your key from https://wandb.ai/authorize. You only need to supply your key once, and then it is remembered on the same device.
18+
19+
```
1420
wandb login
1521
```
1622

23+
## W&B config parameters
24+
1725
The following options are available in config to enable and customize the wandb logging:
1826
```yaml
1927
training:
2028
# Weights and Biases control, by default Weights and Biases (wandb) is disabled
2129
wandb:
2230
# Whether to use Weights and Biases Logger, (Default: false)
23-
enabled: false
31+
enabled: true
32+
# An entity is a username or team name where you're sending runs.
33+
# This is necessary if you want to log your metrics to a team account. By default
34+
# it will log the run to your user account.
35+
entity: null
2436
# Project name to be used while logging the experiment with wandb
25-
wandb_projectname: mmf_${oc.env:USER}
37+
project: mmf
2638
# Experiment/ run name to be used while logging the experiment
2739
# under the project with wandb
28-
wandb_runname: ${training.experiment_name}
40+
name: ${training.experiment_name}
41+
# Specify other argument values that you want to pass to wandb.init(). Check out the documentation
42+
# at https://docs.wandb.ai/ref/python/init to see what arguments are available.
43+
# job_type: 'train'
44+
# tags: ['tag1', 'tag2']
2945
env:
3046
wandb_logdir: ${env:MMF_WANDB_LOGDIR,}
3147
```
32-
To enable wandb logger the user needs to change the following option in the config.
3348
34-
`training.wandb.enabled=True`
49+
* To enable wandb logger the user needs to change the following option in the config.
50+
51+
`training.wandb.enabled=True`
52+
53+
* To give the `entity` which is the name of the team or the username, the user needs to change the following option in the config. In case no `entity` is provided, the data will be logged to the `entity` set as default in the user's settings.
54+
55+
`training.wandb.entity=<teamname/username>`
56+
57+
* To give the current experiment a project and run name, user should add these config options. The default project name is `mmf` and the default run name is `${training.experiment_name}`.
58+
59+
`training.wandb.project=<ProjectName>` <br />
60+
`training.wandb.name=<RunName>`
61+
62+
* To change the path to the directory where wandb metadata would be stored (Default: `env.log_dir`):
63+
64+
`env.wandb_logdir=<dir_name>`
3565

36-
To give the current experiment a project and run name, user should add these config options.
66+
* To provide extra arguments to `wandb.init()`, the user just needs to define them in the config file. Check out the documentation at https://docs.wandb.ai/ref/python/init to see what arguments are available. An example is shown in the config parameter shown above. Make sure to use the same key name in the config file as defined in the documentation.
3767

38-
`training.wandb.wandb_projectname=<ProjectName> training.wandb.wandb_runname=<RunName>`
68+
## Current features
3969

40-
To change the path to the directory where wandb metadata would be stored (Default: `env.log_dir`):
70+
The following features are currently supported by the `WandbLogger`:
4171

42-
`env.wandb_logdir=<dir_name>`
72+
* Training & Validation metrics
73+
* Learning Rate over time
74+
* GPU: Type, GPU Utilization, power, temperature, CUDA memory usage
75+
* Log configuration parameters

0 commit comments

Comments
 (0)