pytorch-forecasting: Multi-GPU training results in "ProcessExitedException process 0 terminated with signal SIGSEGV" exception for Baseline and TFT models.

  • PyTorch-Forecasting version: 1.0.0
  • PyTorch version: 2.0.1+cu117
  • Lightning version: 2.0.4
  • Python version: 3.10.11
  • Operating System: Linux-5.10.0-23-cloud-amd64-x86_64-with-glibc2.31 (Google Cloud)

Expected behavior

I am trying to run the exact code from the stallion example for TFTs on a multi-gpu device in preparation to train a similar model on my own data in the same environment. I am able to run on a single GPU machine without issue and would expect to be able to run it without issue on a multi-gpu machine (especially when specifying to use only 1 of the multiple GPUs with devices=1). I have also tested out a similar script with my own data and am running into the same issues.

Actual behavior

When I run the same code on a multi GPU machine I get the following error both fitting the Baseline model and the TFT model.

Baseline

# calculate baseline mean absolute error, i.e. predict next value as the last available value from the history
baseline_predictions = Baseline().predict(val_dataloader, return_y=True)
MAE()(baseline_predictions.output, baseline_predictions.y)

Output

Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]
---------------------------------------------------------------------------
ProcessExitedException                    Traceback (most recent call last)
Cell In[6], line 2
      1 # # calculate baseline mean absolute error, i.e. predict next value as the last available value from the history
----> 2 baseline_predictions = Baseline().predict(val_dataloader, return_y=True)
      3 MAE()(baseline_predictions.output, baseline_predictions.y)

File ~/.local/lib/python3.10/site-packages/pytorch_forecasting/models/base_model.py:1423, in BaseModel.predict(self, data, mode, return_index, return_decoder_lengths, batch_size, num_workers, fast_dev_run, return_x, return_y, mode_kwargs, trainer_kwargs, write_interval, output_dir, **kwargs)
   1421 logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)
   1422 trainer = Trainer(fast_dev_run=fast_dev_run, **trainer_kwargs)
-> 1423 trainer.predict(self, dataloader)
   1424 logging.getLogger("lightning").setLevel(log_level_lighting)
   1425 logging.getLogger("pytorch_lightning").setLevel(log_level_pytorch_lightning)

File /opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:845, in Trainer.predict(self, model, dataloaders, datamodule, return_predictions, ckpt_path)
    843     model = _maybe_unwrap_optimized(model)
    844     self.strategy._lightning_module = model
--> 845 return call._call_and_handle_interrupt(
    846     self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
    847 )

File /opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:41, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     39 try:
     40     if trainer.strategy.launcher is not None:
---> 41         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     42     return trainer_fn(*args, **kwargs)
     44 except _TunerExitException:

File /opt/conda/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/multiprocessing.py:124, in _MultiProcessingLauncher.launch(self, function, trainer, *args, **kwargs)
    116 process_context = mp.start_processes(
    117     self._wrapping_function,
    118     args=process_args,
   (...)
    121     join=False,  # we will join ourselves to get the process references
    122 )
    123 self.procs = process_context.processes
--> 124 while not process_context.join():
    125     pass
    127 worker_output = return_queue.get()

File /opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py:140, in ProcessContext.join(self, timeout)
    138 if exitcode < 0:
    139     name = signal.Signals(-exitcode).name
--> 140     raise ProcessExitedException(
    141         "process %d terminated with signal %s" %
    142         (error_index, name),
    143         error_index=error_index,
    144         error_pid=failed_process.pid,
    145         exit_code=exitcode,
    146         signal_name=name
    147     )
    148 else:
    149     raise ProcessExitedException(
    150         "process %d terminated with exit code %d" %
    151         (error_index, exitcode),
   (...)
    154         exit_code=exitcode
    155     )

ProcessExitedException: process 0 terminated with signal SIGSEGV

TFT

# configure network and trainer
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
lr_logger = LearningRateMonitor()  # log the learning rate
logger = TensorBoardLogger("lightning_logs")  # logging results to a tensorboard

trainer = pl.Trainer(
    max_epochs=10,
    accelerator="cuda", #added line vs example code
    strategy="ddp_notebook", #added line vs example code
    devices=2, #added line vs example code
    enable_model_summary=True,
    gradient_clip_val=0.1,
    limit_train_batches=50,  # coment in for training, running valiation every 30 batches
    # fast_dev_run=True,  # comment in to check that networkor dataset has no serious bugs
    callbacks=[lr_logger, early_stop_callback],
    logger=logger,
)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=16,
    attention_head_size=2,
    dropout=0.1,
    hidden_continuous_size=8,
    loss=QuantileLoss(),
    log_interval=10,  # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches
    optimizer="Ranger",
    reduce_on_plateau_patience=4,
)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

Output

[rank: 0] Global seed set to 42
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
[rank: 1] Global seed set to 42
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]

   | Name                               | Type                            | Params
