jax: Cannot take gradient of VJP function involving new-style RNG key
Description
from jax import custom_vjp, grad
from jax.random import key
@custom_vjp
def find_fixed_point(theta, state):
return state
def _ffp_fwd(theta, state):
return state, None
def _ffp_bwd(residuals, state_bar):
assert False
find_fixed_point.defvjp(_ffp_fwd, _ffp_bwd)
def fixed_point_using_while_of_theta(theta) -> float:
state = (8.0, key(123))
x, _ = find_fixed_point(theta, state)
return x
grad(fixed_point_using_while_of_theta)(3.0)
gives
ValueError: Cannot convert_element_type to dtype=key<fry>
What jax/jaxlib version are you using?
0.4.20
Which accelerator(s) are you using?
No response
Additional system info
No response
NVIDIA GPU info
No response
About this issue
- Original URL
- State: closed
- Created 8 months ago
- Comments: 17 (12 by maintainers)
Okay, I may eventually get to this when it blocks me, but Iām in the middle of a bigger project right now. I was hoping the stack trace might reveal the problem š