stable-baselines3: Current ONNX opset doesn't support StableBaselines3 natively, requires creating a wrapper class.

I am interested in using stable-baselines to train an agent, and then export it through ONNX to put it inside the Unity engine via Barracuda. I was hoping to write up the documentation too!

Unfortunately the opset 9 or opset 12 in ONNX doesn’t seem to support converting trained policies.

RuntimeError: Exporting the operator broadcast_tensors to ONNX opset version 9 is not supported.
Please open a bug to request ONNX export support for the missing operator.

While the broadcast_tensor isn’t something explicitly called in the codebase it potentially might be related to using torch.distributions. Unfortunately, this seems to be an open issue since 2019 November, so I am pessimistic about it being solved soon.

While very unlikely, do you think there might be a way around this? Either way, I wanted to raise this issue so the team is aware.

Checklist

  • I have read the documentation (required)
  • I have checked that there is no similar issue in the repo (required)

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 19 (8 by maintainers)

Most upvoted comments

The “after the fact” instantiate option. I’m still not sure where exactly the broadcast is but I guess its not in these 3 modules :p

I made a example using a MultiInputPolicy with SAC:

thanks =), but it seems that you are passing randomly initialized features extractor instead of the trained ones: model.policy.make_features_extractor() instead of model.policy.actor.features_extractor

@batu @Egiob FYI we would be happy to take PRs on adding comments on exporting models to ONNX, as this seems to be a common thing people do 😃

Yes, I believe the initial .flatten() causes the problem. The following works for me.

class SACOnnxablePolicy(torch.nn.Module):
    def __init__(self,  actor):
        super(SACOnnxablePolicy, self).__init__()
        
        # Removing the flatten layer because it can't be onnxed
        self.actor = torch.nn.Sequential(actor.latent_pi, actor.mu)

    def forward(self, input):
        return self.actor(input)

new_model = SACOnnxablePolicy(loaded_model.policy.actor)

Hello, I’ll re-open this issue as I think it would a useful addition to the doc (https://stable-baselines3.readthedocs.io/en/master/guide/export.html#export-to-onnx).

I would also be interested in knowing which part exactly break the export? (probably try to delete variables until it works?). The current solution only works for MlpPolicy when no preprocessing is needed (so works only with Box observation space).

For a more complete solution, you should preprocess the input, see https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/policies.py#L110

Another thing not mentioned is that we do a lot of magic in the predict() method, notably re-organizing image channels automatically and a returning the correct shape (depending on the input shape): https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/policies.py#L238

I’m not sure what parts of the implementation of the policy make it so that it isn’t supported by the opset (I think you’re right about torch.distributions) but the mlp_extractor, action_net, and the value_net modules of the ActorCriticPolicy are all “onnxable” as they don’t include the broadcast_tensors operator. So one can do the following:

class OnnxablePolicy(torch.nn.Module):
    def __init__(self, extractor, action_net, value_net):
        super(OnnxablePolicy, self).__init__()
        self.extractor = extractor
        self.action_net = action_net
        self.value_net = value_net

    def forward(self, input):
        action_hidden, value_hidden = self.extractor(input)
        return (self.action_net(action_hidden), self.value_net(value_hidden))

Then one can use the OnnxablePolicy to do the actual onnx export. A similar thing can be done for the ActorCriticCnnPolicy but minding that the features_extractor only outputs a single hidden value tensor which is consumed by both the value and action net. Its a little bit of surgery but given this workaround perhaps one of the maintainers has a better approach.