----------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0     
1  | logging_metrics                    | ModuleList                      | 0     
2  | input_embeddings                   | MultiEmbedding                  | 1.3 K 
3  | prescalers                         | ModuleDict                      | 256   
4  | static_variable_selection          | VariableSelectionNetwork        | 3.4 K 
5  | encoder_variable_selection         | VariableSelectionNetwork        | 8.0 K 
6  | decoder_variable_selection         | VariableSelectionNetwork        | 2.7 K 
7  | static_context_variable_selection  | GatedResidualNetwork            | 1.1 K 
8  | static_context_initial_hidden_lstm | GatedResidualNetwork            | 1.1 K 
9  | static_context_initial_cell_lstm   | GatedResidualNetwork            | 1.1 K 
10 | static_context_enrichment          | GatedResidualNetwork            | 1.1 K 
11 | lstm_encoder                       | LSTM                            | 2.2 K 
12 | lstm_decoder                       | LSTM                            | 2.2 K 
13 | post_lstm_gate_encoder             | GatedLinearUnit                 | 544   
14 | post_lstm_add_norm_encoder         | AddNorm                         | 32    
15 | static_enrichment                  | GatedResidualNetwork            | 1.4 K 
16 | multihead_attn                     | InterpretableMultiHeadAttention | 808   
17 | post_attn_gate_norm                | GateAddNorm                     | 576   
18 | pos_wise_ff                        | GatedResidualNetwork            | 1.1 K 
19 | pre_output_gate_norm               | GateAddNorm                     | 576   
20 | output_layer                       | Linear                          | 119   
----------------------------------------------------------------------------------------
29.4 K    Trainable params
0         Non-trainable params
29.4 K    Total params
0.118     Total estimated model params size (MB)

ProcessExitedException                    Traceback (most recent call last)
Cell In[11], line 2
      1 # fit network
----> 2 trainer.fit(
      3     tft,
      4     train_dataloaders=train_dataloader,
      5     val_dataloaders=val_dataloader,
      6 )

File /opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:531, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    529 model = _maybe_unwrap_optimized(model)
    530 self.strategy._lightning_module = model
--> 531 call._call_and_handle_interrupt(
    532     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    533 )

File /opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:41, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     39 try:
     40     if trainer.strategy.launcher is not None:
---> 41         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     42     return trainer_fn(*args, **kwargs)
     44 except _TunerExitException:

File /opt/conda/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/multiprocessing.py:124, in _MultiProcessingLauncher.launch(self, function, trainer, *args, **kwargs)
    116 process_context = mp.start_processes(
    117     self._wrapping_function,
    118     args=process_args,
   (...)
    121     join=False,  # we will join ourselves to get the process references
    122 )
    123 self.procs = process_context.processes
--> 124 while not process_context.join():
    125     pass
    127 worker_output = return_queue.get()

File /opt/conda/lib/python3.10/site-packages/torch/multiprocessing/spawn.py:140, in ProcessContext.join(self, timeout)
    138 if exitcode < 0:
    139     name = signal.Signals(-exitcode).name
--> 140     raise ProcessExitedException(
    141         "process %d terminated with signal %s" %
    142         (error_index, name),
    143         error_index=error_index,
    144         error_pid=failed_process.pid,
    145         exit_code=exitcode,
    146         signal_name=name
    147     )
    148 else:
    149     raise ProcessExitedException(
    150         "process %d terminated with exit code %d" %
    151         (error_index, exitcode),
   (...)
    154         exit_code=exitcode
    155     )

ProcessExitedException: process 0 terminated with signal SIGSEGV

Code to reproduce the problem

I copied the code exactly from here

The only changes made were additional specification of multi-gpu parameters in the TFT Trainer call:

trainer = pl.Trainer(
    max_epochs=10,
    accelerator="cuda", #added line vs example code
    strategy="ddp_notebook", #added line vs example code
    devices=2, #added line vs example code
    enable_model_summary=True,
    gradient_clip_val=0.1,
    limit_train_batches=50,  # coment in for training, running valiation every 30 batches
    # fast_dev_run=True,  # comment in to check that networkor dataset has no serious bugs
    callbacks=[lr_logger, early_stop_callback],
    logger=logger,
)

Potential Solution

I have spent a couple weeks trying to resolve these issues and it seems to be at least related to a memory sharing issue between GPUs. I have found one possible solution on the lightning forum here, but am still relatively new to this package and am struggling to figure out a generalized way to implement this fix while building the model with the from_dataset() method while also maintaining maximum flexibility of the model to train in CPU, GPU and multi-GPU environments.

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Reactions: 1
  • Comments: 16

Most upvoted comments

Same behavior. No luck trying to adapt the solution from the lightning forum.

For what its worth adding

    def train_dataloader(self):
        return train_dataloader

to https://github.com/jdb78/pytorch-forecasting/blob/d8a4462fb12de025f8bef852df1f5b48a7ae5b7c/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py#L29

doesn’t work. Perhaps unsurprisingly.