jax: RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm

Running convolutional layers seems to cause an error that Jax does not know what cudnn optimisation algorithm to use This error appears to be Jax only as I have replicated the code with TensorFlow and no error occurs

My jax version is 0.2.24 and jaxlib version is 0.1.74+cuda11.cudnn82 with a Nvidia 3080

The example is taken from the flax readme (https://github.com/google/flax) The bug appears to be only for convolutions as the error does not occur for the MLP example

I haven’t been able to replicate this error as I don’t have another GPU to use I found this similar issue from someone who uses a 3080 like me (https://github.com/google/jax/issues/7953)

import jax.numpy as jnp
import flax.linen as nn

class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        x = nn.log_softmax(x)
        return x

model = CNN()
batch = jnp.ones((32, 64, 64, 10))  # (N, H, W, C) format
variables = model.init(jax.random.PRNGKey(0), batch)
# output = model.apply(variables, batch)
Traceback (most recent call last):
  File "/home/mark/Documents/programming/test-jax/flax_main.py", line 25, in <module>
    variables = model.init(jax.random.PRNGKey(0), batch)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 998, in init
    _, v_out = self.init_with_output(
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 968, in init_with_output
    return self.apply(
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 936, in apply
    return apply(
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/core/scope.py", line 687, in wrapper
    y = fn(root, *args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 1178, in scope_fn
    return fn(module.clone(parent=scope), *args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 275, in wrapped_module_method
    y = fun(self, *args, **kwargs)
  File "/home/mark/Documents/programming/test-jax/flax_main.py", line 9, in __call__
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 275, in wrapped_module_method
    y = fun(self, *args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py", line 270, in __call__
    y = lax.conv_general_dilated(
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 653, in conv_general_dilated
    return conv_general_dilated_p.bind(
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/core.py", line 272, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/core.py", line 624, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 311, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/util.py", line 187, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/util.py", line 180, in cached
    return f(*args, **kwargs)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 334, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 653, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars,
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 769, in compile
    self._executable = XlaCompiledComputation.from_xla_computation(
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 798, in from_xla_computation
    compiled = compile_or_get_cached(backend, xla_computation, options)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 87, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options)
  File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 369, in backend_compile
    return backend.compile(built_c, compile_options=options)
RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm: INTERNAL: All algorithms tried for %cudnn-conv = (f32[32,64,64,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[32,64,64,10]{2,1,3,0} %copy.3, f32[3,3,10,32]{1,0,2,3} %copy.4), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="jit(conv_general_dilated)/conv_general_dilated[\n  batch_group_count=1\n  dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n  feature_group_count=1\n  lhs_dilation=(1, 1)\n  lhs_shape=(32, 64, 64, 10)\n  padding=((1, 1), (1, 1))\n  precision=None\n  preferred_element_type=None\n  rhs_dilation=(1, 1)\n  rhs_shape=(3, 3, 10, 32)\n  window_strides=(1, 1)\n]" source_file="/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py" source_line=270}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm. 

Convolution performance may be suboptimal.  To ignore this failure and try to use a fallback algorithm, use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please also file a bug for the root cause of failing autotuning.

Process finished with exit code 1

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 16 (5 by maintainers)

Most upvoted comments

Hi everyone,

I had the exact same issue described above. I am running on WSL2 on Windows 10. I installed CUDA and CuDNN and then installed jax[gpu] via pip. After setting XLA_PYTHON_CLIENT_MEM_FRACTION=0.87, my program works perfectly on a 3080, but with 0.9 the same RuntimeError is thrown up.

Pre-updated memory fraction nvidia-smi:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.52       Driver Version: 511.79       CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:01:00.0 Off |                  N/A |
| N/A   42C    P8    12W /  N/A |   7530MiB /  8192MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      7754      C   /python3.8                      N/A      |
+-----------------------------------------------------------------------------+`

Post-updated memory fraction nvidia-smi:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.52       Driver Version: 511.79       CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  On   | 00000000:01:00.0 Off |                  N/A |
| N/A   53C    P0    57W /  N/A |   7476MiB /  8192MiB |     48%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      7958      C   /python3.8                      N/A      |
+-----------------------------------------------------------------------------+

I am very new to any sort of collaboration on repositories, so apologies if my etiquette is somewhat off, but I was wondering whether this had any updates? Any “best practice” ways to correct this? I am currently setting XLA_PYTHON_CLIENT_MEM_FRACTION=0.87 in my bash ~/.profile directory and then just running Jax as is.

Also, slightly unrelated, but what would be the best way to keep up with updates to this repository? I will be using Jax pretty religiously to build SVGP models as I love its flexibility, and so would like to keep up-to-date.

Thanks for any help!

Fix it, it is a memory allocation issue like suggested below however different export XLA_PYTHON_CLIENT_MEM_FRACTION=0.7

I found this previous discussion that had a very similar problem to mine https://github.com/google/jax/discussions/6332

The discussion noted the way that Jax allocates memory, which by default is 90% on the first JAX operation which for us was the convolution operation. As the GPU is my display then I think there isn’t enough memory available for JAX to allocate 90% of the memory

@rems75 does this fix the issue for you? If so, I think we can close the issue

@pseudo-rnd-thoughts: Issue #8302 solved this problem for me when running the Flax ImageNet example (add environment variable TF_FORCE_GPU_ALLOW_GROWTH before calling tf datasets)