jax: Extremely slow when I jit compile a function on GPU.
When I use jit for a function with fft operation on GPU as follows (Note that I use a global variable pulse here). It takes more than 10s to jit compile on GPU while 0.1s on CPU. I can not figure out the reason. Can you explain what happend ?
import time
import jax
import numpy as np
import jax.numpy as jnp
pulse = jax.device_put(np.random.rand(8000))
def f():
sigTx = jnp.fft.fft(pulse)
return sigTx
t0 = time.time()
a = jax.jit(f, backend='gpu')()
t1 = time.time()
print('time cost on gpu:',t1-t0)
t0 = time.time()
a = jax.jit(f, backend='cpu')()
t1 = time.time()
print('time cost on cpu:',t1-t0)
Surprisingly, when I pass the global variable pulse as an argument into the function, the compile time is reduce to a normal level
import time
import jax
import numpy as np
import jax.numpy as jnp
pulse = jax.device_put(np.random.rand(10000))
def f(pulse):
return jnp.fft.fft(pulse)
t0 = time.time()
a = jax.jit(f, backend='gpu')(pulse)
t1 = time.time()
print('time cost on gpu:',t1-t0)
t0 = time.time()
a = jax.jit(f, backend='cpu')(pulse)
t1 = time.time()
print('time cost on cpu:',t1-t0)
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 15 (9 by maintainers)
This is related to constant-folding IMO.