ray: [Core][Tune] Ray tune cannot be used with pytorch-lightning 1.7.0 due to processes spawned with fork.
What happened + What you expected to happen
As part of Add support for DDP fork included in pytorch-lightning 1.7.0 calls to:
torch.cuda.device_count()
torch.cuda.is_available()
in the pytorch lightning codebase were replaced with new functions:
pytorch_lightning.utilities.device_parser.num_cuda_devices()
pytorch_lightning.utilities.device_parser.is_cuda_available()
These functions internally create a multiprocessing.Pool with fork
with multiprocessing.get_context("fork").Pool(1) as pool:
return pool.apply(torch.cuda.device_count)
This call waits forever when run inside an Actor.
(train pid=139, ip=172.22.0.3) File "/usr/lib/python3.8/threading.py", line 890, in _bootstrap
(train pid=139, ip=172.22.0.3) self._bootstrap_inner()
(train pid=139, ip=172.22.0.3) File "/usr/lib/python3.8/threading.py", line 932, in _bootstrap_inner
(train pid=139, ip=172.22.0.3) self.run()
(train pid=139, ip=172.22.0.3) File "/usr/local/lib/python3.8/dist-packages/ray/tune/function_runner.py", line 277, in run
(train pid=139, ip=172.22.0.3) self._entrypoint()
(train pid=139, ip=172.22.0.3) File "/usr/local/lib/python3.8/dist-packages/ray/tune/function_runner.py", line 349, in entrypoint
(train pid=139, ip=172.22.0.3) return self._trainable_func(
(train pid=139, ip=172.22.0.3) File "/usr/local/lib/python3.8/dist-packages/ray/util/tracing/tracing_helper.py", line 462, in _resume_span
(train pid=139, ip=172.22.0.3) return method(self, *_args, **_kwargs)
(train pid=139, ip=172.22.0.3) File "/usr/local/lib/python3.8/dist-packages/ray/tune/function_runner.py", line 645, in _trainable_func
(train pid=139, ip=172.22.0.3) output = fn()
(train pid=139, ip=172.22.0.3) File "test.py", line 9, in train
(train pid=139, ip=172.22.0.3) pl.Trainer(
(train pid=139, ip=172.22.0.3) File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/utilities/argparse.py", line 345, in insert_env_defaults
(train pid=139, ip=172.22.0.3) return fn(self, **kwargs)
(train pid=139, ip=172.22.0.3) File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 537, in __init__
(train pid=139, ip=172.22.0.3) self._setup_on_init(num_sanity_val_steps)
(train pid=139, ip=172.22.0.3) File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 618, in _setup_on_init
(train pid=139, ip=172.22.0.3) self._log_device_info()
(train pid=139, ip=172.22.0.3) File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/trainer.py", line 1739, in _log_device_info
(train pid=139, ip=172.22.0.3) if CUDAAccelerator.is_available():
(train pid=139, ip=172.22.0.3) File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/accelerators/cuda.py", line 91, in is_available
(train pid=139, ip=172.22.0.3) return device_parser.num_cuda_devices() > 0
(train pid=139, ip=172.22.0.3) File "/usr/local/lib/python3.8/dist-packages/pytorch_lightning/utilities/device_parser.py", line 346, in num_cuda_devices
(train pid=139, ip=172.22.0.3) return pool.apply(torch.cuda.device_count)
(train pid=139, ip=172.22.0.3) File "/usr/lib/python3.8/multiprocessing/pool.py", line 736, in __exit__
(train pid=139, ip=172.22.0.3) self.terminate()
(train pid=139, ip=172.22.0.3) File "/usr/lib/python3.8/multiprocessing/pool.py", line 654, in terminate
(train pid=139, ip=172.22.0.3) self._terminate()
(train pid=139, ip=172.22.0.3) File "/usr/lib/python3.8/multiprocessing/util.py", line 224, in __call__
(train pid=139, ip=172.22.0.3) res = self._callback(*self._args, **self._kwargs)
(train pid=139, ip=172.22.0.3) File "/usr/lib/python3.8/multiprocessing/pool.py", line 729, in _terminate_pool
(train pid=139, ip=172.22.0.3) p.join()
(train pid=139, ip=172.22.0.3) File "/usr/lib/python3.8/multiprocessing/process.py", line 149, in join
(train pid=139, ip=172.22.0.3) res = self._popen.wait(timeout)
(train pid=139, ip=172.22.0.3) File "/usr/lib/python3.8/multiprocessing/popen_fork.py", line 47, in wait
(train pid=139, ip=172.22.0.3) return self.poll(os.WNOHANG if timeout == 0.0 else 0)
(train pid=139, ip=172.22.0.3) File "/usr/lib/python3.8/multiprocessing/popen_fork.py", line 27, in poll
(train pid=139, ip=172.22.0.3) pid, sts = os.waitpid(self.pid, flag)
This is a critical breaking change given that pytorch_lightning.Trainer calls these methods and therefore cannot be used.
The reproduction script below always hangs. However during my experimentation I found that creating a minimal reproduction script was difficult. Sometimes a script will work, and fail when re-running it. Sometimes changing a seemingly unrelated line of code makes a working script fail. I haven’t dived deep enough into the Ray codebase to understand why this is the case.
For my larger projects ray-tune simply cannot be used with pytorch-lightning 1.7.0 as these calls aways hang. My current workaround is to monkeypatch torch.multiprocessing.get_all_start_methods.
patched_start_methods = [m for m in torch.multiprocessing.get_all_start_methods() if m != "fork"]
torch.multiprocessing.get_all_start_methods = lambda: patched_start_methods
As far as I can tell it is known that ray does not work with forked processes https://discuss.ray.io/t/best-solution-to-have-multiprocess-working-in-actor/2165/8. However given that pytorch-lightning is a such a widely used library in the ML ecosystem this issue may be worth looking into.
Versions / Dependencies
ray-tune 1.13.0 pytorch 1.12.0 pytorch-lightning 1.7.0 python 3.8.10 OS: Ubuntu 20.04.4 LTS
Reproduction script
import pytorch_lightning as pl
from ray import tune
def train(config):
pl.Trainer(accelerator="gpu", devices=1)
def run():
tune.run(
train,
resources_per_trial={"cpu": 8, "gpu": 1},
log_to_file=["stdout.txt", "stderr.txt"], # For some reason removing this line makes the script work
config={},
num_samples=1,
name="Test",
)
if __name__ == "__main__":
run()
Submitted to a ray cluster with
ray job submit --runtime-env-json='{"working_dir": "./"}' -- python test.py
Issue Severity
Medium: It is a significant difficulty but I can work around it.
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Reactions: 1
- Comments: 15 (9 by maintainers)
Nice work, great fix @krfricke @amogkam !
Btw in the mean time we have worked with PyTorch to remove this hack on the Lightning side. First we proposed some changes on the PyTorch side (https://github.com/pytorch/pytorch/issues/83973), after they landed ported the changes back to Lightning (https://github.com/Lightning-AI/lightning/pull/14631). Finally, in PyTorch >=1.14 (and some future Lightning version), this hack will no longer be necessary (https://github.com/Lightning-AI/lightning/pull/15110) and then eventually Ray can drop this workaround too! ❤️
Another workaround is
runtime_env={"env_vars": {"PL_DISABLE_FORK": "1"}}Update:
I hope this will unblock you soon. Thank you.
@xwjiang2010 , the support to pytorch 1.7 for ray lightning is work in process (https://github.com/ray-project/ray_lightning/issues/194)
cc @JiahaoYao