pytorch-lightning: OOM when implementing training_epoch_end

Bug description

See the toy example:

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, Dataset


class SomeModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 2)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

    def forward(self):
        batch_size, features = 10240, 24000
        x_hat = torch.randn(batch_size, features, requires_grad=True, device=self.device)
        x = torch.randn(batch_size, features, device=self.device)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return {
            "loss": loss,
            "x": x,
        }

    def training_step(self, batch, batch_idx):
        return self()

    def training_epoch_end(self, outputs) -> None:
        # do nothing
        ...


class SomeDataset(Dataset):
    def __len__(self):
        return 100000000

    def __getitem__(self, index):
        return "some sample"


train = SomeDataset()
val = SomeDataset()

model = SomeModel()
trainer = pl.Trainer(devices=1, accelerator="gpu", check_val_every_n_epoch=1, )
trainer.fit(model, DataLoader(train), DataLoader(val))

Run it, and we can get an OOM error.

How to reproduce the bug

see the above.

Error messages and logs

RuntimeError: CUDA out of memory.

Environment

PyTorch Lightning Version : 1.7.7

More info

How to fix it: the bug comes from the method advance of training_epoch_loop. WRONG

~~see https://github.com/Lightning-AI/lightning/blob/1.7.7/src/pytorch_lightning/loops/epoch/training_epoch_loop.py#L213~~

when we save the batch_end_outputs for use, we should detach every tensors in the output.

cc @carmocca @justusschock @rohitgr7

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 15 (13 by maintainers)

Most upvoted comments

This is expected. When you override training_epoch_end, we store all batch outputs for the hook.

We don’t want to detach because the user might want the grads.

If you don’t want to do this, but want an epoch end hook, use on_train_epoch_end which does not have this behaviour

So what about checking the signature of the overridden training_epoch_end? When users do not need the outputs, skip the L227.

Actually, I use training_epoch_end to do some logs as followings:

def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
    self.log("train_acc", self.train_accuracy.compute(), sync_dist=True)
    self.log("train_bleu_1", self.train_bleu_1.compute(), sync_dist=True)
    self.log("train_bleu_2", self.train_bleu_2.compute(), sync_dist=True)

@snknitin Using x.detach(), we will get OOM of GPU memory. Using x.cpu(), the training speed is slowed down and we will get OOM of CPU memory.

@snknitin Thanks for your attention. But the OOM issue does not come from the self.log("train_loss", loss). It comes from the returned value {"loss": loss, "x": x,}, and we cannot return loss.item() because pl needs non-detached loss to optimize models.

So in the return value it obviously needs to be loss for backprop. But in the log, you don’t need the whole object. in self.log("train_loss", loss.item()) can be done

Also in the return {"loss": loss, "x": x,} unless you’re doing something with x down the line, like metrics calculation, there is probably no need for it to be accumulating across batches on GPU. You might do x.cpu() or x.detach() . Looks like x is just the labels. This is probably the main culprit taking up most of the memory

@snknitin Thanks for your attention. But the OOM issue does not come from the self.log("train_loss", loss). It comes from the returned value {"loss": loss, "x": x,}, and we cannot return loss.item() because pl needs non-detached loss to optimize models.

@awaelchli It is my pleasure to make a contribution to pl. 😃