pytorch-lightning: Bug with isolate_rng "Cannot re-initialize CUDA in forked subprocess"

Bug description

I was using isolate_rng in pytorch dataset for some augmentations in pytorch lightning 1.7.1 but after update to latest 1.8.6 I experience error now.

RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

How to reproduce the bug

from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torch.utils.data import Dataset
import torch
from torch import nn
import torch.nn.functional as F

class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

class DummyDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        
    def __getitem__(self,  index: int):
        with pl.utilities.seed.isolate_rng():
            return [torch.ones(28,28),torch.ones(28,28)] 
    
    def __len__(self):
        return 10

dataset = DummyDataset()

trainer = pl.Trainer(logger=None, accelerator='gpu')
trainer.fit(LitAutoEncoder(), DataLoader(dataset, num_workers=2))

if you remove num_workers=2 the bug will go away but I need multiple workers.

Error messages and logs

Traceback (most recent call last):
  File "check_bug_in_pl.py", line 47, in <module>
    trainer.fit(LitAutoEncoder(), DataLoader(dataset, num_workers=2))
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 603, in fit
    call._call_and_handle_interrupt(
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 645, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1098, in _run
    results = self._run_stage()
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1177, in _run_stage
    self._run_train()
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1200, in _run_train
    self.fit_loop.run()
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 188, in advance
    batch = next(data_fetcher)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/utilities/fetching.py", line 184, in __next__
    return self.fetching_function()
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/utilities/fetching.py", line 265, in fetching_function
    self._fetch_next_batch(self.dataloader_iter)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/utilities/fetching.py", line 280, in _fetch_next_batch
    batch = next(iterator)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/supporters.py", line 568, in __next__
    return self.request_next_batch(self.loader_iters)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/supporters.py", line 580, in request_next_batch
    return apply_to_collection(loader_iters, Iterator, next)
  File "/home/jovyan/.local/lib/python3.8/site-packages/lightning_utilities/core/apply_func.py", line 47, in apply_to_collection
    return function(data, *args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 530, in __next__
    data = self._next_data()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1224, in _next_data
    return self._process_data(data)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1250, in _process_data
    data.reraise()
  File "/usr/local/lib/python3.8/dist-packages/torch/_utils.py", line 457, in reraise
    raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "check_bug_in_pl.py", line 38, in __getitem__
    with pl.utilities.seed.isolate_rng():
  File "/usr/lib/python3.8/contextlib.py", line 113, in __enter__
    return next(self.gen)
  File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/utilities/seed.py", line 42, in isolate_rng
    states = _collect_rng_states()
  File "/usr/local/lib/python3.8/dist-packages/lightning_lite/utilities/seed.py", line 112, in _collect_rng_states
    "torch.cuda": torch.cuda.get_rng_state_all(),
  File "/usr/local/lib/python3.8/dist-packages/torch/cuda/random.py", line 39, in get_rng_state_all
    results.append(get_rng_state(i))
  File "/usr/local/lib/python3.8/dist-packages/torch/cuda/random.py", line 22, in get_rng_state
    _lazy_init()
  File "/usr/local/lib/python3.8/dist-packages/torch/cuda/__init__.py", line 206, in _lazy_init
    raise RuntimeError(
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

Epoch 0:   0%|          | 0/10 [00:00<?, ?it/s]   

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0): 1.8.6
#- Lightning App Version (e.g., 0.5.2): 
#- PyTorch Version (e.g., 1.10): 1.11.0+cu102
#- Python version (e.g., 3.9):  3.8.10
#- OS (e.g., Linux): ubuntu
#- CUDA/cuDNN version: 10.2
#- GPU models and configuration: Tesla T4 
#- How you installed Lightning: pip
#- Running environment of LightningApp : local

More info

No response

cc @awaelchli

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 17 (14 by maintainers)

Most upvoted comments

Hi folks, just hit this same issue in lightning 1.9.0, but via a different path. I’m using fault tolerant training on a map-style dataset, using the “ddp” strategy. The worker loop ends up calling state_dict on CaptureMapDataset which makes a call to _collect_rng_states() and ends up triggering the same code path as stated above, leading to the same crash.

Original Traceback (most recent call last):
  File ".../torch/utils/data/_utils/worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File ".../torch/utils/data/_utils/fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File ".../pytorch_lightning/utilities/auto_restart.py", line 418, in _capture_metadata_collate
    metadata = dataset.state_dict()
  File ".../pytorch_lightning/utilities/auto_restart.py", line 275, in state_dict
    return {self.worker_id: {"rng_states": _collect_rng_states()}}
  File ".../lightning_fabric/utilities/seed.py", line 111, in _collect_rng_states
    "torch.cuda": torch.cuda.get_rng_state_all(),
  File ".../torch/cuda/random.py", line 39, in get_rng_state_all
    results.append(get_rng_state(i))
  File ".../torch/cuda/random.py", line 22, in get_rng_state
    _lazy_init()
  File ".../torch/cuda/__init__.py", line 207, in _lazy_init
    raise RuntimeError(
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

Would you prefer a new bug for this one, or is there already sufficient info on this ticket to propose a way of solving this?

@sergeevii123 I think it would be fine to add an argument to the function include_cuda=True|False.

One could detect the start method yes, but that alone doesn’t tell us whether we are running the function in a forked process or the main one, which I don’t know how to detect like that. So, if it’s between the two, I would prefer the optional argument.

Agree, sounds good. I can make PR to add the argument

Maybe this starting method could be checked in the isolate_rng so it will not result in the error

Adding an argument to exclude the cuda seed would be possible, but let me first check what the root cause is and whether we can fix it properly.

@sergeevii123 Thanks for the reproducible code. Did you try wrapping the code with this guard?

if __name__ == "__main__": 
    dataset = DummyDataset()
    trainer = pl.Trainer(logger=None, accelerator='gpu')
    trainer.fit(LitAutoEncoder(), DataLoader(dataset, num_workers=2))

I will take a closer look otherwise.

Tried but result is same. It works if you remove collecting cuda state from _collect_rng_states function and skip setting cuda state in _set_rng_states. Maybe this collection and setting could be skipped if start method of trainer is not spawn