stable-baselines3: [Bug:] Cannot use the `fused` flag in default optimizer of PPO

The default Adam optimizer has a fused flag, which, according to the docs, is significantly faster than the default when used on CUDA. Using it with PPO generates an exception, which complains that the parameters are not of type CUDA.

The fused parameter can be specified to PPO using policy_kwargs = dict(optimizer_kwargs={'fused': True}). But, the issue is in the following lines of code: https://github.com/DLR-RM/stable-baselines3/blob/c8fda060d4bcb283eea6ddd385be5a46a54d3356/stable_baselines3/common/on_policy_algorithm.py#L133-L136

Before line 133 above, the correct device has been initialised in self.device. But, the policy_class is initialized without it in line 133, so it initialises with the cpu device, and that also initialises the optimizer with the cpu device. In line 136, the device of the policy_class is updated to the correct one, but by then it is too late, because the optimizer had already been initialized, and it thought the device was cpu.

This is a problem with the fused flag, because the Adam optimiser does check it and then double-checks self.parameters() to ensure they are of the correct type, and complains, in my case, that it is not of cuda type.

If the policy_class in line 133 above was passed the correct device (i.e. self.device) in the initialization in the first place, it could set it correctly before MlpExtractor gets initialized. MlpExtractor gets initialized to the parent class’s device in the lines below: https://github.com/DLR-RM/stable-baselines3/blob/c8fda060d4bcb283eea6ddd385be5a46a54d3356/stable_baselines3/common/policies.py#L568-L581

Here is the traceback I get:

Traceback (most recent call last):
  File "x.py", line 350, in <module>
    main(sys.argv)
  File "x.py", line 254, in main
    model = PPO(
            ^^^^
  File "/z/venv/lib64/python3.11/site-packages/stable_baselines3/ppo/ppo.py", line 164, in __init__
    self._setup_model()
  File "/z/venv/lib64/python3.11/site-packages/stable_baselines3/ppo/ppo.py", line 167, in _setup_model
    super()._setup_model()
  File "/z/venv/lib64/python3.11/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 123, in _setup_model
    self.policy = self.policy_class(  # type: ignore[assignment]
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/z/venv/lib64/python3.11/site-packages/stable_baselines3/common/policies.py", line 857, in __init__
    super().__init__(
  File "/z/venv/lib64/python3.11/site-packages/stable_baselines3/common/policies.py", line 507, in __init__
    self._build(lr_schedule)
  File "/z/venv/lib64/python3.11/site-packages/stable_baselines3/common/policies.py", line 610, in _build
    self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/z/venv/lib64/python3.11/site-packages/torch/optim/adam.py", line 60, in __init__
    raise RuntimeError("`fused=True` requires all the params to be floating point Tensors of "
RuntimeError: `fused=True` requires all the params to be floating point Tensors of supported devices: ['cuda', 'xpu', 'privateuseone'].

About this issue

  • Original URL
  • State: closed
  • Created 7 months ago
  • Comments: 15 (6 by maintainers)

Most upvoted comments

The benchmark is still needed to know if it’s something that should be done on other algorithms or not, or if it should only be mentioned in the doc (with a link to your fork).

Actually my PR is generic to anything that inherits from OnPolicyAlgorithms. It doesn’t actually set the fused flag, it just makes it possible to be set with the existing SB3 API. That is why I have submitted it as a generic bug fix.

I will update this issue and the PR with benchmark results when I am able to run them.

Btw, as I mentioned before, if you want a real performance boost, you can have a look at https://github.com/araffin/sbx (usually faster when run on cpu only when not using CNN, see https://arxiv.org/abs/2310.05808)

Thanks for the suggestion. I will consider sbx in the future, but for my current project I’ll have to stick with SB3.