pennylane: [BUG] With many observables `generate_shifted_tapes()` is called "unreasonably often" resulting in massive performance loss
Expected behavior
When taking the parameter shift hessian of a QNode
returning the expectation value of several compatible observables, I expect the construction of shifted tapes to only take place once per gate, then the derivatives of all observables can be computed from this single set of shifted gates.
Actual behavior
In reality parameter_shift.py:152:expval_param_shift
is called for each observable separately and generate_shifted_tapes()
is therefore called number of parametrized gates times number of observables many times.
This results in a massive performance hit. In the 10 qubit example below, PennyLane spends over 60% of the time copying tapes, compared to just ~4% with the simulation of circuits, even on the slow default.qubit
simulator.
Additional information
To get a nice graphical overview of what is happening run the example below with cProfile
, e.g… like this: python -m cProfile -o timing test.py && gprof2dot -f pstats timing -o graph.dot && dot -Tsvg graph.dot -o graph.svg
Source code
# test.py
# run with:
# python -m cProfile -o timing test.py && gprof2dot -f pstats timing -o graph.dot && dot -Tsvg graph.dot -o graph.svg
import pennylane as qml
from pennylane import numpy as np
num_wires = 10
wires = range(num_wires)
dev = qml.device('default.qubit', num_wires)
np.random.seed(42)
init_params = np.random.randn(num_wires//2)
@qml.qnode(dev, diff_method='parameter-shift')
def qnode(params):
for par_idx, wire in enumerate(wires[::2]):
qml.Hadamard(wire)
qml.CRY(params[par_idx], wires=[wire, wire+1])
return [qml.expval(qml.PauliZ(wire1) @ qml.PauliZ(wire2)) for wire1 in wires for wire2 in wires if wire1 != wire2]
print(len([0 for wire1 in wires for wire2 in wires if wire1 != wire2]))
for _ in range(4):
print(qml.jacobian(qnode)(init_params))
Tracebacks
No response
System information
Python 3.8.8 | packaged by conda-forge | (default, Feb 20 2021, 16:22:27)
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import pennylane as qml; qml.about()
WARNING: pip is being invoked by an old script wrapper. This will fail in a future version of pip.
Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.
To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.
Name: PennyLane
Version: 0.21.0.dev0
Summary: PennyLane is a Python quantum machine learning library by Xanadu Inc.
Home-page: https://github.com/XanaduAI/pennylane
Author: None
Author-email: None
License: Apache License 2.0
Location: /home/cvjjm/src/covqcstack/qcware/pennylane
Requires: numpy, scipy, networkx, retworkx, autograd, toml, appdirs, semantic-version, autoray, cachetools, pennylane-lightning
Required-by: pytket-pennylane, PennyLane-Qchem, PennyLane-Lightning, covvqetools
Platform info: Linux-5.10.102.1-microsoft-standard-WSL2-x86_64-with-glibc2.10
Python version: 3.8.8
Numpy version: 1.20.1
Scipy version: 1.6.1
Installed devices:
- default.gaussian (PennyLane-0.21.0.dev0)
- default.mixed (PennyLane-0.21.0.dev0)
- default.qubit (PennyLane-0.21.0.dev0)
- default.qubit.autograd (PennyLane-0.21.0.dev0)
- default.qubit.jax (PennyLane-0.21.0.dev0)
- default.qubit.tf (PennyLane-0.21.0.dev0)
- default.qubit.torch (PennyLane-0.21.0.dev0)
- pytket.pytketdevice (pytket-pennylane-0.1.0)
- lightning.qubit (PennyLane-Lightning-0.20.2)
- cov.qubit (covvqetools-0.1.1)
Existing GitHub issues
- I have searched existing GitHub issues to make sure the issue does not already exist.
About this issue
- Original URL
- State: closed
- Created 2 years ago
- Comments: 20 (20 by maintainers)
Great! The implementation in #2645 looks very nice.
Hi @cvjjm this is really helpful. Running your provided example takes around 40s locally. Attacking the offending function allows me to bring this down to ~27s. It may take a few days, but I’ll aim to put together a general solution that helps here.
Awesome! Just tested and current master now even beats the performance of v0.18.0 175.19s to 250.03s in a test that took several hours to run with v0.22 - v0.24 😃