pytorch-lightning: intermittent SIGSEGV segfault when plotting during `training_step`
๐ Bug
If I create specific, simple matplotlib plots during training_step then I get a segfault after training for several batches.
Specifically, if I call plot_data() from within training_step, then I get a segfault:
def plot_data():
"""Plot random data."""
fig, axes = plt.subplots(ncols=4) # See note 1 below.
N_PERIODS = 16
x = pd.date_range(START_DATE, periods=N_PERIODS, freq="30 min") # See note 2 below.
y = np.ones(N_PERIODS)
for ax in axes:
ax.plot(x, y)
plt.close(fig)
Interestingly, Iโve found two ways to stop the segfaults. You can do one or the other (or both) of these actions to stop the segfaults:
- Change
ncolsto one of{2, 3, 5, 6, 7, 9, 10}. segfaults only seem to appear ifncolsis one of{4, 8, 16, 32} - Convert
xfrom apd.DatetimeIndexto a numpy array of matplotlib dates by doingx = matplotlib.dates.date2num(x)after line 1.
If I run gdb --args python script.py I get this:
Epoch 0: 1%|โโ| 7/1024 [00:01<02:52, 5.90it/s, loss=0.281, v_num=34]
Thread 1 "python" received signal SIGSEGV, Segmentation fault.
0x000055555568bb4b in _PyObject_GC_UNTRACK_impl (filename=0x55555587fef0
"/home/conda/feedstock_root/build_artifacts/python-split_1643749964416/work/Modules/gcmodule.c", lineno=2236,
op=0x7fffd05f5e00) at /home/conda/feedstock_root/build_artifacts/python-
split_1643749964416/work/Include/internal/pycore_object.h:76
76 /home/conda/feedstock_root/build_artifacts/python-split_1643749964416/work/Include/internal/pycore_object.h:
No such file or directory.
The minimal code example below seems to always crash whilst training on epoch 7.
In my โrealโ code, the crash was intermittent.
If ncols is large (e.g. 32) then the crash changes randomly between:
SIGSEGVsegfaultAttributeError: 'weakref' object has no attribute 'grad_fn'AttributeError: 'builtin_function_or_method' object has no attribute 'grad_fn'
The two AttributeErrors originate from the same bit of code: line 213 of torch/optim/optimizer.py: if p.grad.grad_fn is not None:. My guess is that the SIGSEGV is the root problem, and the AttributeErrors are symptoms of the underlying SIGSEGV error. I have (once) seen the code compain, at the same time, about a SIGSEGV and an AttributeError.
To Reproduce
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import pandas as pd
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
START_DATE = pd.Timestamp("2020-01-01")
N_EXAMPLES_PER_BATCH = 32
N_FEATURES = 1
class MyDataset(Dataset):
def __init__(self, n_batches_per_epoch: int):
self.n_batches_per_epoch = n_batches_per_epoch
def __len__(self):
return self.n_batches_per_epoch
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
x = torch.rand(N_EXAMPLES_PER_BATCH, N_FEATURES, 1)
y = torch.rand(N_EXAMPLES_PER_BATCH, 1)
return x, y
class LitNeuralNetwork(pl.LightningModule):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.nn = nn.Linear(in_features=N_FEATURES, out_features=1)
def forward(self, x):
x = self.flatten(x)
return self.nn(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.mse_loss(y_hat, y)
plot_data()
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
def plot_data():
"""Plot random data."""
# ncols needs to be 4 or higher to trigger a segfault.
fig, axes = plt.subplots(ncols=4)
N_PERIODS = 16
x = pd.date_range(START_DATE, periods=N_PERIODS, freq="30 min")
# The segfaults go away if I do:
# x = mdates.date2num(x)
y = np.ones(N_PERIODS)
for ax in axes:
ax.plot(x, y)
plt.close(fig)
dataloader = DataLoader(
MyDataset(n_batches_per_epoch=1024),
batch_size=None,
num_workers=2,
)
model = LitNeuralNetwork()
trainer = pl.Trainer()
trainer.fit(model=model, train_dataloader=dataloader)
Expected behavior
No segfault ๐
Environment
* CUDA:
- GPU:
- available: False
- version: None
* Packages:
- numpy: 1.22.2
- pyTorch_debug: False
- pyTorch_version: 1.10.2
- pytorch-lightning: 1.5.10
- tqdm: 4.62.3
* System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.9.10
- version: #31-Ubuntu SMP Thu Jan 13 17:41:06 UTC 2022
Hereโs my environment.yml file:
name: segfault
channels:
- conda-forge
- pytorch
dependencies:
- matplotlib
- pandas
- pytorch
- cpuonly
- pytorch-lightning
Additional context
This may be related to issue #11067.
I have tried and failed to reproduce this problem in two simpler scripts:
- A script which just uses matplotlib and pandas (no pytorch. no pytorch-lightning)
- A script which just uses matplotlib, pandas and pytorch (no pytorch-lightning).
These scripts are in a tiny GitHub repo for this bug report: https://github.com/JackKelly/pytorch-lighting-segfault
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Reactions: 1
- Comments: 20 (8 by maintainers)
@JackKelly I was encountering a similar segfault when trying to log an image with tensorboard inside the training step (i.e., I use
self.logger.experiment.add_imageinside the training step) .One thing that solves the segfault is adding a
plt.pause(0.1)right after I finish plotting or adding the image:Since I only log the images a few times during the epoch, adding a pause does not impact the performance of my training code.
Sorry for the delay in responding. I have just posted this issue to the PyTorch issue queue: https://github.com/pytorch/pytorch/issues/74188
Cool, thank you. Iโll close this issue for now. Iโve started a new issue in the matplotlib issue queue: https://github.com/matplotlib/matplotlib/issues/22478
Thanks again for your help!
Yes, I think that would be best as I donโt see any reason why this could be caused by us ๐
Sure, feel free to quote them ๐
Hi again,
I ran your code and for me, neither the pure pytorch, nor the lightning version segfaulted (or had any issue at all).
The two attribute errors you mentioned seem to indicate that somewhere
torch.Tensorare swapped out with proxies or builtin functions, for which I donโt have any explanation.I am very sorry, but with your exact environment, the code seems to run fine for me so I am a bit clueless on how to investigate this ๐
@JackKelly Thanks for opening this issue and for the reproducibility examples. Iโll have a look later today and report back if I find something ๐