pytorch-lightning: Freeze on restore from checkpoint

🐛 Bug

Sometimes restoring from checkpoint causes the trainer to go into infinite cylce: 100% GPU usage, no progress.

Please reproduce using the BoringModel

I can’t reproduce in Colab, since I don’t have 2 GPUs there, but here is python code:

"""
# The Boring Model
Replicate a bug you experience, using this model.

[Remember! we're always available for support on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A)

---
## Setup env
"""

"""---
## Deps
"""
from pathlib import Path
from uuid import uuid4
from pytorch_lightning.plugins import DDPPlugin

from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
import os

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from shutil import rmtree
from pytorch_lightning.metrics.functional import accuracy
tmpdir = os.getcwd()

"""---
## Data
Random data is best for debugging. If you needs special tensor shapes or batch compositions or dataloaders, modify as needed
"""

# some other options for random data
from pl_bolts.datasets import RandomDataset, DummyDataset, RandomDictDataset

class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

num_samples = 10000

train = RandomDataset(32, num_samples)
train = DataLoader(train, batch_size=32)

val = RandomDataset(32, num_samples)
val = DataLoader(val, batch_size=32)

test = RandomDataset(32, num_samples)
test = DataLoader(test, batch_size=32)


import torch
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset

class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def training_epoch_end(self, outputs) -> None:
        torch.stack([x["loss"] for x in outputs]).mean()

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        self.log("loss", loss)
        return {"x": loss}

    def validation_epoch_end(self, outputs) -> None:
        torch.stack([x['x'] for x in outputs]).mean()

    def test_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        self.log('fake_test_acc', loss)
        return {"y": loss}

    def test_epoch_end(self, outputs) -> None:
        torch.stack([x["y"] for x in outputs]).mean()

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]

"""---
## Define the test
NOTE: in colab, set progress_bar_refresh_rate high or the screen will freeze because of the rapid tqdm update speed.
"""
ckpt_dir = Path("/tmp") / str(uuid4())
def _create_trainer(remove_checkpoint=True, **kwargs):
    checkpoint_callback = ModelCheckpoint(
        monitor="loss",
        mode="min",
        filename="{step:05d}",
        every_n_train_steps=30,
        save_top_k=1,
        save_last=True,
    )
    if remove_checkpoint:
        rmtree(ckpt_dir, ignore_errors=True)
    last_checkpoint_path = ckpt_dir / "lightning_logs" / "version_0" / "checkpoints" / "last.ckpt"
    return pl.Trainer(
        default_root_dir=ckpt_dir,
        gpus=2,
        accelerator="ddp",
        auto_select_gpus=True,
        max_epochs=1,
        progress_bar_refresh_rate=20,
        sync_batchnorm=True,
        plugins=DDPPlugin(find_unused_parameters=False),
        resume_from_checkpoint=last_checkpoint_path if last_checkpoint_path.exists() else None,
        callbacks=[checkpoint_callback],
        **kwargs
    )

def test_normal():
    # init model
    model = BoringModel()

    # Initialize a trainer
    trainer = _create_trainer(precision=32)

    # Train the model ⚡
    trainer.fit(model, train, val)

    trainer.test(test_dataloaders=test)

def test_another():
    model = BoringModel()

    trainer = _create_trainer(precision=32, remove_checkpoint=False) # If remove_checkpoint is True, it works fine

    # Train the model ⚡
    trainer.fit(model, train, val)
    trainer.test(test_dataloaders=test)

def test_fp16():
    model = BoringModel()
    trainer16 = _create_trainer(precision=16)
    trainer16.fit(model, train, val)

test_normal()
test_another()
test_fp16()

See the comment in test_another to toggle the bug. Also, if you run any function on its own it will finish fine.

To Reproduce

See above

Expected behavior

Either run or fail with error.

Environment

Note: Bugs with code are solved faster ! Colab Notebook should be made public !

  • IDE: None, terminal, python from conda.

  • Colab Notebook: N/A

You can get the script and run it with:

wget https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/tests/collect_env_details.py
# For security purposes, please check the contents of collect_env_details.py before running it.
python collect_env_details.py
  • PyTorch Version (e.g., 1.0): 1.8.1+cu102
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.8.10
  • CUDA/cuDNN version: Not sure, I don’t think I installed it, nvidia-smi says 11.2
  • GPU models and configuration: 2 * 2080ti
  • Any other relevant information:

Additional context

I found this during writing tests with pytest, but managed to reproduce without.

P.S. In Trainer doc-string it says that if file at resume_from_checkpoint doesn’t exist, it will start from scratch, but in reality it throws an error. Need to correct doc-string or code.

P.P.S. Slack link in colab notebook also doesn’t work.

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Reactions: 1
  • Comments: 21 (4 by maintainers)

Most upvoted comments