pytorch-lightning: [bug]Resuming From Checkpoint for FP16 failure (Single GPU)
🐛 Bug
Please reproduce using the BoringModel
from typing import Dict, Any
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torch.utils.data import DataLoader
import logging
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
from torch.optim import AdamW
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
class ToyTask(pl.LightningModule):
def __init__(self):
super().__init__()
self.loss_fn = nn.MSELoss()
def setup(self, stage: str):
if stage == "test":
return
self.model = ToyModel()
self.optimizer = AdamW(self.model.parameters(), lr=0.001, betas=[0.9, 0.999], eps=1.0e-08, weight_decay=0, amsgrad=False)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
targets = self.forward(batch["model_input"])
loss = self.loss_fn(targets, batch["label"])
# Log loss results per train step and per epoch
self.log("loss", loss)
# Tell Lightning to minimize loss
return loss
def configure_optimizers(self):
return self.optimizer
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.setup("fit")
setup training
task = ToyTask()
dataset = [
{"model_input": torch.randn(20, 10), "label": torch.randn(20, 5)} for _ in range(10)
]
train_dataloader = DataLoader(dataset, batch_size=None)
val_dataloader = DataLoader(dataset, batch_size=None)
model_checkpoint = ModelCheckpoint(
save_last=True,
every_n_val_epochs=1,
)
trainer = pl.Trainer(
gpus=1,
precision=16,
max_epochs=3,
progress_bar_refresh_rate=100,
log_gpu_memory=None,
reload_dataloaders_every_epoch=True,
limit_train_batches=10,
limit_val_batches=10,
limit_test_batches=10,
callbacks=[model_checkpoint],
)
results = trainer.fit(task, train_dataloader)
resume from checkpoint
trainer = pl.Trainer(
gpus=1,
precision=16,
max_epochs=4,
reload_dataloaders_every_epoch=True,
limit_train_batches=10,
limit_val_batches=10,
limit_test_batches=10,
callbacks=[model_checkpoint],
resume_from_checkpoint=model_checkpoint.last_model_path,
)
trainer.fit(task, train_dataloader) <--- this is where will fail
breaks at the first training step: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/precision/native_amp.py#L96
complains about
/mnt/xarfuse/uid-26337/a9b1f2c7-seed-nspid4026533638_cgpid18793775-ns-4026533627/torch/cuda/amp/grad_scaler.py in step(self, optimizer, *args, **kwargs)
334 self.unscale_(optimizer)
335
--> 336 assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
337
338 retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
AssertionError: No inf checks were recorded for this optimizer.
Expected behavior
Expected to resume training.
Environment
Note: Bugs with code are solved faster ! Colab Notebook should be made public !
-
IDE: Please, use our python bug_report_model.py template. -
Colab Notebook: Please copy and paste the output from our environment collection script (or fill out the checklist below manually).
You can get the script and run it with:
wget https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/tests/collect_env_details.py
# For security purposes, please check the contents of collect_env_details.py before running it.
python collect_env_details.py
- PyTorch Version (e.g., 1.0):
- OS (e.g., Linux):
- How you installed PyTorch (
conda,pip, source): - Build command you used (if compiling from source):
- Python version:
- CUDA/cuDNN version:
- GPU models and configuration:
- Any other relevant information:
Additional context
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Reactions: 2
- Comments: 28 (23 by maintainers)
and
I’m working on #7652 to enable that. This way, you don’t have to call
setup()manually inon_load_checkpoint()No, I don’t think so. The plugins already have many responsibilities. I believe we should aim to find a way to restore model and trainer state that fits all plugins well.
Thanks for testing this!
No, by all standards this is not a good approach for sure. This workaround will hopefully unblock you but it means we need to discuss if and how we want to move the restore call to an earlier point in the fit call. This is complicated, because apparently recently some changes were introduced that now let the training type plugin call the model hooks for restoring … so care needs to be taken when splitting this up.
Yeah, just saw the documentation shows this use case as well. It seems to work.
Thanks @awaelchli
@hudeven I implemented the changes here #7652. In summary, the restoring will happen like this:
In Trainer.fit:
Thanks!!!
good catch! yes then it should restore after pre_dispatch!
Yes, in theory we can have the training plugin decide to restore before or after (nice idea!). We would however sacrifice on a consistent hook call order #7740 , so depends if we are ok with making an exception here. Well, there would be no way around if we want to allow shifting the layer instantiation.
Wanted to give an update.
In #7652 I’m loading the model weights in this order:
so after the setup hook is called, but before the accelerator setup. and optimizer gets restored after the accelerator setup.
Does that make sense?
Q: should
call_configure_sharded_modelhappen before or after model weights get restored? Based on the comments from @shuyingsunshine21 on this PR it sounds like we want to restore weights before that hook.What Lightning version are you using here, master?
To simulate that, I tried the following workaround and it seems to work.
Can you check if that works for you too?
In our trainer when calling fit, the sequence is the following:
If you rebuild your model in
on_load_checkpoint, the precision setup is skipped/undone.One solution from our side could be to move restore_weights() to directly after step 2. I’m hesitating however, because I’m not sure yet if this could have undesired side-effects? Any thoughts?
sorry, just edited my comment. “During loading from checkpoint, it requires model object to be created before loading weights. So we call setup(“fit”) to create model in on_load_checkpiont()”
@awaelchli @tchaton thanks for looking into it! we have to init model in setup(“fit”) due to:
LightningModule.__init__()will make LightningModule not pickable. Our internal system requires LightningModule to be pickle-able__init__()During loading from checkpoint, it requires model object to be created before loading weights. So we call setup(“fit”) to create model in on_load_checkpiont()