pytorch-lightning: WandbLogger cannot be used with 'ddp'

๐Ÿ› Bug

wandb modifies init such that a child process calling init returns None if the master process has called init. This seems to cause a bug with ddp, and results in rank zero having experiment = None, which crashes the program.

To Reproduce

Can be reproduced with the basic MNIST gpu template, simply add a WandbLogger and pass โ€˜ddpโ€™ as the distributed backend.

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/rmrao/anaconda3/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/home/rmrao/anaconda3/lib/python3.6/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 331, in ddp_train
    self.run_pretrain_routine(model)
  File "/home/rmrao/anaconda3/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 757, in run_pretrain_routine
    self.logger.log_hyperparams(ref_model.hparams)
  File "/home/rmrao/anaconda3/lib/python3.6/site-packages/pytorch_lightning/logging/base.py", line 14, in wrapped_fn
    fn(self, *args, **kwargs)
  File "/home/rmrao/anaconda3/lib/python3.6/site-packages/pytorch_lightning/logging/wandb.py", line 79, in log_hyperparams
    self.experiment.config.update(params)
AttributeError: 'NoneType' object has no attribute 'config'

This occurs with the latest wandb version and with pytorch-lightning 0.6.

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 20 (14 by maintainers)

Most upvoted comments

This particular problem I think stems from this part of the wandb.init(...) code:

def init(...):
    ...
    # If a thread calls wandb.init() it will get the same Run object as
    # the parent. If a child process with distinct memory space calls
    # wandb.init(), it won't get an error, but it will get a result of
    # None.
    # This check ensures that a child process can safely call wandb.init()
    # after a parent has (only the parent will create the Run object).
    # This doesn't protect against the case where the parent doesn't call
    # wandb.init but two children do.
    if run or os.getenv(env.INITED):
        return run

Child processes end up getting None for the wandb run object, which causes logging to fail. There are probably two reasonable and complementary solutions:

  1. The main thread should avoid creating a wandb experiment unless absolutely necessary.

Right now, this is the only part of the logging code that the parent thread calls (I assume itโ€™s called when pickling):

def __getstate__(self):
    state = self.__dict__.copy()
    # cannot be pickled
    state['_experiment'] = None
    # args needed to reload correct experiment
    state['_id'] = self.experiment.id
    return state

If this is changed to:

def __getstate__(self):
    state = self.__dict__.copy()
    # args needed to reload correct experiment
    if self._experiment is not None:
        state['_id'] = self._experiment.id
    else:
        state['_id'] = None

    # cannot be pickled
    state['_experiment'] = None
    return state

That will ensure that unless the user explicitly logs something or creates the wandb experiment first, then the main thread will not try to create an experiment. Since subsequent logging / saving code is wrapped by the @rank_zero_only decorator, this will generally solve the issue in the base case.

Itโ€™s also possible that these properties are also called by master. Ideally they would be wrapped to not create the experiment unless it had been already created (i.e. experiment should only be created by a function that is wrapped with the @rank_zero_only decorator).

  1. If the main thread has created an experiment, rank zero should be passed the re-init argument.

wandb does allow you to reinitialize the experiment. I tried to play around with this a little bit and got some errors, but in theory adding this:

wandb.init(..., reinit=dist.is_available() and dist.is_initialized() and dist.get_rank() == 0)

should force a re-initialization when wandb is already initialzed for rank zero.

It is solved here #13166.

Just to clarify @parasjโ€™s solution, presumably you also have to pass in some name as a keyword arg while constructing the logger.

Unfortunately this means youโ€™re now responsible for generating unique names for subsequent runs.