jax: Cannot use GPU on Ubuntu 16.04, CUDA 11.0

I have a GeForce RTX 3090 with CUDA 11.0 installed on Ubuntu 16.04 and the installation works fine with TensorFlow. The path /usr/local/cuda points to that installation.

I installed Jax into my Python 3.8.6 conda environment by running

pip3 install --upgrade jax jaxlib==0.1.62+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

I can import Jax from Python but the first operation throws an error.

from jax import numpy
numpy.zeros(4)
2021-03-12 21:26:30.353284: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:191] Falling back to the CUDA driver for PTX compilation; ptxas does not support CC 8.6
2021-03-12 21:26:30.353307: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:194] Used ptxas at /usr/local/cuda-11.0/bin/ptxas
2021-03-12 21:26:30.353808: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:682] failed to get PTX kernel "broadcast_2" from module: CUDA_ERROR_NOT_FOUND: named symbol not found
2021-03-12 21:26:30.353849: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1881] Execution of replica 0 failed: Internal: Could not find the corresponding function
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/holl/miniconda3/envs/phiflow2_tf/lib/python3.8/site-packages/jax/interpreters/xla.py", line 1181, in __repr__
    s = np.array2string(self._value, prefix=prefix, suffix=',',
  File "/home/holl/miniconda3/envs/phiflow2_tf/lib/python3.8/site-packages/jax/interpreters/xla.py", line 1122, in _value
    self._npy_value = _force(self).device_buffer.to_py()
  File "/home/holl/miniconda3/envs/phiflow2_tf/lib/python3.8/site-packages/jax/interpreters/xla.py", line 1333, in _force
    result = force_fun(x)
  File "/home/holl/miniconda3/envs/phiflow2_tf/lib/python3.8/site-packages/jax/interpreters/xla.py", line 1357, in force_fun
    return compiled.execute([x.device_buffer])[0]
RuntimeError: Internal: Could not find the corresponding function

Running nvcc --version prints

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Thu_Jun_11_22:26:38_PDT_2020
Cuda compilation tools, release 11.0, V11.0.194
Build cuda_11.0_bu.TC445_37.28540450_0

Is this a bug or am I doing something wrong?

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Reactions: 2
  • Comments: 15 (2 by maintainers)

Commits related to this issue

Most upvoted comments

In my case the problem was caused by my PATH. CUDA did not include itself into it automatically (although nvidia-smi worked), so I had to add this to my .bashrc.

export PATH=/usr/local/cuda-11/bin:${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda-11/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}

Unfortunately not with CUDA 11.0. If there is a fix, I would be interested, too.

However, after upgrading to CUDA Toolkit 11.2 and using jaxlib==0.1.64+cuda112 everything seems to be working 🤷‍♂️

Not sure if it’s related but building from source fails with the following error:

ERROR: /home/holl/.cache/bazel/_bazel_holl/0b9919be22653e3173d0276d146dc0c1/external/org_tensorflow/tensorflow/stream_executor/gpu/BUILD:226:11: C++ compilation of rule '@org_tensorflow//tensorflow/stream_executor/gpu:asm_compiler' failed (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing command
  (cd /home/holl/.cache/bazel/_bazel_holl/0b9919be22653e3173d0276d146dc0c1/execroot/__main__ && \
  exec env - \
    PATH=/home/holl/bin:/home/holl/.local/bin:/home/holl/miniconda3/envs/jax/bin:/home/holl/miniconda3/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/usr/local/cuda/bin \
    PWD=/proc/self/cwd \
    TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0 \
    TF_ROCM_AMDGPU_TARGETS=gfx803,gfx900,gfx906,gfx1010 \
  external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -MD -MF bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/stream_executor/gpu/_objs/asm_compiler/asm_compiler.pic.d '-frandom-seed=bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/stream_executor/gpu/_objs/asm_compiler/asm_compiler.pic.o' -DTF_USE_SNAPPY -DEIGEN_MPL2_ONLY '-DEIGEN_MAX_ALIGN_BYTES=64' -DHAVE_SYS_UIO_H -D__CLANG_SUPPORT_DYN_ANNOTATION__ -iquote external/org_tensorflow -iquote bazel-out/k8-opt/bin/external/org_tensorflow -iquote external/com_google_absl -iquote bazel-out/k8-opt/bin/external/com_google_absl -iquote external/nsync -iquote bazel-out/k8-opt/bin/external/nsync -iquote external/eigen_archive -iquote bazel-out/k8-opt/bin/external/eigen_archive -iquote external/gif -iquote bazel-out/k8-opt/bin/external/gif -iquote external/libjpeg_turbo -iquote bazel-out/k8-opt/bin/external/libjpeg_turbo -iquote external/com_google_protobuf -iquote bazel-out/k8-opt/bin/external/com_google_protobuf -iquote external/zlib -iquote bazel-out/k8-opt/bin/external/zlib -iquote external/com_googlesource_code_re2 -iquote bazel-out/k8-opt/bin/external/com_googlesource_code_re2 -iquote external/farmhash_archive -iquote bazel-out/k8-opt/bin/external/farmhash_archive -iquote external/fft2d -iquote bazel-out/k8-opt/bin/external/fft2d -iquote external/highwayhash -iquote bazel-out/k8-opt/bin/external/highwayhash -iquote external/double_conversion -iquote bazel-out/k8-opt/bin/external/double_conversion -iquote external/snappy -iquote bazel-out/k8-opt/bin/external/snappy -iquote external/local_config_cuda -iquote bazel-out/k8-opt/bin/external/local_config_cuda -iquote external/local_config_tensorrt -iquote bazel-out/k8-opt/bin/external/local_config_tensorrt -Ibazel-out/k8-opt/bin/external/local_config_cuda/cuda/_virtual_includes/cuda_headers_virtual -Ibazel-out/k8-opt/bin/external/local_config_tensorrt/_virtual_includes/tensorrt_headers -isystem external/nsync/public -isystem bazel-out/k8-opt/bin/external/nsync/public -isystem external/eigen_archive -isystem bazel-out/k8-opt/bin/external/eigen_archive -isystem external/gif -isystem bazel-out/k8-opt/bin/external/gif -isystem external/com_google_protobuf/src -isystem bazel-out/k8-opt/bin/external/com_google_protobuf/src -isystem external/zlib -isystem bazel-out/k8-opt/bin/external/zlib -isystem external/farmhash_archive/src -isystem bazel-out/k8-opt/bin/external/farmhash_archive/src -isystem external/double_conversion -isystem bazel-out/k8-opt/bin/external/double_conversion -isystem external/local_config_cuda/cuda -isystem bazel-out/k8-opt/bin/external/local_config_cuda/cuda -isystem external/local_config_cuda/cuda/cuda/include -isystem bazel-out/k8-opt/bin/external/local_config_cuda/cuda/cuda/include -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME__="redacted"' -fPIC -U_FORTIFY_SOURCE '-D_FORTIFY_SOURCE=1' -fstack-protector -Wall -fno-omit-frame-pointer -no-canonical-prefixes -fno-canonical-system-headers -DNDEBUG -g0 -O2 -ffunction-sections -fdata-sections -Wno-sign-compare -Wno-stringop-truncation -mavx '-std=c++14' -DEIGEN_AVOID_STL_ARRAY -Iexternal/gemmlowp -Wno-sign-compare '-ftemplate-depth=900' -fno-exceptions '-DGOOGLE_CUDA=1' '-DTENSORFLOW_USE_NVCC=1' -msse3 -DTENSORFLOW_MONOLITHIC_BUILD -pthread -c external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc -o bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/stream_executor/gpu/_objs/asm_compiler/asm_compiler.pic.o)
Execution platform: @local_execution_config_platform//:platform
external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc: In function ‘void stream_executor::LogPtxasTooOld(const string&, int, int)’:
external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:190:62: error: converting to ‘absl::lts_2020_02_25::container_internal::raw_hash_set<absl::lts_2020_02_25::container_internal::FlatHashSetPolicy<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> >, absl::lts_2020_02_25::hash_internal::Hash<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> >, std::equal_to<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> >, std::allocator<std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int> > >::init_type {aka std::tuple<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int>}’ from initializer list would use explicit constructor ‘constexpr std::tuple< <template-parameter-1-1> >::tuple(_UElements&& ...) [with _UElements = {const std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >&, int&, int&}; <template-parameter-2-2> = void; _Elements = {std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, int, int}]’
   if (already_logged->insert({ptxas_path, cc_major, cc_minor}).second) {
                                                              ^
At global scope:
cc1plus: warning: unrecognized command line option ‘-Wno-stringop-truncation’
Target //build:build_wheel failed to build
INFO: Elapsed time: 1298.922s, Critical Path: 60.78s
INFO: 2938 processes: 806 internal, 2132 local.
FAILED: Build did NOT complete successfully
ERROR: Build failed. Not running target
FAILED: Build did NOT complete successfully
Traceback (most recent call last):
  File "build/build.py", line 521, in <module>
    main()
  File "build/build.py", line 516, in main
    shell(command)
  File "build/build.py", line 51, in shell
    output = subprocess.check_output(cmd)
  File "/home/holl/miniconda3/envs/jax/lib/python3.8/subprocess.py", line 415, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
  File "/home/holl/miniconda3/envs/jax/lib/python3.8/subprocess.py", line 516, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['./bazel-3.7.2-linux-x86_64', 'run', '--verbose_failures=true', '--config=short_logs', '--config=avx_posix', '--config=mkl_open_source_only', '--config=cuda', '--define=xla_python_enable_gpu=true', ':build_wheel', '--', '--output_path=/home/holl/jax/jax/dist']' returned non-zero exit status 1.