pytorch-lightning: Unable to train on v3-8 TPUs with lightning. Training is stuck/deadlocked ?

Bug description

The training code simply gets stuck on the TPU.

What version are you seeing the problem on?

master

How to reproduce the bug

Just used the following calls to trainer and fit.

    pl.seed_everything(7, workers=True)
    torch.set_float32_matmul_precision("high")
    model = SiameseEncoder(model_name_or_path)
    datamodule = RetrievalDataModule(
        model_name_or_path,
        {"train": str(train_path), "val": str(val_path)},
        {"train": train_batch_size, "val": val_batch_size},
        padding_style,
        workers=workers,
    )
    monitor = "rec@1"

    # TODO Add clipping, control validation intervals etc. Lots of work to be done
    trainer = pl.Trainer(
        logger=AimLogger(experiment="SiameseEncoder"),
        accelerator=accelerator,
        devices=devices,
        deterministic=True,
        max_epochs=max_epochs,
        val_check_interval=0.1,
        gradient_clip_val=1,
        precision="16-mixed",
        callbacks=[
            EarlyStopping(monitor=monitor, mode="max", patience=10),
            ModelCheckpoint(monitor=monitor, mode="max", save_top_k=1),
        ],
    )
    trainer.fit(model, datamodule=datamodule)

I also set export PJRT_DEVICE=TPU before calling the trainer code from CLI.

Error messages and logs

