pytorch-lightning: [TPU-Colab] RuntimeError: Cannot replicate if number of devices (1) is different from 8
๐ Bug
When I run trainer.test(model) on a pre-trained model using a Colab TPU instance, the following exception is thrown.
NB:
trainer.train(model)works.
Stack trace
Traceback (most recent call last):
File "run_pl_ged.py", line 217, in <module>
trainer.test(model)
File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py", line 958, in test
self.fit(model)
File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py", line 777, in fit
xmp.spawn(self.tpu_train, args=(model,), nprocs=self.num_tpu_cores, start_method=start_method)
File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 182, in spawn
start_method=start_method)
File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py", line 158, in start_processes
while not context.join():
File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py", line 119, in join
raise Exception(msg)
Exception:
-- Process 0 terminated with the following error:
Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
fn(i, *args)
File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 116, in _start_fn
_setup_replication()
File "/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 109, in _setup_replication
xm.set_replication(str(device), [str(device)])
File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 194, in set_replication
replication_devices = xla_replication_devices(devices)
File "/usr/local/lib/python3.6/dist-packages/torch_xla/core/xla_model.py", line 181, in xla_replication_devices
.format(len(local_devices), len(kind_devices)))
RuntimeError: Cannot replicate if number of devices (1) is different from 8
Code sample
import pytorch_lightning as pl
model = MyLightningModule()
trainer = pl.Trainer(num_tpu_cores=8)
model = model.load_from_checkpoint(checkpoint)
model.prepare_data() # See https://github.com/PyTorchLightning/pytorch-lightning/issues/1562
trainer.test(model)
Environment
Colab TPU instance with XLA 1.5
* CUDA:
- GPU:
- available: False
- version: None
* Packages:
- numpy: 1.18.3
- pyTorch_debug: False
- pyTorch_version: 1.5.0a0+ab660ae
- pytorch-lightning: 0.7.5
- tensorboard: 2.2.1
- tqdm: 4.38.0
* System:
- OS: Linux
- architecture:
- 64bit
-
- processor: x86_64
- python: 3.6.9
- version: #1 SMP Wed Feb 19 05:26:34 PST 2020
Possibly related: https://github.com/PyTorchLightning/pytorch-lightning/pull/1019
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Reactions: 3
- Comments: 19 (7 by maintainers)
Iโm experiencing the same issue with the [Lightning TPU example notebook] (https://colab.research.google.com/github/PytorchLightning/pytorch-lightning/blob/master/notebooks/06-mnist-tpu-training.ipynb), run on Colab.
Both the single TPU core examples work, but when trying to run on 8 cores I get the error:
โRuntimeError: Cannot replicate if number of devices (1) is different from 8โ
Am still facing this issue today on Kaggle:
training on 8 TPU cores training on 8 TPU cores Exception in device=TPU:0: Cannot replicate if number of devices (1) is different from 8 Traceback (most recent call last): File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 322, in _start_fn _setup_replication() Exception in device=TPU:1: Cannot replicate if number of devices (1) is different from 8 File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 315, in _setup_replication xm.set_replication(device, [device]) Traceback (most recent call last): File โ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ, line 317, in set_replication replication_devices = xla_replication_devices(devices) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ, line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 322, in _start_fn _setup_replication() RuntimeError: Cannot replicate if number of devices (1) is different from 8 Exception in device=TPU:2: Cannot replicate if number of devices (1) is different from 8 File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 315, in _setup_replication xm.set_replication(device, [device]) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ, line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) Traceback (most recent call last): File โ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ, line 317, in set_replication replication_devices = xla_replication_devices(devices) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) RuntimeError: Cannot replicate if number of devices (1) is different from 8 Exception in device=TPU:3: Cannot replicate if number of devices (1) is different from 8 Exception in device=TPU:4: Cannot replicate if number of devices (1) is different from 8 File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 322, in _start_fn _setup_replication() Traceback (most recent call last): Traceback (most recent call last): File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ, line 317, in set_replication replication_devices = xla_replication_devices(devices) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ, line 317, in set_replication replication_devices = xla_replication_devices(devices) RuntimeError: Cannot replicate if number of devices (1) is different from 8 File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 322, in _start_fn _setup_replication() File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 322, in _start_fn _setup_replication() File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 315, in _setup_replication xm.set_replication(device, [device]) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ, line 317, in set_replication replication_devices = xla_replication_devices(devices) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 315, in _setup_replication xm.set_replication(device, [device]) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ, line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 315, in _setup_replication xm.set_replication(device, [device]) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ, line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) RuntimeError: Cannot replicate if number of devices (1) is different from 8 Exception in device=TPU:5: Cannot replicate if number of devices (1) is different from 8 File โ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ, line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) Exception in device=TPU:6: Cannot replicate if number of devices (1) is different from 8 Traceback (most recent call last): RuntimeError: Cannot replicate if number of devices (1) is different from 8 Traceback (most recent call last): File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 322, in _start_fn _setup_replication() Exception in device=TPU:7: Cannot replicate if number of devices (1) is different from 8 File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) Traceback (most recent call last): File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 315, in _setup_replication xm.set_replication(device, [device]) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 329, in _mp_start_fn _start_fn(index, pf_cfg, fn, args) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 322, in _start_fn _setup_replication() File โ/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.pyโ, line 315, in _setup_replication xm.set_replication(device, [device]) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ, line 317, in set_replication replication_devices = xla_replication_devices(devices) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ, line 317, in set_replication replication_devices = xla_replication_devices(devices) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ, line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) File โ/opt/conda/lib/python3.7/site-packages/torch_xla/core/xla_model.pyโ, line 287, in xla_replication_devices format(len(local_devices), len(kind_devices))) RuntimeError: Cannot replicate if number of devices (1) is different from 8
it shall be fixed on master, feel free to reopen if needed ๐ฐ
I confirm the problem still exists.
@satpalsr Could you share your reproducible script for the bug? I could take a look.
@lezwon i run this
trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=8)only and tried to restart kernal also. but it did not workHaving the same issue with a kaggle kernel .