pennylane: [BUG] jax.jit(jax.grad()) of a circuit with shots crashes

Expected behavior

Jitting the gradient of a QNode with a device using shots, when setting the PRNGKey leads to a crash. I would expect this to work. Below there is a snippet that easily reproduces this issue on master. Do note that if you remove the jax.jit the gradient works, but this is by accident.

I think I know what is causing the bug, but the explanation is a bit involved, I will first give you a TLDR, then I will show you exactly where the crash happens, then I will reason on what is happening there.

TLDR

The problem arises because you are storing a tracer in DEfaultQubiJax._prng_key, but you are not correctly passing this prng key as an argument of the host callback in jax_jit.py:_execute. Conceptually, you should pass as an arg of the callback the prng key like you do for the parameters.

Instead, the device and therefore the _prng_key is captured in a nested series of lambdas/functions called from the callback. Therefore when the callback is executed, he encounters a tracer object for the prng key which is not substituted with concrete values and crashes.

Observing where the crash happens

As I am not very familiar with the interiors of Pennylane, and as this crash happens inside of a callback, preventing proper stack traces from being printed, I had to resort to a very primitive way of debugging. I have added several print statements in the various functions of penny lane. You can install my copy of ‘instrumented penny lane by running’

pip install git+https://github.com/PhilipVinc/pennylane@pv/debug

Using this copy, and running the snippet below, you will see the following messages printed:

INSIDE THE non_diff_wrapper CALLBACK, EXECUTING PYTHON CODE. Called with args=([array(0.34564769), array(0.45750395)], [array([1.])])
 ...
 after batch vjp, ...
  INSIDE cache_execute called with tapes=[<QuantumTape: wires=[0, 1], params=2>, <QuantumTape: wires=[0, 1], params=2>, <QuantumTape: wires=[0, 1], params=2>, <QuantumTape: wires=[0, 1], params=2>] and kwargs={}
  doing some stuff in the wrapper
   this qubit device has prng_key=Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>
   executing the circuit
     inside circuit execute for self=<DefaultQubitJax device (wires=2, shots=1000) at 0x1342ecdf0>.circuit=<QuantumTape: wires=[0, 1], params=2>
     generating samples
      - Sampling basis states for a jax qubit device with self._prng_key=Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>
      - IF THE PRNG KEY ABOVE IS A TRACER, THIS WILL CRASH!
ERROR:absl:Outside call <jax.experimental.host_callback._CallbackWrapper object at 0x134312230> threw exception Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.

The call chain at the point of the crash is the following:

In this function we use the ._prng_key to execute some jax random functions. But as I said before, this is all being executed inside of a callback, so there should be no tracers there! Instead, as the device was captured in some lambdas, the device has a tracer as a prngKey and leads to a crash.

Possible solution

The solution is to pass the prng key as an argument to the callback. In a sense, you’d need to do something similar to cp_tape for the prng key of the device.

However, this seems complicated to do because you are not passing the device itself as an argument to those functions, but it captured inside of lambdas (I think). But maybe someone who is more familiar with pennylane @antalszava might know how to do this?

Source code

from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
import numpy as np

import pennylane as qml

phys_qubits = 2
pars_q      = np.random.rand(3)

def minimal_circ(params, prng_key=None):
    if prng_key is not None:
        dev = qml.device("default.qubit.jax", wires=tuple(range(phys_qubits)), shots=1000, prng_key=prng_key)
    else:
        dev = qml.device("default.qubit.jax", wires=tuple(range(phys_qubits)), shots=1000)
    @qml.qnode(dev, interface="jax",diff_method="parameter-shift")
    def _measure_operator():
        qml.RY(params[0],wires=0)
        qml.RY(params[1],wires=1)
        op = qml.Hamiltonian([1.0],[qml.PauliZ(0) @ qml.PauliZ(1)])
        return qml.expval(op)
    res = _measure_operator()
    return res

grad_fun = jax.grad(minimal_circ)

jax.jit(grad_fun)(pars_q, jax.random.PRNGKey(0))

System information

>>> import pennylane as qml; qml.about()
Name: PennyLane
Version: 0.27.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author:
Author-email:
License: Apache License 2.0
Location: /Users/filippovicentini/Documents/pythonenvs/dev-pennylane/python-3.10.7/lib/python3.10/site-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, retworkx, scipy, semantic-version, toml
Required-by: PennyLane-Lightning

