pytorch-lightning: OOM when implementing training_epoch_end
Bug description
See the toy example:
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset
class SomeModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 2)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
def forward(self):
batch_size, features = 10240, 24000
x_hat = torch.randn(batch_size, features, requires_grad=True, device=self.device)
x = torch.randn(batch_size, features, device=self.device)
loss = F.mse_loss(x_hat, x)
self.log("train_loss", loss)
return {
"loss": loss,
"x": x,
}
def training_step(self, batch, batch_idx):
return self()
def training_epoch_end(self, outputs) -> None:
# do nothing
...
class SomeDataset(Dataset):
def __len__(self):
return 100000000
def __getitem__(self, index):
return "some sample"
train = SomeDataset()
val = SomeDataset()
model = SomeModel()
trainer = pl.Trainer(devices=1, accelerator="gpu", check_val_every_n_epoch=1, )
trainer.fit(model, DataLoader(train), DataLoader(val))
Run it, and we can get an OOM error.
How to reproduce the bug
see the above.
Error messages and logs
RuntimeError: CUDA out of memory.
Environment
PyTorch Lightning Version : 1.7.7
More info
How to fix it:
the bug comes from the method WRONGadvance of training_epoch_loop.
when we save the batch_end_outputs for use, we should detach every tensors in the output.
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 15 (13 by maintainers)
This is expected. When you override
training_epoch_end, we store all batch outputs for the hook.We don’t want to detach because the user might want the grads.
If you don’t want to do this, but want an epoch end hook, use
on_train_epoch_endwhich does not have this behaviourSo what about checking the signature of the overridden
training_epoch_end? When users do not need theoutputs, skip the L227.Actually, I use
training_epoch_endto do some logs as followings:@snknitin Using
x.detach(), we will get OOM of GPU memory. Usingx.cpu(), the training speed is slowed down and we will get OOM of CPU memory.So in the return value it obviously needs to be loss for backprop. But in the log, you don’t need the whole object. in
self.log("train_loss", loss.item())can be doneAlso in the
return {"loss": loss, "x": x,}unless you’re doing something with x down the line, like metrics calculation, there is probably no need for it to be accumulating across batches on GPU. You might dox.cpu()orx.detach(). Looks like x is just the labels. This is probably the main culprit taking up most of the memory@snknitin Thanks for your attention. But the OOM issue does not come from the
self.log("train_loss", loss). It comes from the returned value{"loss": loss, "x": x,}, and we cannot returnloss.item()because pl needs non-detached loss to optimize models.@awaelchli It is my pleasure to make a contribution to pl. 😃