pennylane: [BUG] jax.jit(jax.grad()) of a circuit with shots crashes
Expected behavior
Jitting the gradient of a QNode
with a device using shots, when setting the PRNGKey leads to a crash. I would expect this to work.
Below there is a snippet that easily reproduces this issue on master. Do note that if you remove the jax.jit
the gradient works, but this is by accident.
I think I know what is causing the bug, but the explanation is a bit involved, I will first give you a TLDR, then I will show you exactly where the crash happens, then I will reason on what is happening there.
TLDR
The problem arises because you are storing a tracer in DEfaultQubiJax._prng_key
, but you are not correctly passing this prng key as an argument of the host callback in jax_jit.py:_execute
. Conceptually, you should pass as an arg
of the callback the prng key
like you do for the parameters.
Instead, the device
and therefore the _prng_key
is captured in a nested series of lambdas/functions called from the callback. Therefore when the callback is executed, he encounters a tracer object for the prng key which is not substituted with concrete values and crashes.
Observing where the crash happens
As I am not very familiar with the interiors of Pennylane, and as this crash happens inside of a callback, preventing proper stack traces from being printed, I had to resort to a very primitive way of debugging. I have added several print statements in the various functions of penny lane. You can install my copy of ‘instrumented penny lane by running’
pip install git+https://github.com/PhilipVinc/pennylane@pv/debug
Using this copy, and running the snippet below, you will see the following messages printed:
INSIDE THE non_diff_wrapper CALLBACK, EXECUTING PYTHON CODE. Called with args=([array(0.34564769), array(0.45750395)], [array([1.])])
...
after batch vjp, ...
INSIDE cache_execute called with tapes=[<QuantumTape: wires=[0, 1], params=2>, <QuantumTape: wires=[0, 1], params=2>, <QuantumTape: wires=[0, 1], params=2>, <QuantumTape: wires=[0, 1], params=2>] and kwargs={}
doing some stuff in the wrapper
this qubit device has prng_key=Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>
executing the circuit
inside circuit execute for self=<DefaultQubitJax device (wires=2, shots=1000) at 0x1342ecdf0>.circuit=<QuantumTape: wires=[0, 1], params=2>
generating samples
- Sampling basis states for a jax qubit device with self._prng_key=Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>
- IF THE PRNG KEY ABOVE IS A TRACER, THIS WILL CRASH!
ERROR:absl:Outside call <jax.experimental.host_callback._CallbackWrapper object at 0x134312230> threw exception Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The call chain at the point of the crash is the following:
-
this crash is happening after the compilation. Inside of the
_execute
, when executing the host callback we hit theexecute_fn
in the callback.- Notice that the callback is essentially executing python code, so this code is not jitted itself (question: do you really need a callback here?), therefore if you have tracers anywhere at this point, they will lead to crashes.
-
The
execute_fn
is a series of wrappers aroundQubitDevice.batch_execute
, which then callsself.execute
. -
In
QubitDevice.execute
there is a branch that, if a finite number of shots is specified, it callsgenerate_samples
and thenDefaultQubitJax.sample_basis_states
.
In this function we use the ._prng_key
to execute some jax random functions. But as I said before, this is all being executed inside of a callback, so there should be no tracers there! Instead, as the device was captured in some lambdas, the device has a tracer as a prngKey and leads to a crash.
Possible solution
The solution is to pass the prng key as an argument to the callback. In a sense, you’d need to do something similar to cp_tape
for the prng key of the device.
However, this seems complicated to do because you are not passing the device itself as an argument to those functions, but it captured inside of lambdas (I think). But maybe someone who is more familiar with pennylane @antalszava might know how to do this?
Source code
from jax.config import config
config.update("jax_enable_x64", True)
import jax
import jax.numpy as jnp
import numpy as np
import pennylane as qml
phys_qubits = 2
pars_q = np.random.rand(3)
def minimal_circ(params, prng_key=None):
if prng_key is not None:
dev = qml.device("default.qubit.jax", wires=tuple(range(phys_qubits)), shots=1000, prng_key=prng_key)
else:
dev = qml.device("default.qubit.jax", wires=tuple(range(phys_qubits)), shots=1000)
@qml.qnode(dev, interface="jax",diff_method="parameter-shift")
def _measure_operator():
qml.RY(params[0],wires=0)
qml.RY(params[1],wires=1)
op = qml.Hamiltonian([1.0],[qml.PauliZ(0) @ qml.PauliZ(1)])
return qml.expval(op)
res = _measure_operator()
return res
grad_fun = jax.grad(minimal_circ)
jax.jit(grad_fun)(pars_q, jax.random.PRNGKey(0))
System information
>>> import pennylane as qml; qml.about()
Name: PennyLane
Version: 0.27.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: /Users/filippovicentini/Documents/pythonenvs/dev-pennylane/python-3.10.7/lib/python3.10/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, retworkx, scipy, semantic-version, toml
Required-by: PennyLane-Lightning
Platform info: macOS-13.0-x86_64-i386-64bit
Python version: 3.10.7
Numpy version: 1.23.4
Scipy version: 1.9.3
Installed devices:
- default.gaussian (PennyLane-0.27.0.dev0)
- default.mixed (PennyLane-0.27.0.dev0)
- default.qubit (PennyLane-0.27.0.dev0)
- default.qubit.autograd (PennyLane-0.27.0.dev0)
- default.qubit.jax (PennyLane-0.27.0.dev0)
- default.qubit.tf (PennyLane-0.27.0.dev0)
- default.qubit.torch (PennyLane-0.27.0.dev0)
- default.qutrit (PennyLane-0.27.0.dev0)
- null.qubit (PennyLane-0.27.0.dev0)
- lightning.qubit (PennyLane-Lightning-0.26.1)
Existing GitHub issues
- I have searched existing GitHub issues to make sure the issue does not already exist.
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 23 (23 by maintainers)
Yes it’s on my to do list. I have a deadline on monday so my eta is 10 days to be able to phrase something decently.
And thanks for implementing this thing. It’s very helpful for us being able to work with Jax without worrying too much about how to work around issues…
For sure! Just wanted to leave a comment here, mentioning that although this issue is being closed, we’d like to track the improvements we discussed.
@antalszava changing all
host callback
topure_callback(, vectorized=False)
gets vmap and jacobian to work with no further effort!Though performance will be sub-optimal, because he’ll be inserting a loop. But I think it’s possible to make it work without switching to jvp…
Thanks for the snippet!
I’ll surely try this out. Just to understand… how will the RNG seed work in that case? is it using some internal state that gets updated every time he calls back into python/lightnight?
not really, no. I think I used it because at first I (erroneously) thought I could not mix and match different devices and
qml.qnode
interfaces, and then it sticked. no other reason.Probably worse. Jax is especially bad at using more than 1 or 2 cores on CPU (I think its BLAS implementation is particularly conservative before switching to multi-threading) and I wouldn’t be surprised if any purpose-written C kernel could beat XLA (Jax compiler) when applying gates…
Yeah, this definitely makes sense. Though for the particular case of a local
jax
device you could drop it, but I agree with your analysis.I’m not sure I understand what is limiting here. Probably also because I fail to see exactly where the device is
captured
in yourexecute_fn
.My very uninformed understanding is that you are taking the Object-Oriented/Pythonic approach of passing around the method of an object, which implicitly captures (in a somewhat opaque manner) the underlying instance.
The standard way to do this in functional programming would be to split the functions from the data structure, so that you are obliged, in a sense, to pass the data structure as an argument. Jax likes that because it can do its tracer magic on the arguments.
Yes, what you propose would work, in principle.
However, Jax retraces/recompiles every time you change some
static
information, detected by thehash
of the static data. If you feed different devices, with different prng seeds, you will recompile a lot the code.In my use case, where I have an hybrid structure coupling a Neural Network and a quantum circuit, re-compiling leads to very, very large increases in computational time (at least, when not using shots.).
As a side note, to make this work, you’d need to correctly compute the
hash
of your devices starting from the static data contained inside (like the prng key), so that if an user changes the prng key in the device, the hash changes, and jax recompiles.Thank you, actually!
If you would like any opinion or discuss more interactively some of those Jax-related mysteries on a call, feel free to drop me an email. This issue is currently blocking the last section of a paper we’re writing, so I have a strong interest in giving any assistance you might need to address it (cc @co9olguy )