pytorch-lightning: ModelCheckpointCallback doesnt delete previous -last.ckpt while saving recent one
š Bug
ModelCheckpointCallback doesnāt delete previous -last.ckpt while saving recent one. This is the snapshot of the checkpoints directory for (save_top_k=1 and save_last=True)
test_last_ckpt_issueāval_loss=1.8490-epoch=13.ckpt test_last_ckpt_issueāval_loss=1.8490-epoch=13-last.ckpt test_last_ckpt_issueāval_loss=6.0828-epoch=6-last.ckpt
As can be seen before saving epoch=7-last.ckpt it didnt delete *6-last.ckpt
To Reproduce
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
# from pytorch_lightning.strategies import DDPStrategy
class LitAutoEncoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 64),
nn.ReLU(),
nn.Linear(64, 3))
self.decoder = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 28 * 28))
def forward(self, x):
embedding = self.encoder(x)
return embedding
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def training_step(self, train_batch, batch_idx):
x, y = train_batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log('train_loss', loss)
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log('val_loss', loss)
def build_dataset():
# data
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
train_loader = DataLoader(mnist_train, batch_size=32,num_workers=20)
val_loader = DataLoader(mnist_val, batch_size=32,num_workers=20)
return train_loader,val_loader
if __name__ == '__main__':
train_loader, val_loader = build_dataset()
# model
model = LitAutoEncoder()
#Callback
checkpoint_callback = ModelCheckpoint(
monitor='val_loss',
dirpath='./test_last_ckpt',
filename='test_last_ckpt-{val_loss:.2f}-{epoch}',
save_last=True,
save_top_k=4,
)
ModelCheckpoint.CHECKPOINT_NAME_LAST='test_last_ckpt-{val_loss:.2f}-{epoch}'+'-last'
# training
trainer = pl.Trainer(gpus=-1, num_nodes=1,max_epochs=5,max_steps=-1,
strategy='ddp',
logger=False,
benchmark=True,
callbacks=[checkpoint_callback]
)
trainer.fit(model, train_loader, val_loader)
Expected behavior
Should only save one latest -last.ckpt while deleting all previous *-last.ckpt
Environment
- CUDA: - GPU: - Quadro GV100 - available: True - version: 10.2
- Packages: - numpy: 1.20.3 - pyTorch_debug: False - pyTorch_version: 1.10.1+cu102 - pytorch-lightning: 1.5.8 - tqdm: 4.62.3
- System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.9.7 - version: #45~20.04.1-Ubuntu SMP Wed Nov 10 10:20:10 UTC 2021
cc @tchaton @rohitgr7 @awaelchli @carmocca @ninginthecloud @jjenniferdai
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 21 (21 by maintainers)
Thanks! This was helpful. We can use the a temporary variable.
I updated your PR and added a test
Adapting the script you provided:
script :
This gets fixed with the PR I provided but it has a problem of a checkpoint being deleted in case the interruption happens at that movement. One way is, we could copy last_model_path to a temp variable and save first then delete prev checkpoint.
@nithinraok When you run your script for the second time, you are not passing an existing checkpoint to the trainer to resume:
.fit(..., ckpt_path="your_checkpoint.ckpt")This means that the
ModelCheckpointstate does not get restored so thelast_model_pathfrom the previous run will not get deleted.Considering that, training gets run again from the beginning and new checkpoints are generated:
*-v1.ckpt).epoch=4-last.ckptyou see after the second run is the last checkpoint of the second run which overwrote the last checkpoint of the first run.Thereās a feature request to add version suffix to ālastā checkpoints #5030
So from my debugging, this is working as expected.