Platform info:           macOS-13.0-x86_64-i386-64bit
Python version:          3.10.7
Numpy version:           1.23.4
Scipy version:           1.9.3
Installed devices:
- default.gaussian (PennyLane-0.27.0.dev0)
- default.mixed (PennyLane-0.27.0.dev0)
- default.qubit (PennyLane-0.27.0.dev0)
- default.qubit.autograd (PennyLane-0.27.0.dev0)
- default.qubit.jax (PennyLane-0.27.0.dev0)
- default.qubit.tf (PennyLane-0.27.0.dev0)
- default.qubit.torch (PennyLane-0.27.0.dev0)
- default.qutrit (PennyLane-0.27.0.dev0)
- null.qubit (PennyLane-0.27.0.dev0)
- lightning.qubit (PennyLane-Lightning-0.26.1)

Existing GitHub issues

  • I have searched existing GitHub issues to make sure the issue does not already exist.

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 23 (23 by maintainers)

Most upvoted comments

Yes it’s on my to do list. I have a deadline on monday so my eta is 10 days to be able to phrase something decently.

And thanks for implementing this thing. It’s very helpful for us being able to work with Jax without worrying too much about how to work around issues…

For sure! Just wanted to leave a comment here, mentioning that although this issue is being closed, we’d like to track the improvements we discussed.

@antalszava changing all host callback to pure_callback(, vectorized=False) gets vmap and jacobian to work with no further effort!

Though performance will be sub-optimal, because he’ll be inserting a loop. But I think it’s possible to make it work without switching to jvp…

Thanks for the snippet!

I’ll surely try this out. Just to understand… how will the RNG seed work in that case? is it using some internal state that gets updated every time he calls back into python/lightnight?

@PhilipVinc was there a specific reason for using the default.qubit.jax device? Although written in JAX natively, our original thinking for that device has been to be used with diff_method=“backprop”.

not really, no. I think I used it because at first I (erroneously) thought I could not mix and match different devices and qml.qnode interfaces, and then it sticked. no other reason.

On top of this, it could be interesting to see how the originally suggested default.qubit.jax device and PRNGKey performs compared to the above example, but it doesn’t seem to unlock a completely different use case, but rather do computation completely in JAX from start to end.

Probably worse. Jax is especially bad at using more than 1 or 2 cores on CPU (I think its BLAS implementation is particularly conservative before switching to multi-threading) and I wouldn’t be surprised if any purpose-written C kernel could beat XLA (Jax compiler) when applying gates…

To answer the question in the brackets: though the callback calls Python code, it encapsulates the quantum device execution which may happen using a remote simulator/a remote QPU.

Yeah, this definitely makes sense. Though for the particular case of a local jax device you could drop it, but I agree with your analysis.

because the pipeline described originally is a Device API that is used by (almost) all devices in our ecosystem

I’m not sure I understand what is limiting here. Probably also because I fail to see exactly where the device is captured in your execute_fn.

My very uninformed understanding is that you are taking the Object-Oriented/Pythonic approach of passing around the method of an object, which implicitly captures (in a somewhat opaque manner) the underlying instance.

The standard way to do this in functional programming would be to split the functions from the data structure, so that you are obliged, in a sense, to pass the data structure as an argument. Jax likes that because it can do its tracer magic on the arguments.

Yes, what you propose would work, in principle.

However, Jax retraces/recompiles every time you change some static information, detected by the hash of the static data. If you feed different devices, with different prng seeds, you will recompile a lot the code.

In my use case, where I have an hybrid structure coupling a Neural Network and a quantum circuit, re-compiling leads to very, very large increases in computational time (at least, when not using shots.).

As a side note, to make this work, you’d need to correctly compute the hash of your devices starting from the static data contained inside (like the prng key), so that if an user changes the prng key in the device, the hash changes, and jax recompiles.

Thank you, actually!

If you would like any opinion or discuss more interactively some of those Jax-related mysteries on a call, feel free to drop me an email. This issue is currently blocking the last section of a paper we’re writing, so I have a strong interest in giving any assistance you might need to address it (cc @co9olguy )