Skip to content

Commit c82aace

Browse files
ankitadefacebook-github-bot
authored andcommitted
upgrade lightning dependency (#1301)
Summary: - Upgrade lightning to 1.6 to fix installation errors - Fix some tests since lightning logic to track global step has changed Pull Request resolved: #1301 Reviewed By: pikapecan Differential Revision: D44670777 Pulled By: ankitade fbshipit-source-id: 0febb0683aeb59fbb82cf1f29bbadf1d3385493d
1 parent 02a55a3 commit c82aace

File tree

7 files changed

+14
-33
lines changed

7 files changed

+14
-33
lines changed

.github/workflows/cpu_test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ jobs:
5959
run: |
6060
conda activate mmf
6161
python -m pip install --upgrade pip
62-
pip install --upgrade setuptools
62+
pip install setuptools==65.6.3
6363
pip install --progress-bar off pytest
6464
pip install -r requirements.txt
6565
python -c 'import torch; print("Torch version:", torch.__version__)'

mmf/trainers/lightning_core/loop_callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def _get_iterations_for_logging(self, trainer: Trainer):
175175
return trainer.fit_loop.batch_idx + 1
176176

177177
def _get_num_updates_for_logging(self, trainer: Trainer):
178-
return trainer.global_step + 1
178+
return trainer.global_step
179179

180180
def _train_log(self, trainer: Trainer, pl_module: LightningModule):
181181
self.train_combined_report = self.train_combined_report.detach()

mmf/trainers/lightning_core/loop_callback_with_torchmetrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def _get_iterations_for_logging(self, trainer: Trainer):
133133
return trainer.fit_loop.batch_idx + 1
134134

135135
def _get_num_updates_for_logging(self, trainer: Trainer):
136-
return trainer.global_step + 1
136+
return trainer.global_step
137137

138138
def _get_train_extra_log(self, trainer: Trainer, pl_module: LightningModule):
139139
extra = {}

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ datasets==1.2.1
1919
matplotlib==3.3.4
2020
pycocotools==2.0.2
2121
ftfy==5.8
22-
pytorch-lightning @ git+https://github.com/PyTorchLightning/pytorch-lightning@9b011606f
22+
pytorch-lightning==1.6.0
2323
psutil
2424
pillow==9.3.0
2525
sentencepiece

tests/trainers/lightning/test_checkpoint.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def _get_lightning_trainer(
197197

198198

199199
class TestLightningCheckpoint(TestLightningCheckpoint):
200+
@skip_if_no_network
200201
def test_load_resume_parity_with_mmf(self):
201202
# with checkpoint.resume = True, by default it loads "current.ckpt"
202203
self._load_checkpoint_and_test("current.ckpt", ckpt_config={"resume": True})
@@ -208,6 +209,7 @@ def test_load_resume_best_parity_with_mmf(self):
208209
"best.ckpt", ckpt_config={"resume": True, "resume_best": True}
209210
)
210211

212+
@skip_if_no_network
211213
def test_load_resume_ignore_resume_zoo(self):
212214
# specifying both checkpoint.resume = True and resume_zoo
213215
# resume zoo should be ignored. It should load the "current.ckpt"
@@ -393,7 +395,8 @@ def test_load_trainer_ckpt_number_of_steps(self):
393395
)
394396
self.assertEquals(lightning.trainer.global_step, 12)
395397
call_args_list = [l[0][4] for l in mock_method.call_args_list]
396-
self.assertListEqual(list(range(0, 6)), call_args_list)
398+
# in lightning 1.6.0 last batch idx from ckpt is repeated
399+
self.assertListEqual(list(range(5, 11)), call_args_list)
397400