Global seed set to 7
Some weights of the model checkpoint at sentence-transformers/multi-qa-mpnet-base-cos-v1 were not used when initializing MPNetModel: ['pooler.dense.weight', 'pooler.dense.b
ias']
- This IS expected if you are initializing MPNetModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequen
ceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MPNetModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassificati
on model from a BertForSequenceClassification model).
INFO:torch_xla:Letting libtpu.so load fail during _XLAC import. libtpu.so will be loaded from `libtpu` Python package when the ComputationClient is created.
INFO:torch_xla:Using bundled libtpu.so (/home/void/miniconda3/envs/siamenc/lib/python3.8/site-packages/torch_xla/lib/libtpu.so)
/home/void/miniconda3/envs/siamenc/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:487: UserWarning: You passed `Trainer(accelerat
or='tpu', precision='16-mixed')` but AMP with fp16 is not supported on TPUs. Using `precision='bf16-mixed'` instead.
  rank_zero_warn(
GPU available: False, used: False
TPU available: True, using: 8 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
WARNING:root:Unsupported nprocs (8), ignoring...

Environment

Current environment
  • CUDA:
    • GPU: None
    • available: False
    • version: None
  • Lightning:
    • lightning: 2.1.0.dev0
    • lightning-cloud: 0.5.37
    • lightning-utilities: 0.9.0
    • torch: 2.0.1
    • torch-xla: 2.0
    • torchmetrics: 0.11.4
  • Packages:
    • absl-py: 1.4.0
    • aim: 3.17.5
    • aim-ui: 3.17.5
    • aimrecords: 0.0.7
    • aimrocks: 0.4.0
    • aiofiles: 23.1.0
    • aiohttp: 3.8.4
    • aiosignal: 1.3.1
    • alabaster: 0.7.13
    • alembic: 1.11.1
    • annotated-types: 0.5.0
    • anyio: 3.7.1
    • arger: 1.4.8
    • arrow: 1.2.3
    • async-timeout: 4.0.2
    • attrs: 23.1.0
    • babel: 2.12.1
    • backoff: 2.2.1
    • base58: 2.0.1
    • beautifulsoup4: 4.12.2
    • blessed: 1.20.0
    • boto3: 1.28.4
    • botocore: 1.31.4
    • build: 0.10.0
    • cachecontrol: 0.12.14
    • cachetools: 5.3.1
    • cattrs: 23.1.2
    • certifi: 2023.5.7
    • cffi: 1.15.1
    • charset-normalizer: 3.2.0
    • cleo: 2.0.1
    • click: 8.1.5
    • cloud-tpu-client: 0.10
    • coverage: 7.2.7
    • crashtest: 0.4.1
    • croniter: 1.4.1
    • cryptography: 41.0.2
    • datasets: 2.13.1
    • dateutils: 0.6.12
    • deepdiff: 6.3.1
    • dill: 0.3.6
    • distlib: 0.3.7
    • docstring-to-markdown: 0.12
    • docutils: 0.20.1
    • dparse: 0.6.3
    • dulwich: 0.21.5
    • exceptiongroup: 1.1.2
    • execnet: 2.0.2
    • fastapi: 0.100.0
    • filelock: 3.12.2
    • frozenlist: 1.4.0
    • fsspec: 2023.6.0
    • gmpy2: 2.1.2
    • google-api-core: 1.16.0
    • google-api-python-client: 1.8.0
    • google-auth: 1.6.3
    • google-auth-httplib2: 0.1.0
    • googleapis-common-protos: 1.59.1
    • greenlet: 2.0.2
    • grpcio: 1.56.0
    • h11: 0.14.0
    • html5lib: 1.1
    • httplib2: 0.22.0
    • huggingface-hub: 0.16.4
    • idna: 3.4
    • imagesize: 1.4.1
    • importlib-metadata: 6.8.0
    • importlib-resources: 6.0.0
    • iniconfig: 2.0.0
    • inquirer: 3.1.3
    • installer: 0.7.0
    • itsdangerous: 2.1.2
    • jaraco.classes: 3.3.0
    • jedi: 0.18.2
    • jeepney: 0.8.0
    • jinja2: 3.1.2
    • jmespath: 1.0.1
    • joblib: 1.3.1
    • jsonschema: 4.18.4
    • jsonschema-specifications: 2023.7.1
    • keyring: 23.13.1
    • lightning: 2.1.0.dev0
    • lightning-cloud: 0.5.37
    • lightning-utilities: 0.9.0
    • lockfile: 0.12.2
    • lsprotocol: 2023.0.0a2
    • m2r2: 0.3.3.post2
    • mako: 1.2.4
    • markdown-it-py: 3.0.0
    • markupsafe: 2.1.3
    • mdurl: 0.1.2
    • mistune: 0.8.4
    • monotonic: 1.6
    • mpmath: 1.3.0
    • msgpack: 1.0.5
    • multidict: 6.0.4
    • multiprocess: 0.70.14
    • mypy: 1.4.1
    • mypy-extensions: 1.0.0
    • networkx: 3.1
    • nltk: 3.8.1
    • numpy: 1.24.4
    • oauth2client: 4.1.3
    • ordered-set: 4.1.0
    • packaging: 23.1
    • pandas: 2.0.3
    • parso: 0.8.3
    • pexpect: 4.8.0
    • pillow: 10.0.0
    • pip: 23.2
    • pkginfo: 1.9.6
    • pkgutil-resolve-name: 1.3.10
    • platformdirs: 3.9.1
    • pluggy: 1.2.0
    • poetry-core: 1.6.1
    • poetry-plugin-export: 1.4.0
    • pprintpp: 0.4.0
    • protobuf: 4.23.4
    • psutil: 5.9.5
    • ptyprocess: 0.7.0
    • py3nvml: 0.2.7
    • pyarrow: 12.0.1
    • pyasn1: 0.5.0
    • pyasn1-modules: 0.3.0
    • pycparser: 2.21
    • pydantic: 2.0.3
    • pydantic-core: 2.3.0
    • pygls: 1.0.2
    • pygments: 2.15.1
    • pyjwt: 2.8.0
    • pyparsing: 3.1.0
    • pyproject-hooks: 1.0.0
    • pytest: 7.4.0
    • pytest-clarity: 1.0.1
    • pytest-cov: 4.1.0
    • pytest-randomly: 3.13.0
    • pytest-sugar: 0.9.7
    • pytest-xdist: 3.3.1
    • python-dateutil: 2.8.2
    • python-editor: 1.0.4
    • python-lsp-jsonrpc: 1.0.0
    • python-lsp-server: 1.7.4
    • python-multipart: 0.0.6
    • pytz: 2023.3
    • pyyaml: 6.0.1
    • rapidfuzz: 2.15.1
    • readchar: 4.0.5
    • referencing: 0.30.0
    • regex: 2023.6.3
    • requests: 2.31.0
    • requests-toolbelt: 1.0.0
    • restrictedpython: 6.1
    • rich: 13.4.2
    • rpds-py: 0.9.2
    • rsa: 4.9
    • ruamel.yaml: 0.17.32
    • ruamel.yaml.clib: 0.2.7
    • ruff: 0.0.278
    • ruff-lsp: 0.0.35
    • s3transfer: 0.6.1
    • safetensors: 0.3.1
    • safety: 2.3.5
    • secretstorage: 3.3.3
    • segment-analytics-python: 2.2.3
    • setuptools: 68.0.0
    • shellingham: 1.5.0.post1export PJRT_DEVICE=TPU
    • siamenc: 2.0.0
    • six: 1.16.0
    • sniffio: 1.3.0
    • snowballstemmer: 2.2.0
    • soupsieve: 2.4.1
    • sphinx: 7.0.1
    • sphinx-autodoc-typehints: 1.23.3
    • sphinxcontrib-applehelp: 1.0.4
    • sphinxcontrib-devhelp: 1.0.2
    • sphinxcontrib-htmlhelp: 2.0.1
    • sphinxcontrib-jsmath: 1.0.1
    • sphinxcontrib-qthelp: 1.0.3
    • sphinxcontrib-serializinghtml: 1.1.5
    • sqlalchemy: 1.4.49
    • starlette: 0.27.0
    • starsessions: 1.3.0
    • sympy: 1.12
    • termcolor: 2.3.0
    • tokenizers: 0.13.3
    • tomli: 2.0.1
    • tomlkit: 0.11.8
    • torch: 2.0.1
    • torch-xla: 2.0
    • torchmetrics: 0.11.4
    • tqdm: 4.65.0
    • traitlets: 5.9.0
    • transformers: 4.30.2
    • trove-classifiers: 2023.7.6
    • typeguard: 3.0.2
    • typing-extensions: 4.7.1
    • tzdata: 2023.3
    • ujson: 5.8.0
    • uritemplate: 3.0.1
    • urllib3: 1.26.16
    • uvicorn: 0.23.1
    • virtualenv: 20.24.0
    • wcwidth: 0.2.6
    • webencodings: 0.5.1
    • websocket-client: 1.6.1
    • websockets: 11.0.3
    • wheel: 0.38.4
    • xmltodict: 0.13.0
    • xxhash: 3.2.0
    • yarl: 1.9.2
    • zipp: 3.16.2
  • System:
    • OS: Linux
    • architecture:
      • 64bit
      • ELF
    • processor: x86_64
    • python: 3.8.17
    • release: 5.13.0-1027-gcp
    • version: #32~20.04.1-Ubuntu SMP Thu May 26 10:53:08 UTC 2022

More info

The trainer simply gets stuck after the message

WARNING:root:Unsupported nprocs (8), ignoring…

Pressing Ctrl-C leads to

Messages like Process ForkProcess-4 Process ForkProcess-5 Process ForkProcess-2 Process ForkProcess-3

and it seems to be getting stuck somehwere

  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/concurrent/futures/process.py", line 233, in _process_worker
    call_item = call_queue.get(block=True)
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/concurrent/futures/process.py", line 233, in _process_worker
    call_item = call_queue.get(block=True)
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/concurrent/futures/process.py", line 233, in _process_worker
    call_item = call_queue.get(block=True)
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/queues.py", line 96, in get
    with self._rlock:
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/queues.py", line 96, in get
    with self._rlock:
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/queues.py", line 96, in get
    with self._rlock:
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/synchronize.py", line 95, in __enter__
    return self._semlock.__enter__()
Traceback (most recent call last):
KeyboardInterrupt
KeyboardInterrupt
KeyboardInterrupt
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/concurrent/futures/process.py", line 233, in _process_worker
    call_item = call_queue.get(block=True)
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/queues.py", line 97, in get
    res = self._recv_bytes()
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/connection.py", line 414, in _recv_bytes
    buf = self._recv(4)
  File "/home/void/miniconda3/envs/siamenc/lib/python3.8/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
KeyboardInterrupt

etc.

cc @carmocca @JackCaoG @steventk-g @Liyang90

About this issue

  • Original URL
  • State: open
  • Created a year ago
  • Reactions: 2
  • Comments: 16 (5 by maintainers)

Most upvoted comments

@gkroiz @carmocca It seems like in theory v2 and v3 should be supported in the PJRT runtime.

But in v2 and v3 PJRT causes each process to run multi-threaded. On v4 each process only runs single threaded. https://pytorch.org/xla/master/#multithreading-on-tpu-v2v3

I am guessing this is the issue probably because of GIL and some kind of non thread safe code in lightning causing issues because of mixing multiprocessing and multithreading .

@vikigenius I also noticed that you are using PJRT runtime, which is not fully supported by Lightning, as far as I understand. I was able to run the training using export XRT_TPU_CONFIG="localservice;0;localhost:51011".

Also, if you enable the PJRT runtime, the only way to disable it is to restart the node.