optimum: Validating ONNX model fails for GPT-J
System Info
Optimum: 1.5.1
Python: 3.10.4
Platform: Windows 10
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, …) - My own task or dataset (give details below)
Reproduction
I Installed optimum with pip install optimum[onnxruntime-gpu]
. Then I was running python -m optimum.exporters.onnx --task causal-lm-with-past --model EleutherAI/gpt-j-6B gptj_onnx/
to transform GPT-J to ONNX. The output of this call is then as follows:
Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJModel: ['lm_head.weight', 'lm_head.bias']
- This IS expected if you are initializing GPTJModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPTJModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Framework not specified. Using pt to export to ONNX.
Using framework PyTorch: 1.12.1
Overriding 2 configuration item(s)
- use_cache -> True
- pad_token_id -> 0
C:\Users\myUsername\Anaconda3\envs\huggingface\lib\site-packages\transformers\models\gptj\modeling_gptj.py:597: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if batch_size <= 0:
C:\Users\myUsername\Anaconda3\envs\huggingface\lib\site-packages\transformers\models\gptj\modeling_gptj.py:177: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode
Validating ONNX model...
-[✓] ONNX model output names match reference model (present.22.value, present.15.key, present.15.value, present.25.value, present.9.value, present.26.value, present.8.value, present.13.key, present.27.key, present.6.value, present.7.value, present.12.value, present.24.key, present.1.value, present.4.key, logits, present.10.key, present.9.key, present.16.key, present.0.key, present.19.key, present.21.key, present.4.value, present.23.value, present.3.key, present.17.key, present.6.key, present.21.value, present.22.key, present.18.key, present.11.key, present.10.value, present.14.value, present.0.value, present.13.value, present.14.key, present.5.value, present.2.value, present.16.value, present.24.value, present.25.key, present.27.value, present.8.key, present.7.key, present.19.value, present.20.key, present.26.key, present.18.value, present.23.key, present.11.value, present.2.key, present.5.key, present.3.value, present.1.key, present.20.value, present.17.value, present.12.key)
- Validating ONNX Model output "logits":
-[✓] (2, 16, 50400) matches (2, 16, 50400)
-[x] values not close enough, max diff: 3.2901763916015625e-05 (atol: 1e-05)
- Validating ONNX Model output "present.0.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.0.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.1.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.1.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.2.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.2.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.3.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.3.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.4.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.4.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.5.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.5.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.6.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.6.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.7.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[x] values not close enough, max diff: 2.6702880859375e-05 (atol: 1e-05)
- Validating ONNX Model output "present.7.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.8.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.8.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.9.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[x] values not close enough, max diff: 2.09808349609375e-05 (atol: 1e-05)
- Validating ONNX Model output "present.9.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[x] values not close enough, max diff: 1.3589859008789062e-05 (atol: 1e-05)
- Validating ONNX Model output "present.10.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.10.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[x] values not close enough, max diff: 1.2636184692382812e-05 (atol: 1e-05)
- Validating ONNX Model output "present.11.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[x] values not close enough, max diff: 2.574920654296875e-05 (atol: 1e-05)
- Validating ONNX Model output "present.11.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.12.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.12.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.13.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[x] values not close enough, max diff: 1.71661376953125e-05 (atol: 1e-05)
- Validating ONNX Model output "present.13.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.14.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.14.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.15.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.15.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.16.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.16.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.17.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[x] values not close enough, max diff: 1.71661376953125e-05 (atol: 1e-05)
- Validating ONNX Model output "present.17.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.18.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.18.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.19.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.19.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.20.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.20.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.21.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.21.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.22.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.22.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.23.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.23.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.24.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.24.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.25.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.25.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[x] values not close enough, max diff: 1.1682510375976562e-05 (atol: 1e-05)
- Validating ONNX Model output "present.26.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.26.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.27.key":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
- Validating ONNX Model output "present.27.value":
-[✓] (2, 16, 32, 256) matches (2, 16, 32, 256)
-[✓] all values close (atol: 1e-05)
An error occured, but the model was saved at: gptj_onnx/model.onnx
Expected behavior
Validation of ONNX model should succeed.
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 40 (9 by maintainers)
The strategy is to take the
decoder_model.onnx
anddecoder_with_past_model.onnx
, and merge them into a single ONNX with anIf
node that dispatches on the without or with branch depending on a flag passed as input, while weights are shared between both.In the first pass, dummy past key values must be passed (they will simply not be used).
The
decoder_model_merged.onnx
(which will be the default in the ONNX export) now looks like:To be honest, this is a bit of a hack, and there should be cleanier solution than this: https://discuss.huggingface.co/t/how-does-the-onnx-exporter-work-for-generationmodel-with-past-key-value/31316/8?u=fxmarty
Oh, it’s a trick to generate a fake cache. We only need to generate (randn) a cache for keys and values for each layer for the past text of length 1. And then, we can safely use the attention mask, masking this fake cache with zero. As far as I have checked, this method really works and does not affect the output of the model when forwarding real tokens.
I do not consider this to be a correct solution to the problem, I just realized through experiments that it works. (tested only on gpt-j)
Hey @Eichhof, I made a support for past key/value in decoder for my own. My version does not require 2 models to be loaded into memory, I think this is a terrible idea, since many decoders weigh a lot. Also, I got rid of many bugs in the implementation of ORTModelForCausalLM class, which I found while trying to use it.
You can check out my version here, unfortunately I’m not going to do a PR: https://github.com/hivaze/optimum/blob/main/optimum/onnxruntime/modeling_decoder.py