pytorch-lightning: ModelCheckpointCallback doesnt delete previous -last.ckpt while saving recent one

šŸ› Bug

ModelCheckpointCallback doesn’t delete previous -last.ckpt while saving recent one. This is the snapshot of the checkpoints directory for (save_top_k=1 and save_last=True)

test_last_ckpt_issue–val_loss=1.8490-epoch=13.ckpt test_last_ckpt_issue–val_loss=1.8490-epoch=13-last.ckpt test_last_ckpt_issue–val_loss=6.0828-epoch=6-last.ckpt

As can be seen before saving epoch=7-last.ckpt it didnt delete *6-last.ckpt

To Reproduce


import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
# from pytorch_lightning.strategies import DDPStrategy

class LitAutoEncoder(pl.LightningModule):
	def __init__(self):
		super().__init__()
		self.encoder = nn.Sequential(
      nn.Linear(28 * 28, 64),
      nn.ReLU(),
      nn.Linear(64, 3))
		self.decoder = nn.Sequential(
      nn.Linear(3, 64),
      nn.ReLU(),
      nn.Linear(64, 28 * 28))

	def forward(self, x):
		embedding = self.encoder(x)
		return embedding

	def configure_optimizers(self):
		optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
		return optimizer

	def training_step(self, train_batch, batch_idx):
		x, y = train_batch
		x = x.view(x.size(0), -1)
		z = self.encoder(x)    
		x_hat = self.decoder(z)
		loss = F.mse_loss(x_hat, x)
		self.log('train_loss', loss)
		return loss

	def validation_step(self, val_batch, batch_idx):
		x, y = val_batch
		x = x.view(x.size(0), -1)
		z = self.encoder(x)
		x_hat = self.decoder(z)
		loss = F.mse_loss(x_hat, x)
		self.log('val_loss', loss)

def build_dataset():
	# data
	dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
	mnist_train, mnist_val = random_split(dataset, [55000, 5000])
	train_loader = DataLoader(mnist_train, batch_size=32,num_workers=20)
	val_loader = DataLoader(mnist_val, batch_size=32,num_workers=20)

	return train_loader,val_loader

if __name__ == '__main__':

	train_loader, val_loader = build_dataset()
	# model
	model = LitAutoEncoder()

	#Callback
	checkpoint_callback = ModelCheckpoint(
	monitor='val_loss',
	dirpath='./test_last_ckpt',
	filename='test_last_ckpt-{val_loss:.2f}-{epoch}',
	save_last=True,
	save_top_k=4,
	)
	ModelCheckpoint.CHECKPOINT_NAME_LAST='test_last_ckpt-{val_loss:.2f}-{epoch}'+'-last'

	# training
	trainer = pl.Trainer(gpus=-1, num_nodes=1,max_epochs=5,max_steps=-1,
	strategy='ddp',
	logger=False,
	benchmark=True,
	callbacks=[checkpoint_callback]
	)
	
	trainer.fit(model, train_loader, val_loader)

Expected behavior

Should only save one latest -last.ckpt while deleting all previous *-last.ckpt

Environment

  • CUDA: - GPU: - Quadro GV100 - available: True - version: 10.2
  • Packages: - numpy: 1.20.3 - pyTorch_debug: False - pyTorch_version: 1.10.1+cu102 - pytorch-lightning: 1.5.8 - tqdm: 4.62.3
  • System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.9.7 - version: #45~20.04.1-Ubuntu SMP Wed Nov 10 10:20:10 UTC 2021

cc @tchaton @rohitgr7 @awaelchli @carmocca @ninginthecloud @jjenniferdai

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 21 (21 by maintainers)

Most upvoted comments

Thanks! This was helpful. We can use the a temporary variable.

I updated your PR and added a test

Adapting the script you provided:

script :
import os
import shutil

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

class ModifiedModel(BoringModel):
    def __init__(self):
        super().__init__()
    
    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        if self.global_step == 80:
            raise SystemExit(1)
        return loss 

def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    model = ModifiedModel()

    # remove data
    dir = f"{os.getcwd()}/debug"
    shutil.rmtree(dir, ignore_errors=True)

    # configure model checkpoint
    model_checkpoint = ModelCheckpoint(
        dirpath=dir, filename="{epoch}", monitor="valid_loss", save_last=True, save_top_k=2
    )
    model_checkpoint.CHECKPOINT_NAME_LAST = "{epoch}-last"

    # first run
    trainer = Trainer(
        max_epochs=6, callbacks=model_checkpoint, enable_model_summary=False, enable_progress_bar=False, logger=False
    )
    try:
        trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    except:
        print("system exit caught",os.listdir(dir))

    actual = set(os.listdir(dir))
    assert actual == {
        "epoch=1.ckpt",
        "epoch=0.ckpt",
        "epoch=1-last.ckpt",
    }

    model = BoringModel()
    # second run
    trainer = Trainer(
        max_epochs=6, callbacks=model_checkpoint, enable_model_summary=False, enable_progress_bar=False, logger=False
    )

    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data, ckpt_path=f"{dir}/epoch=1-last.ckpt")

    actual = set(os.listdir(dir))
    assert actual == {
        "epoch=4.ckpt",
        "epoch=5.ckpt",
        "epoch=5-last.ckpt",
    }


if __name__ == "__main__":
    run()

This gets fixed with the PR I provided but it has a problem of a checkpoint being deleted in case the interruption happens at that movement. One way is, we could copy last_model_path to a temp variable and save first then delete prev checkpoint.

@nithinraok When you run your script for the second time, you are not passing an existing checkpoint to the trainer to resume: .fit(..., ckpt_path="your_checkpoint.ckpt")

This means that the ModelCheckpoint state does not get restored so the last_model_path from the previous run will not get deleted.

Considering that, training gets run again from the beginning and new checkpoints are generated:

  • In the case of normal checkpoints, a version suffix gets appended (*-v1.ckpt).
  • For last checkpoints, they get overwritten. The epoch=4-last.ckpt you see after the second run is the last checkpoint of the second run which overwrote the last checkpoint of the first run.

There’s a feature request to add version suffix to ā€œlastā€ checkpoints #5030

So from my debugging, this is working as expected.