iree: [CPU] Accuracy error on many models with large batch size (ResNet, GPT2, BertLarge)

What happened?

For both CPU default flags and data-tiling+ukernel paths, we see an accuracy error on GPT2 for batch sizes 48 and 64 (possibly higher but haven’t tried). Accuracy is correct for batch sizes 1, 2, 8, 16 and 32.

Also seeing the same behavior in VMVX backend.

Steps to reproduce your issue

Download and extract https://storage.googleapis.com/iree-shared-files/jax_models.tar.xz.

  1. Compile and run with batch size 1.
MODEL_DIR=/tmp/jax_models/GPT2LMHEAD_FP32_JAX_512XI32_BATCH1
iree-compile --iree-hal-target-backends=llvm-cpu --iree-input-type=stablehlo --iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu ${MODEL_DIR}/stablehlo.mlir -o ${MODEL_DIR}/module.vmfb
iree-run-module --function=main --input=@${MODEL_DIR}/inputs_npy/input_0.npy --input=@${MODEL_DIR}/inputs_npy/input_1.npy --module=${MODEL_DIR}/module.vmfb --output=@${MODEL_DIR}/out.npy
python3 /tmp/jax_models/compare_npy.py -a ${MODEL_DIR}/out.npy -b ${MODEL_DIR}/outputs_npy/output_0.npy

Should see output like below, indicating the accuracy is correct:

EXEC @main
all_close: True. shape: (1, 512, 50257)
a[0, 0, 0]: -33.17157745361328, b[0, 0, 0]: -33.17165756225586
a[0, 1, 0]: -84.1168212890625, b[0, 1, 0]: -84.11683654785156
a[0, 0, 1]: -32.55836486816406, b[0, 0, 1]: -32.55845642089844
  1. Compile and run with batch size 48.
MODEL_DIR=/tmp/jax_models/GPT2LMHEAD_FP32_JAX_512XI32_BATCH48
iree-compile --iree-hal-target-backends=llvm-cpu --iree-input-type=stablehlo --iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu ${MODEL_DIR}/stablehlo.mlir -o ${MODEL_DIR}/module.vmfb
iree-run-module --function=main --input=@${MODEL_DIR}/inputs_npy/input_0.npy --input=@${MODEL_DIR}/inputs_npy/input_1.npy --module=${MODEL_DIR}/module.vmfb --output=@${MODEL_DIR}/out.npy
python3 /tmp/jax_models/compare_npy.py -a ${MODEL_DIR}/out.npy -b ${MODEL_DIR}/outputs_npy/output_0.npy

Output should look like below, indicating accuracy error.

all_close: False. shape: (48, 512, 50257)

a[0, 0, 0]: -4.763853549957275, b[0, 0, 0]: -33.17165756225586
a[0, 1, 0]: -22.7552547454834, b[0, 1, 0]: -84.11683654785156
a[0, 0, 1]: 15.67091178894043, b[0, 0, 1]: -32.55845642089844

[0, 0, 0]: -4.763853549957275 != -33.17165756225586
[0, 0, 1]: 15.67091178894043 != -32.55845642089844
[0, 0, 2]: 4.204073429107666 != -35.47325134277344
[0, 0, 3]: -19.93921661376953 != -35.13169479370117

The same behavior is seen when using data tiling + ukernel compiler flags.

What component(s) does this issue relate to?

Compiler

Version information

iree-compiler release 20230804.603

Additional context

No response

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 40 (25 by maintainers)

Commits related to this issue

Most upvoted comments

Ooooh the blamed PR #14016 is an anagram of this Issue #14601 !!! It was right under our nose all along!

@benvanik

is the model download still valid? I get a 13mb file that has a truncated tar in it, not whatever 1gb thing benoit mentions?

As mentioned above, the original link in the Issue description is broken in the way you describe. Use instead the gcloud command Hanhan provided in https://github.com/openxla/iree/issues/14601#issuecomment-1927525221

To reproduce the issue:

Download artifacts:

gcloud storage cp -r gs://iree-model-artifacts/jax/jax_models_0.4.23_1706594181/GPT2LMHEAD_FP32_JAX_512XI32_BATCH64/ ~/
cd ~/GPT2LMHEAD_FP32_JAX_512XI32_BATCH64
tar -axvf inputs_npy.tgz
tar -axvf outputs_npy.tgz

Compile and run the model:

iree-compile --output-format=vm-bytecode --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-target-cpu=cascadelake --iree-llvmcpu-target-triple=x86_64-unknown-linux-gnu ~/GPT2LMHEAD_FP32_JAX_512XI32_BATCH64/stablehlo.mlirbc -o /tmp/a.vmfb --iree-llvmcpu-enable-ukernels=all

export MODEL_DIR=/home/hanchung/GPT2LMHEAD_FP32_JAX_512XI32_BATCH64
build/tools/iree-run-module --function=main --input=@${MODEL_DIR}/input_0.npy --input=@${MODEL_DIR}/input_1.npy --module=/tmp/a.vmfb --output=@${MODEL_DIR}/out.npy

Check the output, note that it takes some time to dump the actual_value and expect_value. The compare_npy.py file can be found at the gist

cd ~/GPT2LMHEAD_FP32_JAX_512XI32_BATCH64
python compare_npy.py -a out.npy -b output_0.npy