pennylane: JIT-ed calculation of Hessian [grad(grad)] fails with JAX

Expected behavior

I was trying to compute the Hessian and saw that the Jax interface breaks down if we have the JIT on. Without JIT, it works fine. The error seems to be due to the non-availability of JVPs in the host_callback bridge between PL and Jax. To make it work, just remove the @jax.jit from the definition of the circuit.

@josh146 and I discussed this over slack and it seems a bit strange to have something that works with the JIT off not work when we simply JIT things.

Actual behavior

can’t apply forward-mode autodiff (jvp) to a custom_vjp function. JVP rule is implemented only for id_tap, not for call.

Additional information

No response

Source code

import jax
import pennylane as qml

dev = qml.device("default.qubit.jax", wires=1, shots=100)


@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(a):
    qml.RX(a, wires=0)
    return qml.expval(qml.PauliZ(0))

hess = jax.grad(jax.grad(circuit))
print("hessian", hess(0.5))

Tracebacks

Traceback (most recent call last):
  File "/Users/shahnawaz/Dropbox/dev/pennylane/tests/devices/test_jit_vs_no_jit.py", line 14, in <module>
    print("hessian", hess(0.5))
  File "/Users/shahnawaz/Dropbox/dev/pennylane/pennylane/qnode.py", line 549, in __call__
    res = qml.execute(
  File "/Users/shahnawaz/Dropbox/dev/pennylane/pennylane/interfaces/batch/__init__.py", line 412, in execute
    res = _execute(
  File "/Users/shahnawaz/Dropbox/dev/pennylane/pennylane/interfaces/batch/jax_jit.py", line 83, in execute
    return _execute(
  File "/Users/shahnawaz/Dropbox/dev/pennylane/pennylane/interfaces/batch/jax_jit.py", line 218, in _execute
    return wrapped_exec(params)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: JVP rule is implemented only for id_tap, not for call.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/shahnawaz/Dropbox/dev/pennylane/tests/devices/test_jit_vs_no_jit.py", line 14, in <module>
    print("hessian", hess(0.5))
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/api.py", line 918, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/api.py", line 993, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/api.py", line 2312, in _vjp
    out_primal, out_vjp = ad.vjp(
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 116, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 103, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 513, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/api.py", line 918, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/api.py", line 1000, in value_and_grad_f
    g = vjp_py(np.ones((), dtype=dtype))
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/tree_util.py", line 326, in <lambda>
    func = lambda *args, **kw: original_func(*args, **kw)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/api.py", line 2219, in _vjp_pullback_wrapper
    ans = fun(*args)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/tree_util.py", line 326, in <lambda>
    func = lambda *args, **kw: original_func(*args, **kw)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 123, in unbound_vjp
    arg_cts = backward_pass(jaxpr, reduce_axes, consts, dummy_args, cts)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 222, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 558, in call_transpose
    out_flat = primitive.bind(fun, *all_args, **new_params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 323, in process_call
    result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 202, in process_call
    jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 311, in partial_eval
    out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
    return call_bind(self, fun, *args, **params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1635, in process
    return trace.process_call(self, fun, tracers, params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 627, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/xla.py", line 687, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/linear_util.py", line 263, in memoized_fun
    ans = call(fun, *args)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/xla.py", line 759, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars,
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/xla.py", line 771, in lower_xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1542, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1520, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 228, in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 690, in _custom_lin_transpose
    cts_in = bwd.call_wrapped(*res, *cts_out)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/shahnawaz/Dropbox/dev/pennylane/pennylane/interfaces/batch/jax_jit.py", line 180, in wrapped_exec_bwd
    vjps = host_callback.call(
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/experimental/host_callback.py", line 658, in call
    return _call(callback_func, arg, result_shape=result_shape,
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/experimental/host_callback.py", line 708, in _call
    flat_results = outside_call_p.bind(*flat_args, **params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 272, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 288, in process_primitive
    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/experimental/host_callback.py", line 1204, in _outside_call_jvp_rule
    raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
jax._src.traceback_util.UnfilteredStackTrace: NotImplementedError: JVP rule is implemented only for id_tap, not for call.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/shahnawaz/Dropbox/dev/pennylane/tests/devices/test_jit_vs_no_jit.py", line 14, in <module>
    print("hessian", hess(0.5))
  File "/Users/shahnawaz/Dropbox/dev/pennylane/pennylane/interfaces/batch/jax_jit.py", line 180, in wrapped_exec_bwd
    vjps = host_callback.call(
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/experimental/host_callback.py", line 658, in call
    return _call(callback_func, arg, result_shape=result_shape,
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/experimental/host_callback.py", line 708, in _call
    flat_results = outside_call_p.bind(*flat_args, **params)
  File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/experimental/host_callback.py", line 1204, in _outside_call_jvp_rule
    raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
NotImplementedError: JVP rule is implemented only for id_tap, not for call.

System information

Name: PennyLane
Version: 0.21.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /Users/shahnawaz/Dropbox/dev/pennylane
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, retworkx, scipy, semantic_version, toml
Required-by: PennyLane-Forest, PennyLane-Lightning
Platform info:           macOS-11.2.3-x86_64-i386-64bit
Python version:          3.9.7
Numpy version:           1.21.4
Scipy version:           1.7.3
Installed devices:
- lightning.qubit (PennyLane-Lightning-0.20.2)
- forest.numpy_wavefunction (PennyLane-Forest-0.20.0)
- forest.qvm (PennyLane-Forest-0.20.0)
- forest.wavefunction (PennyLane-Forest-0.20.0)
- default.gaussian (PennyLane-0.21.0.dev0)
- default.mixed (PennyLane-0.21.0.dev0)
- default.qubit (PennyLane-0.21.0.dev0)
- default.qubit.autograd (PennyLane-0.21.0.dev0)
- default.qubit.jax (PennyLane-0.21.0.dev0)
- default.qubit.tf (PennyLane-0.21.0.dev0)
- default.qubit.torch (PennyLane-0.21.0.dev0)
None

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.

About this issue

  • Original URL
  • State: open
  • Created 2 years ago
  • Comments: 15 (15 by maintainers)

Most upvoted comments

I just stumbled on this issue, I think a way to avoid this issue would be to define a new jax primitive operation jax.core.Primitive("qml_expval"), defining the CPU and GPU implementations, and then also how to differentiate it. Then qml_expval will be perfectly equivalent to any other jax native operation (so everything can be supported).

The main ‘complication’ is that I’m not sure if you can feed a host_callback as the implementation. I think it could be possible, but should be tried. I’m sure you could feed a C function (because that’s what it natively supports), which then trampolines back into python code.

The interface is quite stable and has not changed in the last 2 years.

Ps: If there’s any interest for that, I don’t have the time to put in, but I can provide some guidance. I did that already for two different packages.

Hi everyone, getting back to this thread as I saw that in Jax, there is a possibility to implement higher order gradients (VJPs) with host_callback using an outside implementation for the gradient computation (e.g., TensorFlow). See the discussion here: https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html#using-call-to-call-a-tensorflow-function-with-reverse-mode-autodiff-support

I had a look at the implementation here: https://github.com/google/jax/blob/main/tests/host_callback_to_tf_test.py#L100 but haven’t figured out completely what is happening in the custom backward pass that allows one to compute higher order gradients with host_callback. It feels like somehow they are just hooking up the TensorFlow autodiff mechanism to the Jax custom_vjp definitions and it works all the way up to the higher order derivative.

Just putting this out here for reference in the future in case we look into this again and it is helpful.

At the same time, passing g along the other arguments creates issues because g may become a BatchTrace object when using certain transforms (e.g., jax.jacobian that uses jax.vmap) and this seems to be the culprit for the original error:

Are you aware of jax.custom_batching.custom_vmap ? If you define a custom_vmap rule for your custom_vjp you might sidestep the issue entirely.

This is interesting, there is a nice example here of how this is done all the way upto JITing: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html

But the problem still remains that we cannot use XLA operations to do the expval evaluation as the expval evaluation happens purely on the quantum device right? So there has to be a break in the computational graph somewhere (what host_callback does). So I don’t know if this would still work.

Unless there is a way to use the XLA custom call (https://www.tensorflow.org/xla/custom_call) to get the value of the expval.

Note: this issue could be potentially resolved by a refactor to the JAX JIT interface. We have this on our radar and would like to look into a resolution in the coming weeks.

In specific, at the moment the g cotangent value is being used as input parameters here. This g value instead should be applied to the result of host_callback.call.

Thanks @antalszava for the explanation. Feel free to leave this issue open till there is a resolution or close it since this is an issue with Jax and not PL.