398401
def test_trainer_save_current_parity_with_mmf(self):
399402
with mock_env_with_temp(
@@ -454,7 +457,7 @@ def test_lightning_checkpoint_interval(self):
454457
files = os.listdir(os.path.join(tmp_d, "models"))
455458
self.assertEquals(3, len(files))
456459
indexes = {int(x[:-5].split("=")[1]) for x in files}
457-
self.assertSetEqual({1, 3, 5}, indexes)
460+
self.assertSetEqual({2, 4, 6}, indexes)
458461

459462
def _get_mmf_ckpt(self, filename, ckpt_config=None):
460463
with mock_env_with_temp(
@@ -508,12 +511,7 @@ def _load_checkpoint_and_test(self, filename, ckpt_config=None):
508511

509512
# Make sure lightning and mmf parity
510513
self._assert_same_dict(mmf_ckpt["model"], lightning_ckpt["state_dict"])
511-
512-
# different case for best checkpoint, see D34398730
513-
if "resume_best" in ckpt_config and ckpt_config["resume_best"]:
514-
self.assertEquals(mmf_ckpt["current_epoch"], lightning_ckpt["epoch"] + 1)
515-
else:
516-
self.assertEquals(mmf_ckpt["current_epoch"], lightning_ckpt["epoch"])
514+
self.assertEquals(mmf_ckpt["current_epoch"], lightning_ckpt["epoch"] + 1)
517515
self.assertEquals(mmf_ckpt["num_updates"], lightning_ckpt["global_step"])
518516
self._assert_same_dict(
519517
mmf_ckpt["optimizer"], lightning_ckpt["optimizer_states"][0]

tests/trainers/lightning/test_logging.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from mmf.trainers.callbacks.logistics import LogisticsCallback
77
from mmf.trainers.lightning_core.loop_callback import LightningLoopCallback
88
from mmf.utils.timer import Timer
9+
from tests.test_utils import skip_if_no_network
910
from tests.trainers.test_utils import (
1011
get_config_with_defaults,
1112
get_lightning_trainer,
@@ -19,6 +20,7 @@ def setUp(self):
1920
self.mmf_tensorboard_logs = []
2021
self.lightning_tensorboard_logs = []
2122

23+
@skip_if_no_network
2224
@patch("mmf.common.test_reporter.PathManager.mkdirs")
2325
@patch("mmf.trainers.callbacks.logistics.setup_output_folder", return_value="logs")
2426
@patch("mmf.trainers.lightning_trainer.setup_output_folder", return_value="logs")

tests/trainers/lightning/test_validation.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -92,21 +92,7 @@ def log_values(
9292
keys = list(gt.keys())
9393
self.assertListEqual(keys, list(lv.keys()))
9494
for key in keys:
95-
if key == "num_updates" and gt[key] == self.ground_truths[-1][key]:
96-
# After training, in the last evaluation run, mmf's num updates is 8
97-
# while lightning's num updates is 9, this is due to a hack to
98-
# assign the lightning num_updates to be the trainer.global_step+1.
99-
#
100-
# This is necessary because of a lightning bug: trainer.global_step
101-
# is 1 off less than the actual step count. When on_train_batch_end
102-
# is called for the first time, the trainer.global_step should be 1,
103-
# rather than 0, since 1 update/step has already been done.
104-
#
105-
# When lightning fixes its bug, we will update this test to remove
106-
# the hack. # issue: 6997 in pytorch lightning
107-
self.assertAlmostEqual(gt[key], lv[key] - 1, 1)
108-
else:
109-
self.assertAlmostEqual(gt[key], lv[key], 1)
95+
self.assertAlmostEqual(gt[key], lv[key], 1)
11096

11197
# TODO: update test function with avg_loss
11298
@patch("mmf.common.test_reporter.PathManager.mkdirs")
@@ -145,12 +131,7 @@ def log_values(
145131
self.assertEqual(len(self.ground_truths), len(lightning_values))
146132
for gt, lv in zip(self.ground_truths, lightning_values):
147133
for key in ["num_updates", "max_updates"]:
148-
if key == "num_updates" and gt[key] == self.ground_truths[-1][key]:
149-
# to understand the reason of using lv[key] - 1 (intead of lv[key])
150-
# see comments in test_validation
151-
self.assertAlmostEqual(gt[key], lv[key] - 1, 1)
152-
else:
153-
self.assertAlmostEqual(gt[key], lv[key], 1)
134+
self.assertAlmostEqual(gt[key], lv[key], 1)
154135

155136
@patch("mmf.common.test_reporter.PathManager.mkdirs")
156137
@patch("torch.utils.tensorboard.SummaryWriter")

0 commit comments

Comments
 (0)