pytorch-lightning: intermittent SIGSEGV segfault when plotting during `training_step`

๐Ÿ› Bug

If I create specific, simple matplotlib plots during training_step then I get a segfault after training for several batches.

Specifically, if I call plot_data() from within training_step, then I get a segfault:

def plot_data():
    """Plot random data."""
    fig, axes = plt.subplots(ncols=4)  # See note 1 below.
    N_PERIODS = 16
    x = pd.date_range(START_DATE, periods=N_PERIODS, freq="30 min")  # See note 2 below.
    y = np.ones(N_PERIODS)
    for ax in axes:
        ax.plot(x, y)
    plt.close(fig)

Interestingly, Iโ€™ve found two ways to stop the segfaults. You can do one or the other (or both) of these actions to stop the segfaults:

  1. Change ncols to one of {2, 3, 5, 6, 7, 9, 10}. segfaults only seem to appear if ncols is one of {4, 8, 16, 32}
  2. Convert x from a pd.DatetimeIndex to a numpy array of matplotlib dates by doing x = matplotlib.dates.date2num(x) after line 1.

If I run gdb --args python script.py I get this:

Epoch 0:   1%|โ–ˆโ–| 7/1024 [00:01<02:52,  5.90it/s, loss=0.281, v_num=34]
Thread 1 "python" received signal SIGSEGV, Segmentation fault.
0x000055555568bb4b in _PyObject_GC_UNTRACK_impl (filename=0x55555587fef0 
"/home/conda/feedstock_root/build_artifacts/python-split_1643749964416/work/Modules/gcmodule.c", lineno=2236, 
op=0x7fffd05f5e00) at /home/conda/feedstock_root/build_artifacts/python-
split_1643749964416/work/Include/internal/pycore_object.h:76
76  /home/conda/feedstock_root/build_artifacts/python-split_1643749964416/work/Include/internal/pycore_object.h:
 No such file or directory.

The minimal code example below seems to always crash whilst training on epoch 7.

In my โ€œrealโ€ code, the crash was intermittent.

If ncols is large (e.g. 32) then the crash changes randomly between:

  • SIGSEGV segfault
  • AttributeError: 'weakref' object has no attribute 'grad_fn'
  • AttributeError: 'builtin_function_or_method' object has no attribute 'grad_fn'

The two AttributeErrors originate from the same bit of code: line 213 of torch/optim/optimizer.py: if p.grad.grad_fn is not None:. My guess is that the SIGSEGV is the root problem, and the AttributeErrors are symptoms of the underlying SIGSEGV error. I have (once) seen the code compain, at the same time, about a SIGSEGV and an AttributeError.

To Reproduce

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import pandas as pd
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl


START_DATE = pd.Timestamp("2020-01-01")
N_EXAMPLES_PER_BATCH = 32
N_FEATURES = 1


class MyDataset(Dataset):
    def __init__(self, n_batches_per_epoch: int):
        self.n_batches_per_epoch = n_batches_per_epoch

    def __len__(self):
        return self.n_batches_per_epoch

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        x = torch.rand(N_EXAMPLES_PER_BATCH, N_FEATURES, 1)
        y = torch.rand(N_EXAMPLES_PER_BATCH, 1)
        return x, y


class LitNeuralNetwork(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.nn = nn.Linear(in_features=N_FEATURES, out_features=1)

    def forward(self, x):
        x = self.flatten(x)
        return self.nn(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        plot_data()
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


def plot_data():
    """Plot random data."""
    # ncols needs to be 4 or higher to trigger a segfault.
    fig, axes = plt.subplots(ncols=4)
    N_PERIODS = 16
    x = pd.date_range(START_DATE, periods=N_PERIODS, freq="30 min")
    # The segfaults go away if I do:
    # x = mdates.date2num(x)
    y = np.ones(N_PERIODS)
    for ax in axes:
        ax.plot(x, y)
    plt.close(fig)


dataloader = DataLoader(
    MyDataset(n_batches_per_epoch=1024),
    batch_size=None,
    num_workers=2,
)

model = LitNeuralNetwork()
trainer = pl.Trainer()
trainer.fit(model=model, train_dataloader=dataloader)

Expected behavior

No segfault ๐Ÿ™‚

Environment

* CUDA:
	- GPU:
	- available:         False
	- version:           None
* Packages:
	- numpy:             1.22.2
	- pyTorch_debug:     False
	- pyTorch_version:   1.10.2
	- pytorch-lightning: 1.5.10
	- tqdm:              4.62.3
* System:
	- OS:                Linux
	- architecture:
		- 64bit
		- ELF
	- processor:         x86_64
	- python:            3.9.10
	- version:           #31-Ubuntu SMP Thu Jan 13 17:41:06 UTC 2022

Hereโ€™s my environment.yml file:

name: segfault
channels:
  - conda-forge
  - pytorch
dependencies:
  - matplotlib
  - pandas
  - pytorch
  - cpuonly 
  - pytorch-lightning

Additional context

This may be related to issue #11067.

I have tried and failed to reproduce this problem in two simpler scripts:

  • A script which just uses matplotlib and pandas (no pytorch. no pytorch-lightning)
  • A script which just uses matplotlib, pandas and pytorch (no pytorch-lightning).

These scripts are in a tiny GitHub repo for this bug report: https://github.com/JackKelly/pytorch-lighting-segfault

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Reactions: 1
  • Comments: 20 (8 by maintainers)

Commits related to this issue

Most upvoted comments

@JackKelly I was encountering a similar segfault when trying to log an image with tensorboard inside the training step (i.e., I use self.logger.experiment.add_image inside the training step) .

One thing that solves the segfault is adding a plt.pause(0.1) right after I finish plotting or adding the image:

import matplotlib.pyplot as plt
plt.pause(0.1)

Since I only log the images a few times during the epoch, adding a pause does not impact the performance of my training code.

Sorry for the delay in responding. I have just posted this issue to the PyTorch issue queue: https://github.com/pytorch/pytorch/issues/74188

Cool, thank you. Iโ€™ll close this issue for now. Iโ€™ve started a new issue in the matplotlib issue queue: https://github.com/matplotlib/matplotlib/issues/22478

Thanks again for your help!

Yes, I think that would be best as I donโ€™t see any reason why this could be caused by us ๐Ÿ˜ƒ

Sure, feel free to quote them ๐Ÿ˜ƒ

Hi again,

I ran your code and for me, neither the pure pytorch, nor the lightning version segfaulted (or had any issue at all).

image

The two attribute errors you mentioned seem to indicate that somewhere torch.Tensor are swapped out with proxies or builtin functions, for which I donโ€™t have any explanation.

I am very sorry, but with your exact environment, the code seems to run fine for me so I am a bit clueless on how to investigate this ๐Ÿ˜•

@JackKelly Thanks for opening this issue and for the reproducibility examples. Iโ€™ll have a look later today and report back if I find something ๐Ÿ˜ƒ