jax: Cannot take gradient of VJP function involving new-style RNG key

Description

from jax import custom_vjp, grad
from jax.random import key

@custom_vjp
def find_fixed_point(theta, state):
    return state

def _ffp_fwd(theta, state):
    return state, None

def _ffp_bwd(residuals, state_bar):
    assert False

find_fixed_point.defvjp(_ffp_fwd, _ffp_bwd)

def fixed_point_using_while_of_theta(theta) -> float:
    state = (8.0, key(123))
    x, _ = find_fixed_point(theta, state)
    return x

grad(fixed_point_using_while_of_theta)(3.0)

gives

ValueError: Cannot convert_element_type to dtype=key<fry>

@froystig

What jax/jaxlib version are you using?

0.4.20

Which accelerator(s) are you using?

No response

Additional system info

No response

NVIDIA GPU info

No response

About this issue

  • Original URL
  • State: closed
  • Created 8 months ago
  • Comments: 17 (12 by maintainers)

Most upvoted comments

Okay, I may eventually get to this when it blocks me, but I’m in the middle of a bigger project right now. I was hoping the stack trace might reveal the problem šŸ˜„