pytorch-lightning: WandbLogger crashes when used with TPU VM
Bug description
On a TPU VM, using WandbLogger
causes training to crash. I am using the nightly build which I know states “no guarantees”, so apologies in advance if this is currently being worked on (I wasn’t able to find any relevant issues or PRs). I am also unsure of why this error is occurring, and whether it is an issue with Lightning or WandB.
What version are you seeing the problem on?
master
How to reproduce the bug
import lightning.pytorch as pl
import lightning.pytorch.loggers
import torch
import torch.backends.cudnn
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
class LinearDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
w = torch.randn([128])
eps = 0.01 * torch.randn([2400, 1])
self.X = torch.randn([2400, 128])
self.Y = torch.sum(w * self.X, dim=-1, keepdim=True) + eps
def loader(self):
return DataLoader(
TensorDataset(self.X, self.Y),
batch_size=100,
num_workers=4,
shuffle=True,
)
def train_dataloader(self):
return self.loader()
def val_dataloader(self):
return self.loader()
class LinearRegression(pl.LightningModule):
def __init__(self):
super().__init__()
self.proj = nn.Linear(128, 1)
def step(self, batch, split):
X, y = batch
loss = F.mse_loss(self.proj(X), y)
self.log(f"{split}/loss", loss, sync_dist=(split != "train"), prog_bar=True)
return loss
def training_step(self, batch, batch_idx):
return self.step(batch, "train")
def validation_step(self, batch, batch_idx):
return self.step(batch, "val")
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.1)
def train():
pl.seed_everything(100, workers=True)
data = LinearDataModule()
model = LinearRegression()
trainer = pl.Trainer(
accelerator="tpu",
devices=8,
enable_checkpointing=False,
precision="bf16-mixed",
logger=pl.loggers.WandbLogger(project="tpu_debug"),
max_epochs=100,
enable_progress_bar=True,
)
trainer.fit(model=model, datamodule=data)
if __name__ == "__main__":
train()
The above code was written to a file train.py
and run with
PJRT_DEVICE=TPU python3 -m train
Error messages and logs
Global seed set to 100
GPU available: False, used: False
TPU available: True, using: 8 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
wandb: Currently logged in as: .... Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.15.5
wandb: Run data is saved locally in ./wandb/run-20230710_212512-3vghnzj8
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run fine-music-19
wandb: ⭐️ View project at https://wandb.ai/...
wandb: 🚀 View run at https://wandb.ai/.../runs/3vghnzj8
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
shandler(sreq)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
shandler(sreq)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
shandler(sreq)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
shandler(sreq)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
shandler(sreq)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
shandler(sreq)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
shandler(sreq)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
self.run()
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 112, in run
shandler(sreq)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 152, in server_inform_attach
dict(self._mux._streams[stream_id]._settings),
KeyError: '3vghnzj8'
wandb: ERROR Unable to attach to run 3vghnzj8
concurrent.futures.process._RemoteTraceback:
"""
Traceback (most recent call last):
File "/usr/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
r = call_item.fn(*call_item.args, **call_item.kwargs)
File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in _process_chunk
return [fn(*args) for args in chunk]
File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in <listcomp>
return [fn(*args) for args in chunk]
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 245, in _run_thread_per_device
replica_results = list(
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
yield fs.pop().result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
File "/usr/lib/python3.8/concurrent/futures/thread.py", line 57, in run
result = self.fn(*self.args, **self.kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 238, in _thread_fn
return fn()
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 341, in __call__
self.fn(global_ordinal(), *self.args, **self.kwargs)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/xla.py", line 128, in _wrapping_function
trainer, function, args, kwargs = copy.deepcopy((trainer, function, args, kwargs))
File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/usr/lib/python3.8/copy.py", line 210, in _deepcopy_tuple
y = [deepcopy(a, memo) for a in x]
File "/usr/lib/python3.8/copy.py", line 210, in <listcomp>
y = [deepcopy(a, memo) for a in x]
File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
state = deepcopy(state, memo)
File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/usr/lib/python3.8/copy.py", line 205, in _deepcopy_list
append(deepcopy(a, memo))
File "/usr/lib/python3.8/copy.py", line 161, in deepcopy
rv = reductor(4)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py", line 356, in __getstate__
_ = self.experiment
File "/.../.local/lib/python3.8/site-packages/lightning/fabric/loggers/logger.py", line 114, in experiment
return fn(self)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py", line 398, in experiment
self._experiment = wandb._attach(attach_id)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/wandb_init.py", line 877, in _attach
raise UsageError(f"Unable to attach to run {attach_id}")
wandb.errors.UsageError: Unable to attach to run 3vghnzj8
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/...", line 76, in <module>
train()
File "/...", line 72, in train
trainer.fit(model=model, datamodule=data)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 536, in fit
call._call_and_handle_interrupt(
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 41, in _call_and_handle_interrupt
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/xla.py", line 88, in launch
process_context = xmp.spawn(
File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 386, in spawn
return pjrt.spawn(fn, nprocs, start_method, args)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 365, in spawn
_run_multiprocess(spawn_fn, start_method=start_method)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 322, in _run_multiprocess
replica_results = list(
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 323, in <genexpr>
itertools.chain.from_iterable(
File "/usr/lib/python3.8/concurrent/futures/process.py", line 484, in _chain_from_iterable_of_lists
for element in iterable:
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
yield fs.pop().result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
wandb.errors.UsageError: Unable to attach to run 3vghnzj8
concurrent.futures.process._RemoteTraceback:
"""
Traceback (most recent call last):
File "/usr/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
r = call_item.fn(*call_item.args, **call_item.kwargs)
File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in _process_chunk
return [fn(*args) for args in chunk]
File "/usr/lib/python3.8/concurrent/futures/process.py", line 198, in <listcomp>
return [fn(*args) for args in chunk]
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 245, in _run_thread_per_device
replica_results = list(
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
yield fs.pop().result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
File "/usr/lib/python3.8/concurrent/futures/thread.py", line 57, in run
result = self.fn(*self.args, **self.kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 238, in _thread_fn
return fn()
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 341, in __call__
self.fn(global_ordinal(), *self.args, **self.kwargs)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/xla.py", line 128, in _wrapping_function
trainer, function, args, kwargs = copy.deepcopy((trainer, function, args, kwargs))
File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/usr/lib/python3.8/copy.py", line 210, in _deepcopy_tuple
y = [deepcopy(a, memo) for a in x]
File "/usr/lib/python3.8/copy.py", line 210, in <listcomp>
y = [deepcopy(a, memo) for a in x]
File "/usr/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/usr/lib/python3.8/copy.py", line 270, in _reconstruct
state = deepcopy(state, memo)
File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/usr/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/usr/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/usr/lib/python3.8/copy.py", line 205, in _deepcopy_list
append(deepcopy(a, memo))
File "/usr/lib/python3.8/copy.py", line 161, in deepcopy
rv = reductor(4)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py", line 356, in __getstate__
_ = self.experiment
File "/.../.local/lib/python3.8/site-packages/lightning/fabric/loggers/logger.py", line 114, in experiment
return fn(self)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/loggers/wandb.py", line 398, in experiment
self._experiment = wandb._attach(attach_id)
File "/.../.local/lib/python3.8/site-packages/wandb/sdk/wandb_init.py", line 877, in _attach
raise UsageError(f"Unable to attach to run {attach_id}")
wandb.errors.UsageError: Unable to attach to run 3vghnzj8
"""
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/...", line 76, in <module>
train()
File "/...", line 72, in train
trainer.fit(model=model, datamodule=data)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 536, in fit
call._call_and_handle_interrupt(
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 41, in _call_and_handle_interrupt
return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
File "/.../.local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/xla.py", line 88, in launch
process_context = xmp.spawn(
File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 386, in spawn
return pjrt.spawn(fn, nprocs, start_method, args)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 365, in spawn
_run_multiprocess(spawn_fn, start_method=start_method)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 92, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 322, in _run_multiprocess
replica_results = list(
File "/usr/local/lib/python3.8/dist-packages/torch_xla/experimental/pjrt.py", line 323, in <genexpr>
itertools.chain.from_iterable(
File "/usr/lib/python3.8/concurrent/futures/process.py", line 484, in _chain_from_iterable_of_lists
for element in iterable:
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
yield fs.pop().result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
wandb.errors.UsageError: Unable to attach to run 3vghnzj8
Environment
- TPU type:
v3-8
- TPU software version:
tpu-vm-pt-2.0
- Packages: other than the packages that come shipped with the TPU VM, I installed Lightning and WandB. A list of precise versions is listed:
Package Version
------------------------ --------------------
absl-py 1.4.0
anyio 3.7.1
appdirs 1.4.4
arrow 1.2.3
attrs 19.3.0
Automat 0.8.0
backoff 2.2.1
beautifulsoup4 4.12.2
blessed 1.20.0
blinker 1.4
cachetools 5.3.0
certifi 2019.11.28
chardet 3.0.4
charset-normalizer 2.0.12
click 8.1.4
cloud-init 22.1
cloud-tpu-client 0.10
cmake 3.26.0
colorama 0.4.3
command-not-found 0.3
configobj 5.0.6
constantly 15.1.0
croniter 1.4.1
cryptography 2.8
Cython 0.29.14
dateutils 0.6.12
dbus-python 1.2.16
deepdiff 6.3.1
distlib 0.3.4
distro 1.4.0
distro-info 0.23ubuntu1
docker-pycreds 0.4.0
entrypoints 0.3
exceptiongroup 1.1.2
fastapi 0.100.0
filelock 3.7.1
fsspec 2023.6.0
gitdb 4.0.10
GitPython 3.1.32
google-api-core 1.34.0
google-api-python-client 1.8.0
google-auth 2.16.2
google-auth-httplib2 0.1.0
googleapis-common-protos 1.58.0
h11 0.14.0
httplib2 0.14.0
hyperlink 19.0.0
idna 2.8
importlib-metadata 1.5.0
incremental 16.10.1
inquirer 3.1.3
intel-openmp 2022.1.0
itsdangerous 2.1.2
Jinja2 2.10.1
jsonpatch 1.22
jsonpointer 2.0
jsonschema 3.2.0
keyring 18.0.1
language-selector 0.1
launchpadlib 1.10.13
lazr.restfulclient 0.14.2
lazr.uri 1.0.3
libtpu-nightly 0.1.dev20230213
lightning 2.1.0.dev0
lightning-cloud 0.5.37
lightning-utilities 0.9.0
lit 15.0.7
markdown-it-py 3.0.0
MarkupSafe 1.1.0
mdurl 0.1.2
mkl 2022.1.0
mkl-include 2022.1.0
more-itertools 4.2.0
mpmath 1.3.0
netifaces 0.10.4
networkx 3.0
numpy 1.24.2
nvidia-cublas-cu11 11.10.3.66
nvidia-cuda-cupti-cu11 11.7.101
nvidia-cuda-nvrtc-cu11 11.7.99
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cudnn-cu11 8.5.0.96
nvidia-cufft-cu11 10.9.0.58
nvidia-curand-cu11 10.2.10.91
nvidia-cusolver-cu11 11.4.0.1
nvidia-cusparse-cu11 11.7.4.91
nvidia-nccl-cu11 2.14.3
nvidia-nvtx-cu11 11.7.91
oauth2client 4.1.3
oauthlib 3.1.0
ordered-set 4.1.0
packaging 20.3
pathtools 0.1.2
pexpect 4.6.0
Pillow 9.4.0
pip 20.0.2
platformdirs 2.5.2
protobuf 3.20.3
psutil 5.9.5
pyasn1 0.4.2
pyasn1-modules 0.2.1
pydantic 1.10.11
Pygments 2.15.1
PyGObject 3.36.0
PyHamcrest 1.9.0
PyJWT 1.7.1
pymacaroons 0.13.0
PyNaCl 1.3.0
pyOpenSSL 19.0.0
pyparsing 2.4.6
pyrsistent 0.15.5
pyserial 3.4
python-apt 2.0.0+ubuntu0.20.4.7
python-dateutil 2.8.2
python-debian 0.1.36ubuntu1
python-editor 1.0.4
python-multipart 0.0.6
pytorch-lightning 2.0.5
pytz 2023.3
PyYAML 5.4.1
readchar 4.0.5
requests 2.27.1
requests-unixsocket 0.2.0
rich 13.4.2
rsa 4.9
SecretStorage 2.3.1
sentry-sdk 1.28.0
service-identity 18.1.0
setproctitle 1.3.2
setuptools 62.3.2
simplejson 3.16.0
six 1.14.0
smmap 5.0.0
sniffio 1.3.0
sos 4.3
soupsieve 2.4.1
ssh-import-id 5.10
starlette 0.27.0
starsessions 1.3.0
sympy 1.11.1
systemd-python 234
tbb 2021.6.0
torch 2.0.0
torch-xla 2.0
torchmetrics 1.0.0
torchvision 0.15.1
tqdm 4.65.0
traitlets 5.9.0
triton 2.0.0
Twisted 18.9.0
typing-extensions 4.5.0
ubuntu-advantage-tools 27.8
ufw 0.36
unattended-upgrades 0.1
uritemplate 3.0.1
urllib3 1.26.16
uvicorn 0.22.0
virtualenv 20.14.1
wadllib 1.3.3
wandb 0.15.5
wcwidth 0.2.6
websocket-client 1.6.1
websockets 11.0.3
wheel 0.34.2
zipp 1.0.0
zope.interface 4.7.1
More info
If I train without a logger instead, then no error occurs and the script proceeds normally.
cc @carmocca @JackCaoG @steventk-g @Liyang90 @awaelchli @morganmcg1 @borisdayma @scottire @parambharat
About this issue
- Original URL
- State: open
- Created a year ago
- Comments: 16 (7 by maintainers)
One way to reduce the surface of issues would be to do