iree: [spirv] Incorrect int4 vicuna model output
What happened?
Vicuna int4 model outputs are different compared to fake_int8 model (int4 weights in int8 container). The results begin diverging from dispatch 5.
Build is based on https://github.com/openxla/iree/commit/afc8705fd527680905c9b84c5e06f8cca8051377
Steps to reproduce your issue
- Download model second_vicuna_int4.mlir.
- Commands to reproduce:
./iree-compile --iree-input-type=none --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-cpu-features=host --iree-stream-resource-index-bits=64 --iree-vm-target-index-bits=64 --iree-vm-bytecode-module-strip-source-map=false --iree-util-zero-fill-elided-attrs --iree-vm-target-truncate-unsupported-floats --iree-codegen-check-ir-before-llvm-conversion=false --iree-vulkan-target-triple=rdna3-unknown-linux --iree-opt-const-expr-hoisting=false --iree-consteval-jit-debug=true ~/Downloads/second_vicuna_int4.mlir -o vicuna_vulkan_i4.vmfb --iree-stream-resource-max-allocation-size=3221225472 --iree-flow-break-dispatch=@forward:5
./iree-run-module --device_allocator=caching --vulkan_vma_allocator=false --module=vicuna_vulkan_i4.vmfb --device=vulkan --function=forward --input=1x1xi64 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --input=1x32x1x128xf32 --output=@vulkan_dispatch5.npy
- The results I got for dispatch 5 from the above commands: https://storage.googleapis.com/shark-public/vivian/vicuna_i4/vulkan_dispatch5.npy
- Golden results for comparison: https://storage.googleapis.com/shark-public/vivian/vicuna_i4/vulkan_dispatch5_golden.npy
About this issue
- Original URL
- State: closed
- Created 10 months ago
- Comments: 39 (37 by maintainers)
Commits related to this issue
- [spirv] Fix element extraction order when breaking down vectors (#14755) Assuming little-endian style encoding of the sub-byte elements, for 8xi4 [A, B, C, D, E, F, G, H], they are stored in memory ... — committed to iree-org/iree by antiagainst 10 months ago
- [spirv] Fix element extraction order when breaking down vectors (#14755) Assuming little-endian style encoding of the sub-byte elements, for 8xi4 [A, B, C, D, E, F, G, H], they are stored in memory ... — committed to dcaballe/iree by antiagainst 10 months ago
- Cherry-pick MLIR commits to fix SPIR-V 64-bit index conversion (#14771) This commit cherry-picks: * llvm/llvm-project@22a28d89937333d581103af87d463d7b255d3e3e * llvm/llvm-project@4ffc63ab71e501eb... — committed to iree-org/iree by antiagainst 10 months ago
Okay I think I found it–it’s a 32-bit index out of range issue. We are compiling with 32-bit index right now for SPIR-V. Try picking up https://github.com/openxla/iree/pull/14771 and add
--iree-spirv-index-bits=64
when compiling. It gives me the correct numbers on my side. @yzhang93 and @MaheshRavishankar could you double check.