jax: On entry to GEMM_EX parameter number {9,12} had an illegal value
When I try to run my jax-based training script on the GPU, it prints the following messages several hundred times but proceeds without an exception.
** On entry to GEMM_EX parameter number 12 had an illegal value
** On entry to GEMM_EX parameter number 9 had an illegal value
The same script doesn’t print any errors when run with a CPU-only build. This error occurs with the pip build and a custom build of master.
To reproduce, run python jax_transformer.py alice.txt
in https://github.com/joschu/jax-exp/
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Comments: 16 (16 by maintainers)
Commits related to this issue
- index_take in terms of gather, delete index_untake (c.f. #304) — committed to google/jax by mattjj 5 years ago
- Update XLA release. Updates XLA to https://github.com/tensorflow/tensorflow/commit/e889ea1dd965c31c391106aa3518fc23d2689954, which fixes #304. — committed to hawkinsp/jax by hawkinsp 5 years ago
The GEMM_EX problem is now tracked at https://github.com/tensorflow/tensorflow/issues/25761
I think this happens during GEMM autotuning. It may be benign because eventually XLA gives up and uses a generic GEMM, which seems to work:
Snippet from a log (with
TF_CPP_MIN_VLOG_LEVEL=2
set)