pytorch-lightning: Unable to train on v3-8 TPUs with lightning. Training is stuck/deadlocked ?
Bug description
The training code simply gets stuck on the TPU.
What version are you seeing the problem on?
master
How to reproduce the bug
Just used the following calls to trainer and fit.
pl.seed_everything(7, workers=True)
torch.set_float32_matmul_precision("high")
model = SiameseEncoder(model_name_or_path)
datamodule = RetrievalDataModule(
model_name_or_path,
{"train": str(train_path), "val": str(val_path)},
{"train": train_batch_size, "val": val_batch_size},
padding_style,
workers=workers,
)
monitor = "rec@1"
# TODO Add clipping, control validation intervals etc. Lots of work to be done
trainer = pl.Trainer(
logger=AimLogger(experiment="SiameseEncoder"),
accelerator=accelerator,
devices=devices,
deterministic=True,
max_epochs=max_epochs,
val_check_interval=0.1,
gradient_clip_val=1,
precision="16-mixed",
callbacks=[
EarlyStopping(monitor=monitor, mode="max", patience=10),
ModelCheckpoint(monitor=monitor, mode="max", save_top_k=1),
],
)
trainer.fit(model, datamodule=datamodule)
I also set export PJRT_DEVICE=TPU
before calling the trainer code from CLI.
Error messages and logs
Global seed set to 7
Some weights of the model checkpoint at sentence-transformers/multi-qa-mpnet-base-cos-v1 were not used when initializing MPNetModel: ['pooler.dense.weight', 'pooler.dense.b
ias']
- This IS expected if you are initializing MPNetModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequen
ceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MPNetModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassificati
on model from a BertForSequenceClassification model).
INFO:torch_xla:Letting libtpu.so load fail during _XLAC import. libtpu.so will be loaded from `libtpu` Python package when the ComputationClient is created.
INFO:torch_xla:Using bundled libtpu.so (/home/void/miniconda3/envs/siamenc/lib/python3.8/site-packages/torch_xla/lib/libtpu.so)
/home/void/miniconda3/envs/siamenc/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:487: UserWarning: You passed `Trainer(accelerat
or='tpu', precision='16-mixed')` but AMP with fp16 is not supported on TPUs. Using `precision='bf16-mixed'` instead.
rank_zero_warn(
GPU available: False, used: False
TPU available: True, using: 8 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
WARNING:root:Unsupported nprocs (8), ignoring...
Environment
Current environment
- CUDA:
- GPU: None
- available: False
- version: None
- Lightning:
- lightning: 2.1.0.dev0
- lightning-cloud: 0.5.37
- lightning-utilities: 0.9.0
- torch: 2.0.1
- torch-xla: 2.0
- torchmetrics: 0.11.4
- Packages:
- absl-py: 1.4.0
- aim: 3.17.5
- aim-ui: 3.17.5
- aimrecords: 0.0.7
- aimrocks: 0.4.0
- aiofiles: 23.1.0
- aiohttp: 3.8.4
- aiosignal: 1.3.1
- alabaster: 0.7.13
- alembic: 1.11.1
- annotated-types: 0.5.0
- anyio: 3.7.1
- arger: 1.4.8
- arrow: 1.2.3
- async-timeout: 4.0.2
- attrs: 23.1.0
- babel: 2.12.1
- backoff: 2.2.1
- base58: 2.0.1
- beautifulsoup4: 4.12.2
- blessed: 1.20.0
- boto3: 1.28.4
- botocore: 1.31.4
- build: 0.10.0
- cachecontrol: 0.12.14
- cachetools: 5.3.1
- cattrs: 23.1.2
- certifi: 2023.5.7
- cffi: 1.15.1
- charset-normalizer: 3.2.0
- cleo: 2.0.1
- click: 8.1.5
- cloud-tpu-client: 0.10
- coverage: 7.2.7
- crashtest: 0.4.1
- croniter: 1.4.1
- cryptography: 41.0.2
- datasets: 2.13.1
- dateutils: 0.6.12
- deepdiff: 6.3.1
- dill: 0.3.6
- distlib: 0.3.7
- docstring-to-markdown: 0.12
- docutils: 0.20.1
- dparse: 0.6.3
- dulwich: 0.21.5
- exceptiongroup: 1.1.2
- execnet: 2.0.2
- fastapi: 0.100.0
- filelock: 3.12.2
- frozenlist: 1.4.0
- fsspec: 2023.6.0
- gmpy2: 2.1.2
- google-api-core: 1.16.0
- google-api-python-client: 1.8.0
- google-auth: 1.6.3
- google-auth-httplib2: 0.1.0
- googleapis-common-protos: 1.59.1
- greenlet: 2.0.2
- grpcio: 1.56.0
- h11: 0.14.0
- html5lib: 1.1
- httplib2: 0.22.0
- huggingface-hub: 0.16.4
- idna: 3.4
- imagesize: 1.4.1
- importlib-metadata: 6.8.0
- importlib-resources: 6.0.0
- iniconfig: 2.0.0
- inquirer: 3.1.3
- installer: 0.7.0
- itsdangerous: 2.1.2
- jaraco.classes: 3.3.0
- jedi: 0.18.2
- jeepney: 0.8.0
- jinja2: 3.1.2
- jmespath: 1.0.1
- joblib: 1.3.1
- jsonschema: 4.18.4
- jsonschema-specifications: 2023.7.1
- keyring: 23.13.1
- lightning: 2.1.0.dev0
- lightning-cloud: 0.5.37
- lightning-utilities: 0.9.0
- lockfile: 0.12.2
- lsprotocol: 2023.0.0a2
- m2r2: 0.3.3.post2
- mako: 1.2.4
- markdown-it-py: 3.0.0
- markupsafe: 2.1.3
- mdurl: 0.1.2
- mistune: 0.8.4
- monotonic: 1.6
- mpmath: 1.3.0
- msgpack: 1.0.5
- multidict: 6.0.4
- multiprocess: 0.70.14
- mypy: 1.4.1
- mypy-extensions: 1.0.0
- networkx: 3.1
- nltk: 3.8.1
- numpy: 1.24.4
- oauth2client: 4.1.3
- ordered-set: 4.1.0
- packaging: 23.1
- pandas: 2.0.3
- parso: 0.8.3
- pexpect: 4.8.0
- pillow: 10.0.0
- pip: 23.2
- pkginfo: 1.9.6
- pkgutil-resolve-name: 1.3.10
- platformdirs: 3.9.1
- pluggy: 1.2.0
- poetry-core: 1.6.1
- poetry-plugin-export: 1.4.0
- pprintpp: 0.4.0
- protobuf: 4.23.4
- psutil: 5.9.5
- ptyprocess: 0.7.0
- py3nvml: 0.2.7
- pyarrow: 12.0.1
- pyasn1: 0.5.0
- pyasn1-modules: 0.3.0
- pycparser: 2.21
- pydantic: 2.0.3
- pydantic-core: 2.3.0
- pygls: 1.0.2
- pygments: 2.15.1
- pyjwt: 2.8.0
- pyparsing: 3.1.0
- pyproject-hooks: 1.0.0
- pytest: 7.4.0
- pytest-clarity: 1.0.1
- pytest-cov: 4.1.0
- pytest-randomly: 3.13.0
- pytest-sugar: 0.9.7
- pytest-xdist: 3.3.1
- python-dateutil: 2.8.2
- python-editor: 1.0.4
- python-lsp-jsonrpc: 1.0.0
- python-lsp-server: 1.7.4
- python-multipart: 0.0.6
- pytz: 2023.3
- pyyaml: 6.0.1
- rapidfuzz: 2.15.1
- readchar: 4.0.5
- referencing: 0.30.0
- regex: 2023.6.3
- requests: 2.31.0
- requests-toolbelt: 1.0.0
- restrictedpython: 6.1
- rich: 13.4.2
- rpds-py: 0.9.2
- rsa: 4.9
- ruamel.yaml: 0.17.32
- ruamel.yaml.clib: 0.2.7
- ruff: 0.0.278
- ruff-lsp: 0.0.35
- s3transfer: 0.6.1
- safetensors: 0.3.1
- safety: 2.3.5
- secretstorage: 3.3.3
- segment-analytics-python: 2.2.3
- setuptools: 68.0.0
- shellingham: 1.5.0.post1export PJRT_DEVICE=TPU
- siamenc: 2.0.0
- six: 1.16.0
- sniffio: 1.3.0
- snowballstemmer: 2.2.0
- soupsieve: 2.4.1
- sphinx: 7.0.1
- sphinx-autodoc-typehints: 1.23.3
- sphinxcontrib-applehelp: 1.0.4
- sphinxcontrib-devhelp: 1.0.2
- sphinxcontrib-htmlhelp: 2.0.1
- sphinxcontrib-jsmath: 1.0.1
- sphinxcontrib-qthelp: 1.0.3
- sphinxcontrib-serializinghtml: 1.1.5
- sqlalchemy: 1.4.49
- starlette: 0.27.0
- starsessions: 1.3.0
- sympy: 1.12
- termcolor: 2.3.0
- tokenizers: 0.13.3
- tomli: 2.0.1
- tomlkit: 0.11.8
- torch: 2.0.1
- torch-xla: 2.0
- torchmetrics: 0.11.4
- tqdm: 4.65.0
- traitlets: 5.9.0
- transformers: 4.30.2
- trove-classifiers: 2023.7.6
- typeguard: 3.0.2
- typing-extensions: 4.7.1
- tzdata: 2023.3
- ujson: 5.8.0
- uritemplate: 3.0.1
- urllib3: 1.26.16
- uvicorn: 0.23.1
- virtualenv: 20.24.0
- wcwidth: 0.2.6
- webencodings: 0.5.1
- websocket-client: 1.6.1
- websockets: 11.0.3
- wheel: 0.38.4
- xmltodict: 0.13.0
- xxhash: 3.2.0
- yarl: 1.9.2
- zipp: 3.16.2
- System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.8.17
- release: 5.13.0-1027-gcp
- version: #32~20.04.1-Ubuntu SMP Thu May 26 10:53:08 UTC 2022
More info
The trainer simply gets stuck after the message
WARNING:root:Unsupported nprocs (8), ignoring…
Pressing Ctrl-C leads to
Messages like Process ForkProcess-4 Process ForkProcess-5 Process ForkProcess-2 Process ForkProcess-3
and it seems to be getting stuck somehwere
File "/home/void/miniconda3/envs/siamenc/lib/python3.8/concurrent/futures/process.py", line 233, in _process_worker
call_item = call_queue.get(block=True)
File "/home/void/miniconda3/envs/siamenc/lib/python3.8/concurrent/futures/process.py", line 233, in _process_worker
call_item = call_queue.get(block=True)
File "/home/void/miniconda3/envs/siamenc/lib/python3.8/concurrent/futures/process.py", line 233, in _process_worker
call_item = call_queue.get(block=True)
File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/queues.py", line 96, in get
with self._rlock:
File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/queues.py", line 96, in get
with self._rlock:
File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/queues.py", line 96, in get
with self._rlock:
File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
return self._semlock.__enter__()
File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
return self._semlock.__enter__()
File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
return self._semlock.__enter__()
Traceback (most recent call last):
KeyboardInterrupt
KeyboardInterrupt
KeyboardInterrupt
File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/void/miniconda3/envs/siamenc/lib/python3.8/concurrent/futures/process.py", line 233, in _process_worker
call_item = call_queue.get(block=True)
File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/queues.py", line 97, in get
res = self._recv_bytes()
File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/connection.py", line 216, in recv_bytes
buf = self._recv_bytes(maxlength)
File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/connection.py", line 414, in _recv_bytes
buf = self._recv(4)
File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/connection.py", line 379, in _recv
chunk = read(handle, remaining)
KeyboardInterrupt
etc.
About this issue
- Original URL
- State: open
- Created a year ago
- Reactions: 2
- Comments: 16 (5 by maintainers)
@gkroiz @carmocca It seems like in theory v2 and v3 should be supported in the PJRT runtime.
But in v2 and v3 PJRT causes each process to run multi-threaded. On v4 each process only runs single threaded. https://pytorch.org/xla/master/#multithreading-on-tpu-v2v3
I am guessing this is the issue probably because of GIL and some kind of non thread safe code in lightning causing issues because of mixing multiprocessing and multithreading .
@vikigenius I also noticed that you are using PJRT runtime, which is not fully supported by Lightning, as far as I understand. I was able to run the training using
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
.Also, if you enable the PJRT runtime, the only way to disable it is to restart the node.