pytorch-lightning: Bug with isolate_rng "Cannot re-initialize CUDA in forked subprocess"
Bug description
I was using isolate_rng in pytorch dataset for some augmentations in pytorch lightning 1.7.1 but after update to latest 1.8.6 I experience error now.
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
How to reproduce the bug
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torch.utils.data import Dataset
import torch
from torch import nn
import torch.nn.functional as F
class LitAutoEncoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
# training_step defines the train loop. It is independent of forward
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
class DummyDataset(Dataset):
def __init__(self) -> None:
super().__init__()
def __getitem__(self, index: int):
with pl.utilities.seed.isolate_rng():
return [torch.ones(28,28),torch.ones(28,28)]
def __len__(self):
return 10
dataset = DummyDataset()
trainer = pl.Trainer(logger=None, accelerator='gpu')
trainer.fit(LitAutoEncoder(), DataLoader(dataset, num_workers=2))
if you remove num_workers=2 the bug will go away but I need multiple workers.
Error messages and logs
Traceback (most recent call last):
File "check_bug_in_pl.py", line 47, in <module>
trainer.fit(LitAutoEncoder(), DataLoader(dataset, num_workers=2))
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 603, in fit
call._call_and_handle_interrupt(
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 645, in _fit_impl
self._run(model, ckpt_path=self.ckpt_path)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1098, in _run
results = self._run_stage()
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1177, in _run_stage
self._run_train()
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1200, in _run_train
self.fit_loop.run()
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/loop.py", line 199, in run
self.advance(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance
self._outputs = self.epoch_loop.run(self._data_fetcher)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/loop.py", line 199, in run
self.advance(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 188, in advance
batch = next(data_fetcher)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/utilities/fetching.py", line 184, in __next__
return self.fetching_function()
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/utilities/fetching.py", line 265, in fetching_function
self._fetch_next_batch(self.dataloader_iter)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/utilities/fetching.py", line 280, in _fetch_next_batch
batch = next(iterator)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/supporters.py", line 568, in __next__
return self.request_next_batch(self.loader_iters)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/supporters.py", line 580, in request_next_batch
return apply_to_collection(loader_iters, Iterator, next)
File "/home/jovyan/.local/lib/python3.8/site-packages/lightning_utilities/core/apply_func.py", line 47, in apply_to_collection
return function(data, *args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 530, in __next__
data = self._next_data()
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1224, in _next_data
return self._process_data(data)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 1250, in _process_data
data.reraise()
File "/usr/local/lib/python3.8/dist-packages/torch/_utils.py", line 457, in reraise
raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "check_bug_in_pl.py", line 38, in __getitem__
with pl.utilities.seed.isolate_rng():
File "/usr/lib/python3.8/contextlib.py", line 113, in __enter__
return next(self.gen)
File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/utilities/seed.py", line 42, in isolate_rng
states = _collect_rng_states()
File "/usr/local/lib/python3.8/dist-packages/lightning_lite/utilities/seed.py", line 112, in _collect_rng_states
"torch.cuda": torch.cuda.get_rng_state_all(),
File "/usr/local/lib/python3.8/dist-packages/torch/cuda/random.py", line 39, in get_rng_state_all
results.append(get_rng_state(i))
File "/usr/local/lib/python3.8/dist-packages/torch/cuda/random.py", line 22, in get_rng_state
_lazy_init()
File "/usr/local/lib/python3.8/dist-packages/torch/cuda/__init__.py", line 206, in _lazy_init
raise RuntimeError(
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
Epoch 0: 0%| | 0/10 [00:00<?, ?it/s]
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0): 1.8.6
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 1.10): 1.11.0+cu102
#- Python version (e.g., 3.9): 3.8.10
#- OS (e.g., Linux): ubuntu
#- CUDA/cuDNN version: 10.2
#- GPU models and configuration: Tesla T4
#- How you installed Lightning: pip
#- Running environment of LightningApp : local
More info
No response
cc @awaelchli
About this issue
- Original URL
- State: closed
- Created a year ago
- Comments: 17 (14 by maintainers)
Hi folks, just hit this same issue in lightning 1.9.0, but via a different path. I’m using fault tolerant training on a map-style dataset, using the “ddp” strategy. The worker loop ends up calling
state_dictonCaptureMapDatasetwhich makes a call to_collect_rng_states()and ends up triggering the same code path as stated above, leading to the same crash.Would you prefer a new bug for this one, or is there already sufficient info on this ticket to propose a way of solving this?
Agree, sounds good. I can make PR to add the argument
Maybe this starting method could be checked in the
isolate_rngso it will not result in the errorAdding an argument to exclude the cuda seed would be possible, but let me first check what the root cause is and whether we can fix it properly.
Tried but result is same. It works if you remove collecting cuda state from
_collect_rng_statesfunction and skip setting cuda state in _set_rng_states. Maybe this collection and setting could be skipped if start method of trainer is not spawn