pytorch-lightning: [TPU-Colab] RuntimeError: Cannot replicate if number of devices (1) is different from 8

๐Ÿ› Bug

When I run trainer.test(model) on a pre-trained model using a Colab TPU instance, the following exception is thrown.

NB: trainer.train(model) works.

Stack trace

Traceback (most recent call last):
  File "run_pl_ged.py", line 217, in <module>
    trainer.test(model)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py", line 958, in test
    self.fit(model)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py", line 777, in fit
    xmp.spawn(self.tpu_train, args=(model,), nprocs=self.num_tpu_cores, start_method=start_method)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 182, in spawn
    start_method=start_method)
  File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py", line 158, in start_processes
    while not context.join():
  File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py", line 119, in join
    raise Exception(msg)
Exception: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
    fn(i, *args)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 116, in _start_fn
    _setup_replication()
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 109, in _setup_replication
    xm.set_replication(str(device), [str(device)])
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 194, in set_replication
    replication_devices = xla_replication_devices(devices)
  File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 181, in xla_replication_devices
    .format(len(local_devices), len(kind_devices)))
RuntimeError: Cannot replicate if number of devices (1) is different from 8

Code sample

import pytorch_lightning as pl

model = MyLightningModule()
trainer = pl.Trainer(num_tpu_cores=8)
model = model.load_from_checkpoint(checkpoint)
model.prepare_data() # See https://github.com/PyTorchLightning/pytorch-lightning/issues/1562
trainer.test(model)

Environment

Colab TPU instance with XLA 1.5

* CUDA:
	- GPU:
	- available:         False
	- version:           None
* Packages:
	- numpy:             1.18.3
	- pyTorch_debug:     False
	- pyTorch_version:   1.5.0a0+ab660ae
	- pytorch-lightning: 0.7.5
	- tensorboard:       2.2.1
	- tqdm:              4.38.0
* System:
	- OS:                Linux
	- architecture:
		- 64bit
		- 
	- processor:         x86_64
	- python:            3.6.9
	- version:           #1 SMP Wed Feb 19 05:26:34 PST 2020

Possibly related: https://github.com/PyTorchLightning/pytorch-lightning/pull/1019

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Reactions: 3
  • Comments: 19 (7 by maintainers)

Most upvoted comments

Iโ€™m experiencing the same issue with the [Lightning TPU example notebook] (https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/06-mnist-tpu-training.ipynb), run on Colab.

Both the single TPU core examples work, but when trying to run on 8 cores I get the error:

โ€œRuntimeError: Cannot replicate if number of devices (1) is different from 8โ€

Am still facing this issue today on Kaggle:

training on 8 TPU cores training on 8 TPU cores Exception in device=TPU:0: Cannot replicate if number of devices (1) is different from 8 Traceback (most recent call last): File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 322, in _start_fn _setup_replication() Exception in device=TPU:1: Cannot replicate if number of devices (1) is different from 8 File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 315, in _setup_replication xm.set_replication(device, [device]) Traceback (most recent call last): File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ€, line 317, in set_replication replication_devices = xla_replication_devices(devices) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ€, line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 322, in _start_fn _setup_replication() RuntimeError: Cannot replicate if number of devices (1) is different from 8 Exception in device=TPU:2: Cannot replicate if number of devices (1) is different from 8 File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 315, in _setup_replication xm.set_replication(device, [device]) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ€, line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) Traceback (most recent call last): File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ€, line 317, in set_replication replication_devices = xla_replication_devices(devices) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) RuntimeError: Cannot replicate if number of devices (1) is different from 8 Exception in device=TPU:3: Cannot replicate if number of devices (1) is different from 8 Exception in device=TPU:4: Cannot replicate if number of devices (1) is different from 8 File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 322, in _start_fn _setup_replication() Traceback (most recent call last): Traceback (most recent call last): File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ€, line 317, in set_replication replication_devices = xla_replication_devices(devices) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ€, line 317, in set_replication replication_devices = xla_replication_devices(devices) RuntimeError: Cannot replicate if number of devices (1) is different from 8 File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 322, in _start_fn _setup_replication() File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 322, in _start_fn _setup_replication() File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 315, in _setup_replication xm.set_replication(device, [device]) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ€, line 317, in set_replication replication_devices = xla_replication_devices(devices) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 315, in _setup_replication xm.set_replication(device, [device]) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ€, line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 315, in _setup_replication xm.set_replication(device, [device]) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ€, line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) RuntimeError: Cannot replicate if number of devices (1) is different from 8 Exception in device=TPU:5: Cannot replicate if number of devices (1) is different from 8 File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ€, line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) Exception in device=TPU:6: Cannot replicate if number of devices (1) is different from 8 Traceback (most recent call last): RuntimeError: Cannot replicate if number of devices (1) is different from 8 Traceback (most recent call last): File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 322, in _start_fn _setup_replication() Exception in device=TPU:7: Cannot replicate if number of devices (1) is different from 8 File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) Traceback (most recent call last): File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 315, in _setup_replication xm.set_replication(device, [device]) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 322, in _start_fn _setup_replication() File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ€, line 315, in _setup_replication xm.set_replication(device, [device]) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ€, line 317, in set_replication replication_devices = xla_replication_devices(devices) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ€, line 317, in set_replication replication_devices = xla_replication_devices(devices) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ€, line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) File โ€œ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ€, line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) RuntimeError: Cannot replicate if number of devices (1) is different from 8

it shall be fixed on master, feel free to reopen if needed ๐Ÿฐ

I confirm the problem still exists.

@satpalsr Could you share your reproducible script for the bug? I could take a look.

@lezwon i run this trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=8) only and tried to restart kernal also. but it did not work

Having the same issue with a kaggle kernel .