pennylane: JIT-ed calculation of Hessian [grad(grad)] fails with JAX
Expected behavior
I was trying to compute the Hessian and saw that the Jax interface breaks down if we have the JIT on. Without JIT, it works fine. The error seems to be due to the non-availability of JVPs in the host_callback
bridge between PL and Jax. To make it work, just remove the @jax.jit from the definition of the circuit.
@josh146 and I discussed this over slack and it seems a bit strange to have something that works with the JIT off not work when we simply JIT things.
Actual behavior
can’t apply forward-mode autodiff (jvp) to a custom_vjp function. JVP rule is implemented only for id_tap, not for call.
Additional information
No response
Source code
import jax
import pennylane as qml
dev = qml.device("default.qubit.jax", wires=1, shots=100)
@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(a):
qml.RX(a, wires=0)
return qml.expval(qml.PauliZ(0))
hess = jax.grad(jax.grad(circuit))
print("hessian", hess(0.5))
Tracebacks
Traceback (most recent call last):
File "/Users/shahnawaz/Dropbox/dev/pennylane/tests/devices/test_jit_vs_no_jit.py", line 14, in <module>
print("hessian", hess(0.5))
File "/Users/shahnawaz/Dropbox/dev/pennylane/pennylane/qnode.py", line 549, in __call__
res = qml.execute(
File "/Users/shahnawaz/Dropbox/dev/pennylane/pennylane/interfaces/batch/__init__.py", line 412, in execute
res = _execute(
File "/Users/shahnawaz/Dropbox/dev/pennylane/pennylane/interfaces/batch/jax_jit.py", line 83, in execute
return _execute(
File "/Users/shahnawaz/Dropbox/dev/pennylane/pennylane/interfaces/batch/jax_jit.py", line 218, in _execute
return wrapped_exec(params)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: JVP rule is implemented only for id_tap, not for call.
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/shahnawaz/Dropbox/dev/pennylane/tests/devices/test_jit_vs_no_jit.py", line 14, in <module>
print("hessian", hess(0.5))
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/api.py", line 918, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/api.py", line 993, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/api.py", line 2312, in _vjp
out_primal, out_vjp = ad.vjp(
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 116, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 103, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 513, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/api.py", line 918, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/api.py", line 1000, in value_and_grad_f
g = vjp_py(np.ones((), dtype=dtype))
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/tree_util.py", line 326, in <lambda>
func = lambda *args, **kw: original_func(*args, **kw)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/api.py", line 2219, in _vjp_pullback_wrapper
ans = fun(*args)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/_src/tree_util.py", line 326, in <lambda>
func = lambda *args, **kw: original_func(*args, **kw)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 123, in unbound_vjp
arg_cts = backward_pass(jaxpr, reduce_axes, consts, dummy_args, cts)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 222, in backward_pass
cts_out = get_primitive_transpose(eqn.primitive)(
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 558, in call_transpose
out_flat = primitive.bind(fun, *all_args, **new_params)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
return call_bind(self, fun, *args, **params)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1635, in process
return trace.process_call(self, fun, tracers, params)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 323, in process_call
result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
return call_bind(self, fun, *args, **params)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1635, in process
return trace.process_call(self, fun, tracers, params)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 202, in process_call
jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 311, in partial_eval
out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1632, in bind
return call_bind(self, fun, *args, **params)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1623, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 1635, in process
return trace.process_call(self, fun, tracers, params)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 627, in process_call
return primitive.impl(f, *tracers, **params)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/xla.py", line 687, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/linear_util.py", line 263, in memoized_fun
ans = call(fun, *args)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/xla.py", line 759, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars,
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/xla.py", line 771, in lower_xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1542, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1520, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 228, in backward_pass
cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 690, in _custom_lin_transpose
cts_in = bwd.call_wrapped(*res, *cts_out)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/shahnawaz/Dropbox/dev/pennylane/pennylane/interfaces/batch/jax_jit.py", line 180, in wrapped_exec_bwd
vjps = host_callback.call(
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/experimental/host_callback.py", line 658, in call
return _call(callback_func, arg, result_shape=result_shape,
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/experimental/host_callback.py", line 708, in _call
flat_results = outside_call_p.bind(*flat_args, **params)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/core.py", line 272, in bind
out = top_trace.process_primitive(self, tracers, params)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/interpreters/ad.py", line 288, in process_primitive
primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/experimental/host_callback.py", line 1204, in _outside_call_jvp_rule
raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
jax._src.traceback_util.UnfilteredStackTrace: NotImplementedError: JVP rule is implemented only for id_tap, not for call.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/shahnawaz/Dropbox/dev/pennylane/tests/devices/test_jit_vs_no_jit.py", line 14, in <module>
print("hessian", hess(0.5))
File "/Users/shahnawaz/Dropbox/dev/pennylane/pennylane/interfaces/batch/jax_jit.py", line 180, in wrapped_exec_bwd
vjps = host_callback.call(
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/experimental/host_callback.py", line 658, in call
return _call(callback_func, arg, result_shape=result_shape,
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/experimental/host_callback.py", line 708, in _call
flat_results = outside_call_p.bind(*flat_args, **params)
File "/Users/shahnawaz/miniconda3/envs/jaxnn/lib/python3.9/site-packages/jax/experimental/host_callback.py", line 1204, in _outside_call_jvp_rule
raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
NotImplementedError: JVP rule is implemented only for id_tap, not for call.
System information
Name: PennyLane
Version: 0.21.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: /Users/shahnawaz/Dropbox/dev/pennylane
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, retworkx, scipy, semantic_version, toml
Required-by: PennyLane-Forest, PennyLane-Lightning
Platform info: macOS-11.2.3-x86_64-i386-64bit
Python version: 3.9.7
Numpy version: 1.21.4
Scipy version: 1.7.3
Installed devices:
- lightning.qubit (PennyLane-Lightning-0.20.2)
- forest.numpy_wavefunction (PennyLane-Forest-0.20.0)
- forest.qvm (PennyLane-Forest-0.20.0)
- forest.wavefunction (PennyLane-Forest-0.20.0)
- default.gaussian (PennyLane-0.21.0.dev0)
- default.mixed (PennyLane-0.21.0.dev0)
- default.qubit (PennyLane-0.21.0.dev0)
- default.qubit.autograd (PennyLane-0.21.0.dev0)
- default.qubit.jax (PennyLane-0.21.0.dev0)
- default.qubit.tf (PennyLane-0.21.0.dev0)
- default.qubit.torch (PennyLane-0.21.0.dev0)
None
Existing GitHub issues
- I have searched existing GitHub issues to make sure the issue does not already exist.
About this issue
- Original URL
- State: open
- Created 2 years ago
- Comments: 15 (15 by maintainers)
I just stumbled on this issue, I think a way to avoid this issue would be to define a new jax primitive operation
jax.core.Primitive("qml_expval")
, defining the CPU and GPU implementations, and then also how to differentiate it. Thenqml_expval
will be perfectly equivalent to any other jax native operation (so everything can be supported).The main ‘complication’ is that I’m not sure if you can feed a
host_callback
as the implementation. I think it could be possible, but should be tried. I’m sure you could feed a C function (because that’s what it natively supports), which then trampolines back into python code.The interface is quite stable and has not changed in the last 2 years.
Ps: If there’s any interest for that, I don’t have the time to put in, but I can provide some guidance. I did that already for two different packages.
Hi everyone, getting back to this thread as I saw that in Jax, there is a possibility to implement higher order gradients (VJPs) with
host_callback
using an outside implementation for the gradient computation (e.g., TensorFlow). See the discussion here: https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html#using-call-to-call-a-tensorflow-function-with-reverse-mode-autodiff-supportI had a look at the implementation here: https://github.com/google/jax/blob/main/tests/host_callback_to_tf_test.py#L100 but haven’t figured out completely what is happening in the custom backward pass that allows one to compute higher order gradients with
host_callback
. It feels like somehow they are just hooking up the TensorFlow autodiff mechanism to the Jax custom_vjp definitions and it works all the way up to the higher order derivative.Just putting this out here for reference in the future in case we look into this again and it is helpful.
Are you aware of
jax.custom_batching.custom_vmap
? If you define acustom_vmap
rule for yourcustom_vjp
you might sidestep the issue entirely.This is interesting, there is a nice example here of how this is done all the way upto JITing: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html
But the problem still remains that we cannot use XLA operations to do the expval evaluation as the expval evaluation happens purely on the quantum device right? So there has to be a break in the computational graph somewhere (what
host_callback
does). So I don’t know if this would still work.Unless there is a way to use the XLA custom call (https://www.tensorflow.org/xla/custom_call) to get the value of the expval.
Note: this issue could be potentially resolved by a refactor to the JAX JIT interface. We have this on our radar and would like to look into a resolution in the coming weeks.
In specific, at the moment the
g
cotangent value is being used as input parameters here. Thisg
value instead should be applied to the result ofhost_callback.call
.Thanks @antalszava for the explanation. Feel free to leave this issue open till there is a resolution or close it since this is an issue with Jax and not PL.