pytorch-lightning: DDPShardedPlugin consolidate_state_dict RuntimeError

🐛 Bug

After an (seemingly arbitrary) number of steps/epochs, DDPShardedPlugin::optimizer_state crashes on its consolidate_state_dict call:

  1. Pytorch’s distributed broadcast_object_list tries object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item())
  2. RuntimeError: Trying to create tensor with negative dimension -5193452289200645882: [-5193452289200645882]

Stacktrace:

Traceback (most recent call last):
  File "/home/robertsc/.conda/envs/pytorch1.7/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 560, in train
    self.train_loop.run_training_epoch()
  File "/home/robertsc/.conda/envs/pytorch1.7/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 562, in run_training_epoch
    self.trainer.run_evaluation()
  File "/home/robertsc/.conda/envs/pytorch1.7/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 667, in run_evaluation
    self.evaluation_loop.on_evaluation_end()
  File "/home/robertsc/.conda/envs/pytorch1.7/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 110, in on_evaluation_end
    self.trainer.call_hook('on_validation_end', *args, **kwargs)
  File "/home/robertsc/.conda/envs/pytorch1.7/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 924, in call_hook
    trainer_hook(*args, **kwargs)
  File "/home/robertsc/.conda/envs/pytorch1.7/lib/python3.8/site-object_sizes_tensorpackages/pytorch_lightning/trainer/callback_hook.py", line 177, in on_validation_end
    callback.on_validation_end(self, self.get_model())
  File "/home/robertsc/.conda/envs/pytorch1.7/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 203, in on_validation_end
    self.save_checkpoint(trainer, pl_module)
  File "/home/robertsc/.conda/envs/pytorch1.7/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 253, in save_checkpoint
    self._save_last_checkpoint(trainer, pl_module, monitor_candidates)
  File "/home/robertsc/.conda/envs/pytorch1.7/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 567, in _save_last_checkpoint
    self._save_model(last_filepath, trainer, pl_module)
  File "/home/robertsc/.conda/envs/pytorch1.7/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 361, in _save_model
    self.save_function(filepath, self.save_weights_only)
  File "/home/robertsc/.conda/envs/pytorch1.7/lib/python3.8/site-packages/pytorch_lightning/trainer/properties.py", line 257, in save_checkpoint
    self.checkpoint_connector.save_checkpoint(filepath, weights_only)
  File "/home/robertsc/.conda/envs/pytorch1.7/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 392, in save_checkpoint
    checkpoint = self.dump_checkpoint(weights_only)
  File "/home/robertsc/.conda/envs/pytorch1.7/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 283, in dump_checkpoint
    optimizer_state = self.trainer.accelerator_backend.optimizer_state(optimizer)
  File "/home/robertsc/.conda/envs/pytorch1.7/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 206, in optimizer_state
    return self.ddp_plugin.optimizer_state(optimizer)
  File "/home/robertsc/.conda/envs/pytorch1.7/lib/python3.8/site-packages/pytorch_lightning/plugins/sharded_plugin.py", line 42, in optimizer_state
    optimizer.consolidate_state_dict()
  File "/home/robertsc/.conda/envs/pytorch1.7/lib/python3.8/site-packages/fairscale/optim/oss.py", line 320, in consolidate_state_dict
    self._broadcast_state_dict()
  File "/home/robertsc/.conda/envs/pytorch1.7/lib/python3.8/site-packages/fairscale/optim/oss.py", line 349, in _broadcast_state_dict
    dist.broadcast_object_list([0], src=global_rank, group=self.group)
  File "/home/robertsc/.conda/envs/pytorch1.7/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1687, in broadcast_object_list
    object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item())
RuntimeError: Trying to create tensor with negative dimension -5193452289200645882: [-5193452289200645882]

Environment

  • CUDA:
    • GPU:
      • TITAN RTX
    • available: True
    • version: 11.0
  • Packages:
    • numpy: 1.18.1
    • pyTorch_debug: False
    • pyTorch_version: 1.8.0.dev20210122
    • pytorch-lightning: 1.1.4
    • tqdm: 4.48.2
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor:
    • python: 3.8.3
    • version: #1 SMP Debian 4.19.160-2 (2020-11-28)

cc @awaelchli @rohitgr7 @akihironitta

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Reactions: 3
  • Comments: 35 (5 by maintainers)

Most upvoted comments

Thanks guys! I think after extensive debugging from @robogast it seems unrelated to sharded from that I understand. Few de-sync issues on our end based on 1.2 vs what’s on master, but hopefully next week it will be resolved!

Update: After encountering the same error without enabling ddp_sharded, I did some further digging and found the following:

Currently on PL master branch, model_checkpoint.py creates multiple broadcasts on all ranks at every validation_end (because I setup my ModelCheckpoint to track a validation metric) to keep the checkpoint filename consistent. However, the logic which manages this seems a bit smelly to me; thus my working hypothesis is that there is a circumstance in which model_checkpoint.py wouldn’t call an equal amount of broadcast on all ranks, thus leading to the strange NCCL behaviour described above.

This logic is removed from model_checkpoint.py in the branch 1.2-dev, and I haven’t observed any issues (with or without ddp_sharded) since I’ve switched to that version. So, as far as I’m concerned, I’m closing this issue whenever PL version 1.2 is officially released.

many thanks for all this work @robogast, and sorry for not having been more reactive, I had a hard time reproducing but that might explain why. Feel free to pull me in again and I’m glad that works for you

Sorry that might not have been clear, I’m using the nightly version of pytorch exactly because I ran into the broadcast_object_list was broken because of the rank != device bug. I’ve spent the day chasing the error, but nothing so far.

Pytorch native adam with amsgrad (wrapped by lightning ofcourse)

To be clear: those are two different stacktraces. They happen on the exact same LOC, so I assume they’re related.

oh wow, I never saw that, interesting. In recent fairscale I switched to using pytorch’s broadcast object util instead of a dedicated one, looks like it can fail somehow. One check could be to try to set _torch_broadcast_object to False, I’ll check with Rohan (author of this part) whether there’s something I’m doing wrong