jax: RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm
Running convolutional layers seems to cause an error that Jax does not know what cudnn optimisation algorithm to use This error appears to be Jax only as I have replicated the code with TensorFlow and no error occurs
My jax version is 0.2.24 and jaxlib version is 0.1.74+cuda11.cudnn82 with a Nvidia 3080
The example is taken from the flax readme (https://github.com/google/flax) The bug appears to be only for convolutions as the error does not occur for the MLP example
I haven’t been able to replicate this error as I don’t have another GPU to use I found this similar issue from someone who uses a 3080 like me (https://github.com/google/jax/issues/7953)
import jax.numpy as jnp
import flax.linen as nn
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
model = CNN()
batch = jnp.ones((32, 64, 64, 10)) # (N, H, W, C) format
variables = model.init(jax.random.PRNGKey(0), batch)
# output = model.apply(variables, batch)
Traceback (most recent call last):
File "/home/mark/Documents/programming/test-jax/flax_main.py", line 25, in <module>
variables = model.init(jax.random.PRNGKey(0), batch)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 998, in init
_, v_out = self.init_with_output(
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 968, in init_with_output
return self.apply(
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 936, in apply
return apply(
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/core/scope.py", line 687, in wrapper
y = fn(root, *args, **kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 1178, in scope_fn
return fn(module.clone(parent=scope), *args, **kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 275, in wrapped_module_method
y = fun(self, *args, **kwargs)
File "/home/mark/Documents/programming/test-jax/flax_main.py", line 9, in __call__
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/module.py", line 275, in wrapped_module_method
y = fun(self, *args, **kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py", line 270, in __call__
y = lax.conv_general_dilated(
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/lax/lax.py", line 653, in conv_general_dilated
return conv_general_dilated_p.bind(
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/core.py", line 272, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/core.py", line 624, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 311, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/util.py", line 187, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/_src/util.py", line 180, in cached
return f(*args, **kwargs)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 334, in xla_primitive_callable
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 653, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars,
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 769, in compile
self._executable = XlaCompiledComputation.from_xla_computation(
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 798, in from_xla_computation
compiled = compile_or_get_cached(backend, xla_computation, options)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 87, in compile_or_get_cached
return backend_compile(backend, computation, compile_options)
File "/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/jax/interpreters/xla.py", line 369, in backend_compile
return backend.compile(built_c, compile_options=options)
RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm: INTERNAL: All algorithms tried for %cudnn-conv = (f32[32,64,64,32]{2,1,3,0}, u8[0]{0}) custom-call(f32[32,64,64,10]{2,1,3,0} %copy.3, f32[3,3,10,32]{1,0,2,3} %copy.4), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="jit(conv_general_dilated)/conv_general_dilated[\n batch_group_count=1\n dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2))\n feature_group_count=1\n lhs_dilation=(1, 1)\n lhs_shape=(32, 64, 64, 10)\n padding=((1, 1), (1, 1))\n precision=None\n preferred_element_type=None\n rhs_dilation=(1, 1)\n rhs_shape=(3, 3, 10, 32)\n window_strides=(1, 1)\n]" source_file="/home/mark/anaconda3/envs/test-jax/lib/python3.9/site-packages/flax/linen/linear.py" source_line=270}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm.
Convolution performance may be suboptimal. To ignore this failure and try to use a fallback algorithm, use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning.
Process finished with exit code 1
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Comments: 16 (5 by maintainers)
Hi everyone,
I had the exact same issue described above. I am running on WSL2 on Windows 10. I installed CUDA and CuDNN and then installed
jax[gpu]
viapip
. After settingXLA_PYTHON_CLIENT_MEM_FRACTION=0.87
, my program works perfectly on a 3080, but with 0.9 the sameRuntimeError
is thrown up.Pre-updated memory fraction
nvidia-smi
:Post-updated memory fraction
nvidia-smi
:I am very new to any sort of collaboration on repositories, so apologies if my etiquette is somewhat off, but I was wondering whether this had any updates? Any “best practice” ways to correct this? I am currently setting
XLA_PYTHON_CLIENT_MEM_FRACTION=0.87
in my bash ~/.profile directory and then just running Jax as is.Also, slightly unrelated, but what would be the best way to keep up with updates to this repository? I will be using Jax pretty religiously to build SVGP models as I love its flexibility, and so would like to keep up-to-date.
Thanks for any help!
Fix it, it is a memory allocation issue like suggested below however different export XLA_PYTHON_CLIENT_MEM_FRACTION=0.7
I found this previous discussion that had a very similar problem to mine https://github.com/google/jax/discussions/6332
The discussion noted the way that Jax allocates memory, which by default is 90% on the first JAX operation which for us was the convolution operation. As the GPU is my display then I think there isn’t enough memory available for JAX to allocate 90% of the memory
@rems75 does this fix the issue for you? If so, I think we can close the issue
@pseudo-rnd-thoughts: Issue #8302 solved this problem for me when running the Flax ImageNet example (add environment variable
TF_FORCE_GPU_ALLOW_GROWTH
before calling tf datasets)