stable-baselines3: [Bug:] Cannot use the `fused` flag in default optimizer of PPO
The default Adam optimizer has a fused
flag, which, according to the docs, is significantly faster than the default when used on CUDA. Using it with PPO generates an exception, which complains that the parameters are not of type CUDA.
The fused
parameter can be specified to PPO using policy_kwargs = dict(optimizer_kwargs={'fused': True})
.
But, the issue is in the following lines of code:
https://github.com/DLR-RM/stable-baselines3/blob/c8fda060d4bcb283eea6ddd385be5a46a54d3356/stable_baselines3/common/on_policy_algorithm.py#L133-L136
Before line 133 above, the correct device has been initialised in self.device
. But, the policy_class
is initialized without it in line 133, so it initialises with the cpu
device, and that also initialises the optimizer with the cpu
device. In line 136, the device of the policy_class
is updated to the correct one, but by then it is too late, because the optimizer had already been initialized, and it thought the device was cpu
.
This is a problem with the fused
flag, because the Adam optimiser does check it and then double-checks self.parameters()
to ensure they are of the correct type, and complains, in my case, that it is not of cuda
type.
If the policy_class
in line 133 above was passed the correct device (i.e. self.device
) in the initialization in the first place, it could set it correctly before MlpExtractor
gets initialized. MlpExtractor
gets initialized to the parent class’s device in the lines below:
https://github.com/DLR-RM/stable-baselines3/blob/c8fda060d4bcb283eea6ddd385be5a46a54d3356/stable_baselines3/common/policies.py#L568-L581
Here is the traceback I get:
Traceback (most recent call last):
File "x.py", line 350, in <module>
main(sys.argv)
File "x.py", line 254, in main
model = PPO(
^^^^
File "/z/venv/lib64/python3.11/site-packages/stable_baselines3/ppo/ppo.py", line 164, in __init__
self._setup_model()
File "/z/venv/lib64/python3.11/site-packages/stable_baselines3/ppo/ppo.py", line 167, in _setup_model
super()._setup_model()
File "/z/venv/lib64/python3.11/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 123, in _setup_model
self.policy = self.policy_class( # type: ignore[assignment]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/z/venv/lib64/python3.11/site-packages/stable_baselines3/common/policies.py", line 857, in __init__
super().__init__(
File "/z/venv/lib64/python3.11/site-packages/stable_baselines3/common/policies.py", line 507, in __init__
self._build(lr_schedule)
File "/z/venv/lib64/python3.11/site-packages/stable_baselines3/common/policies.py", line 610, in _build
self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/z/venv/lib64/python3.11/site-packages/torch/optim/adam.py", line 60, in __init__
raise RuntimeError("`fused=True` requires all the params to be floating point Tensors of "
RuntimeError: `fused=True` requires all the params to be floating point Tensors of supported devices: ['cuda', 'xpu', 'privateuseone'].
About this issue
- Original URL
- State: closed
- Created 7 months ago
- Comments: 15 (6 by maintainers)
Actually my PR is generic to anything that inherits from
OnPolicyAlgorithms
. It doesn’t actually set thefused
flag, it just makes it possible to be set with the existing SB3 API. That is why I have submitted it as a generic bug fix.I will update this issue and the PR with benchmark results when I am able to run them.
Thanks for the suggestion. I will consider sbx in the future, but for my current project I’ll have to stick with SB3.