pytorch-lightning: [bug]Resuming From Checkpoint for FP16 failure (Single GPU)

🐛 Bug

Please reproduce using the BoringModel

from typing import Dict, Any
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

from torch.utils.data import DataLoader
import logging
import os
import torch
import torch.nn as nn
import torch.optim as optim

import pytorch_lightning as pl

from torch.optim import AdamW

class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

class ToyTask(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.loss_fn = nn.MSELoss()
    
    def setup(self, stage: str):
        if stage == "test":
            return
    
        self.model = ToyModel()
        self.optimizer = AdamW(self.model.parameters(), lr=0.001, betas=[0.9, 0.999], eps=1.0e-08, weight_decay=0, amsgrad=False)

    def forward(self, x):
        return self.model(x)
        
    def training_step(self, batch, batch_idx):
        targets = self.forward(batch["model_input"])
        loss = self.loss_fn(targets, batch["label"]) 

        # Log loss results per train step and per epoch
        self.log("loss", loss)

        # Tell Lightning to minimize loss
        return loss
            
    def configure_optimizers(self):
        return self.optimizer
    
    def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        self.setup("fit")

setup training

task = ToyTask()

dataset = [
    {"model_input": torch.randn(20, 10), "label": torch.randn(20, 5)} for _ in range(10)
]

train_dataloader = DataLoader(dataset, batch_size=None)
val_dataloader = DataLoader(dataset, batch_size=None)

    
model_checkpoint = ModelCheckpoint(
   save_last=True,
   every_n_val_epochs=1,
)

trainer = pl.Trainer(
    gpus=1,
    precision=16,
    max_epochs=3,
    progress_bar_refresh_rate=100,
    log_gpu_memory=None,
    reload_dataloaders_every_epoch=True,
    limit_train_batches=10,
    limit_val_batches=10,
    limit_test_batches=10,
    callbacks=[model_checkpoint],
)

results = trainer.fit(task, train_dataloader)


resume from checkpoint

trainer = pl.Trainer(
    gpus=1,
    precision=16,
    max_epochs=4,
    reload_dataloaders_every_epoch=True,
    limit_train_batches=10,
    limit_val_batches=10,
    limit_test_batches=10,
    callbacks=[model_checkpoint],
    resume_from_checkpoint=model_checkpoint.last_model_path,
)
trainer.fit(task, train_dataloader) <--- this is where will fail

breaks at the first training step: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/precision/native_amp.py#L96

complains about

/mnt/xarfuse/uid-26337/a9b1f2c7-seed-nspid4026533638_cgpid18793775-ns-4026533627/torch/cuda/amp/grad_scaler.py in step(self, optimizer, *args, **kwargs)
    334             self.unscale_(optimizer)
    335 
--> 336         assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
    337 
    338         retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)

AssertionError: No inf checks were recorded for this optimizer.

Expected behavior

Expected to resume training.

Environment

Note: Bugs with code are solved faster ! Colab Notebook should be made public !

You can get the script and run it with:

wget https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/tests/collect_env_details.py
# For security purposes, please check the contents of collect_env_details.py before running it.
python collect_env_details.py
  • PyTorch Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Reactions: 2
  • Comments: 28 (23 by maintainers)

Most upvoted comments

  1. split loading states to more granular stages (like model states, optimizer states, etc. they do not necessariliy get restored at the same time)

and

But there is usecase where we would like to restore model weights after setup directly and then call configure sharded model, this is more like controlled by specific plugin. For example the on-going FSDP, we would like to load model states for unwrapped model and later it will be wrapped in FSDP at configuring sharded model stage. As the current flow of restoring does not allow us to do this.

I’m working on #7652 to enable that. This way, you don’t have to call setup() manually in on_load_checkpoint()

  1. should plugin control more on loading logics

No, I don’t think so. The plugins already have many responsibilities. I believe we should aim to find a way to restore model and trainer state that fits all plugins well.

Thanks for testing this!

No, by all standards this is not a good approach for sure. This workaround will hopefully unblock you but it means we need to discuss if and how we want to move the restore call to an earlier point in the fit call. This is complicated, because apparently recently some changes were introduced that now let the training type plugin call the model hooks for restoring … so care needs to be taken when splitting this up.

@dave-epstein the model can be built in the setup hook. the order is the following:


init

setup hook

restore weights

configure sharded hook

train

this way the weights can be loaded before the model gets wrapped. Does this help?

Yeah, just saw the documentation shows this use case as well. It seems to work.

@hudeven I implemented the changes here #7652. In summary, the restoring will happen like this:

In Trainer.fit:

model.setup("fit")

# restore model weights
checkpoint_connector.restore_model()

model.configure_sharded_model()
...
accelerator.setup()
...
pre_dispatch()

# restore optimizers, loop, etc.
checkpoint_connector.restore_trainer_state()

Thanks!!!

I think might need to be postponed to after pre-dispatch (in the case optimizer is setup in pre-dispatch stage)

good catch! yes then it should restore after pre_dispatch!

I think this is dependent on TrainingTypePlugin (whether the model is instantiated before configure_sharded_model or on configure_sharded_model).

Yes, in theory we can have the training plugin decide to restore before or after (nice idea!). We would however sacrifice on a consistent hook call order #7740 , so depends if we are ok with making an exception here. Well, there would be no way around if we want to allow shifting the layer instantiation.

Wanted to give an update.

In #7652 I’m loading the model weights in this order:

model.setup("fit")  # trainer calls setup hook

# model weights get restored as soon as model is setup
restore_model() # also calls model.on_load_checkpoint()

call_configure_sharded_model(model)
accelerator.setup(model) 

restore_training_state()  # restore optimizer, precision, loop progress etc.

so after the setup hook is called, but before the accelerator setup. and optimizer gets restored after the accelerator setup.

Does that make sense?

Q: should call_configure_sharded_model happen before or after model weights get restored? Based on the comments from @shuyingsunshine21 on this PR it sounds like we want to restore weights before that hook.

What Lightning version are you using here, master?

it might not be ideal, but it is possible that we resetting up after lightning module restored weights, we could call resetting up here before loading rest of the pieces:

To simulate that, I tried the following workaround and it seems to work.

    def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        self.setup("fit")
        self.trainer.accelerator.setup(self.trainer, self)

Can you check if that works for you too?

In our trainer when calling fit, the sequence is the following:

  1. init ddp connection
  2. setup() hook called -> you build your model here
  3. model gets wrapped with DDP
  4. precision plugin sets up model and optimizer (amp conversion)
  5. restore_model() -> calls on_load_checkpoint -> You rebuild model once again and erase weights
  6. weights get restored

If you rebuild your model in on_load_checkpoint, the precision setup is skipped/undone.

One solution from our side could be to move restore_weights() to directly after step 2. I’m hesitating however, because I’m not sure yet if this could have undesired side-effects? Any thoughts?

sorry, just edited my comment. “During loading from checkpoint, it requires model object to be created before loading weights. So we call setup(“fit”) to create model in on_load_checkpiont()”

@awaelchli @tchaton thanks for looking into it! we have to init model in setup(“fit”) due to:

  1. our model depends on transform output(e.g. reading training data and passing num_classes as out_dim of model), and the transform is not pickable. Putting them in LightningModule.__init__() will make LightningModule not pickable. Our internal system requires LightningModule to be pickle-able
  2. some pre-train model weights are very large and need to be download from remote sources. So we want to defer the model loading, instead of doing it in __init__()

During loading from checkpoint, it requires model object to be created before loading weights. So we call setup(“fit”) to create model in on_load_checkpiont()