stable-baselines3: Current ONNX opset doesn't support StableBaselines3 natively, requires creating a wrapper class.
I am interested in using stable-baselines
to train an agent, and then export it through ONNX to put it inside the Unity engine via Barracuda. I was hoping to write up the documentation too!
Unfortunately the opset 9 or opset 12 in ONNX doesn’t seem to support converting trained policies.
RuntimeError: Exporting the operator broadcast_tensors to ONNX opset version 9 is not supported.
Please open a bug to request ONNX export support for the missing operator.
While the broadcast_tensor isn’t something explicitly called in the codebase it potentially might be related to using torch.distributions. Unfortunately, this seems to be an open issue since 2019 November, so I am pessimistic about it being solved soon.
While very unlikely, do you think there might be a way around this? Either way, I wanted to raise this issue so the team is aware.
Checklist
- I have read the documentation (required)
- I have checked that there is no similar issue in the repo (required)
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Comments: 19 (8 by maintainers)
The “after the fact” instantiate option. I’m still not sure where exactly the broadcast is but I guess its not in these 3 modules :p
thanks =), but it seems that you are passing randomly initialized features extractor instead of the trained ones:
model.policy.make_features_extractor()
instead ofmodel.policy.actor.features_extractor
@batu @Egiob FYI we would be happy to take PRs on adding comments on exporting models to ONNX, as this seems to be a common thing people do 😃
Yes, I believe the initial .flatten() causes the problem. The following works for me.
Hello, I’ll re-open this issue as I think it would a useful addition to the doc (https://stable-baselines3.readthedocs.io/en/master/guide/export.html#export-to-onnx).
I would also be interested in knowing which part exactly break the export? (probably try to delete variables until it works?). The current solution only works for
MlpPolicy
when no preprocessing is needed (so works only withBox
observation space).For a more complete solution, you should preprocess the input, see https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/policies.py#L110
Another thing not mentioned is that we do a lot of magic in the
predict()
method, notably re-organizing image channels automatically and a returning the correct shape (depending on the input shape): https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/policies.py#L238I’m not sure what parts of the implementation of the policy make it so that it isn’t supported by the opset (I think you’re right about torch.distributions) but the
mlp_extractor
,action_net
, and thevalue_net
modules of theActorCriticPolicy
are all “onnxable” as they don’t include thebroadcast_tensors
operator. So one can do the following:Then one can use the
OnnxablePolicy
to do the actual onnx export. A similar thing can be done for theActorCriticCnnPolicy
but minding that thefeatures_extractor
only outputs a single hidden value tensor which is consumed by both thevalue
andaction
net. Its a little bit of surgery but given this workaround perhaps one of the maintainers has a better approach.