iree: [CPU] Accuracy error on many models with large batch size (ResNet, GPT2, BertLarge)
What happened?
For both CPU default flags and data-tiling+ukernel paths, we see an accuracy error on GPT2 for batch sizes 48 and 64 (possibly higher but haven’t tried). Accuracy is correct for batch sizes 1, 2, 8, 16 and 32.
Also seeing the same behavior in VMVX backend.
Steps to reproduce your issue
Download and extract https://storage.googleapis.com/iree-shared-files/jax_models.tar.xz.
- Compile and run with batch size 1.
MODEL_DIR=/tmp/jax_models/GPT2LMHEAD_FP32_JAX_512XI32_BATCH1
iree-compile --iree-hal-target-backends=llvm-cpu --iree-input-type=stablehlo --iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu ${MODEL_DIR}/stablehlo.mlir -o ${MODEL_DIR}/module.vmfb
iree-run-module --function=main --input=@${MODEL_DIR}/inputs_npy/input_0.npy --input=@${MODEL_DIR}/inputs_npy/input_1.npy --module=${MODEL_DIR}/module.vmfb --output=@${MODEL_DIR}/out.npy
python3 /tmp/jax_models/compare_npy.py -a ${MODEL_DIR}/out.npy -b ${MODEL_DIR}/outputs_npy/output_0.npy
Should see output like below, indicating the accuracy is correct:
EXEC @main
all_close: True. shape: (1, 512, 50257)
a[0, 0, 0]: -33.17157745361328, b[0, 0, 0]: -33.17165756225586
a[0, 1, 0]: -84.1168212890625, b[0, 1, 0]: -84.11683654785156
a[0, 0, 1]: -32.55836486816406, b[0, 0, 1]: -32.55845642089844
- Compile and run with batch size 48.
MODEL_DIR=/tmp/jax_models/GPT2LMHEAD_FP32_JAX_512XI32_BATCH48
iree-compile --iree-hal-target-backends=llvm-cpu --iree-input-type=stablehlo --iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu ${MODEL_DIR}/stablehlo.mlir -o ${MODEL_DIR}/module.vmfb
iree-run-module --function=main --input=@${MODEL_DIR}/inputs_npy/input_0.npy --input=@${MODEL_DIR}/inputs_npy/input_1.npy --module=${MODEL_DIR}/module.vmfb --output=@${MODEL_DIR}/out.npy
python3 /tmp/jax_models/compare_npy.py -a ${MODEL_DIR}/out.npy -b ${MODEL_DIR}/outputs_npy/output_0.npy
Output should look like below, indicating accuracy error.
all_close: False. shape: (48, 512, 50257)
a[0, 0, 0]: -4.763853549957275, b[0, 0, 0]: -33.17165756225586
a[0, 1, 0]: -22.7552547454834, b[0, 1, 0]: -84.11683654785156
a[0, 0, 1]: 15.67091178894043, b[0, 0, 1]: -32.55845642089844
[0, 0, 0]: -4.763853549957275 != -33.17165756225586
[0, 0, 1]: 15.67091178894043 != -32.55845642089844
[0, 0, 2]: 4.204073429107666 != -35.47325134277344
[0, 0, 3]: -19.93921661376953 != -35.13169479370117
The same behavior is seen when using data tiling + ukernel compiler flags.
What component(s) does this issue relate to?
Compiler
Version information
iree-compiler release 20230804.603
Additional context
No response
About this issue
- Original URL
- State: closed
- Created a year ago
- Comments: 40 (25 by maintainers)
Commits related to this issue
- Fixing implicit casting that caused 4GB fill/copy limits in local-task. Fixes #14601. — committed to iree-org/iree by benvanik 5 months ago
- Fixing implicit casting that caused 4GB fill/copy limits in local-task. (#16364) I'll be reworking some other things in the future that debugging this highlighted (poor distribution, the fact that w... — committed to iree-org/iree by benvanik 5 months ago
Ooooh the blamed PR #14016 is an anagram of this Issue #14601 !!! It was right under our nose all along!
@benvanik
As mentioned above, the original link in the Issue description is broken in the way you describe. Use instead the gcloud command Hanhan provided in https://github.com/openxla/iree/issues/14601#issuecomment-1927525221
To reproduce the issue:
Download artifacts:
Compile and run the model:
Check the output, note that it takes some time to dump the actual_value and expect_value. The compare_npy.py file can be found at the gist