jax: Extremely slow when I jit compile a function on GPU.

When I use jit for a function with fft operation on GPU as follows (Note that I use a global variable pulse here). It takes more than 10s to jit compile on GPU while 0.1s on CPU. I can not figure out the reason. Can you explain what happend ?

import time
import jax
import numpy as np
import jax.numpy as jnp

pulse = jax.device_put(np.random.rand(8000))

def f():
    sigTx = jnp.fft.fft(pulse)
    return sigTx

t0 = time.time()
a = jax.jit(f, backend='gpu')()
t1 = time.time()
print('time cost on gpu:',t1-t0)


t0 = time.time()
a = jax.jit(f, backend='cpu')()
t1 = time.time()
print('time cost on cpu:',t1-t0)

截屏2022-05-06 上午10 06 15

Surprisingly, when I pass the global variable pulse as an argument into the function, the compile time is reduce to a normal level

import time
import jax
import numpy as np
import jax.numpy as jnp

pulse = jax.device_put(np.random.rand(10000))

def f(pulse):
    return jnp.fft.fft(pulse)

t0 = time.time()
a = jax.jit(f, backend='gpu')(pulse)
t1 = time.time()
print('time cost on gpu:',t1-t0)


t0 = time.time()
a = jax.jit(f, backend='cpu')(pulse)
t1 = time.time()
print('time cost on cpu:',t1-t0)

截屏2022-05-06 上午10 08 59

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 15 (9 by maintainers)

Most upvoted comments

This is related to constant-folding IMO.

2022-05-06 12:57:16.669751: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:61] Constant folding an instruction is taking > 1s:

  %fft.0 = c64[8000]{0} fft(c64[8000]{0} %constant), fft_type=FFT, fft_length={8000}, metadata={op_name="jit(f)/jit(main)/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(8000,)]" source_file="/home/jiacheng/xxx/xxx.py" source_line=9}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.