transformers: Error while converting distilbart-mnli-12-1 model to ONNX
After converting distilbart-mnli-12-1
to ONNX, while testing the onnx model, I get this issue:
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: \[ONNXRuntimeError\] : 2 : INVALID_ARGUMENT : Non-zero status code returned while running Expand node. Name:'Expand_74' Status Message: invalid expand shape
After lots of investigation I understand that the problem is existed with shift_tokens_right
function in modeling_bart.py
code.
I edit the function to this:
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = torch.full((input_ids.shape[0],),decoder_start_token_id)
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
# # replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
Problem totally solved. The issue is existed with ONNX converter where do not perform correctly while there is broadcasting.
Is it possible to edit the repository and merge these changes to yours?
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 18 (9 by maintainers)
Hi @lewtun Thanks a lot. For sure.
Hi @farzanehnakhaee70, @lewtun, Great catch @farzanehnakhaee70 !! I would say that if you have a working solution you can definitely open a PR!