Cirq: ParamResolver is slow

Reconstructing a circuit op-wise and replacing all parametrized ops with new ops containing parameter updates is 10-50x faster than using cirq.ParamResolver with sympy symbols…

Example code to reproduce:

from timeit import default_timer as timer
import numpy as np
import sympy
import cirq


def update_params(circuit, params):
    """Competitor method to cirq.ParamResolver."""
    new_op_tree = []
    for op, param in zip(circuit.all_operations(), params):
        new_op_tree.append(op.gate._with_exponent(param/np.pi)(*op.qubits))
    return cirq.Circuit.from_ops(new_op_tree)


trials = 100
for depth in [10, 15, 20]:
    sympy_circuit = cirq.Circuit.from_ops([cirq.Rx(sympy.Symbol(str(k)))(cirq.LineQubit(0)) for k in range(depth)])
    random_params = np.random.randn(trials, depth)
    # time twenty runs
    start = timer()
    for j in range(trials):
        resolver = dict(zip([str(k) for k in range(depth)], random_params[j]))
        wf1 = cirq.Simulator().simulate(sympy_circuit, param_resolver=resolver).final_state
    end = timer() - start
    print(f"{depth} parameters, {trials} trials using Sympy+ParamResolver: {end} seconds")

    start = timer()
    for j in range(trials):
        float_circuit = update_params(sympy_circuit, random_params[j])
        wf2 = cirq.Simulator().simulate(float_circuit).final_state
    end = timer() - start
    print(f"{depth} parameters, {trials} trials using reconstructed circuit: {end} seconds")

produced (cirq v0.5.0, Windows 10 + core i7 gen 7 processor)

>>> 10 parameters, 100 trials using Sympy+ParamResolver: 2.408036000095308 seconds
>>> 10 parameters, 100 trials using reconstructed circuit: 0.1671589999459684 seconds
>>> 15 parameters, 100 trials using Sympy+ParamResolver: 4.347879000008106 seconds
>>> 15 parameters, 100 trials using reconstructed circuit: 0.25207799999043345 seconds
>>> 20 parameters, 100 trials using Sympy+ParamResolver: 7.1194350000005215 seconds
>>> 20 parameters, 100 trials using reconstructed circuit: 0.31734399998094887 seconds

Some plots of how this scales with large numbers of parameters (generated from a different script using the same update_params method): image

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Reactions: 1
  • Comments: 17 (1 by maintainers)

Commits related to this issue

Most upvoted comments

I took a stab at short-circuiting sympy for common cases for PR #2394 . It does not help the above code as written, but, if flattened or if using XPowGate (which doesn’t add the “x/pi” formula), this speeds up the code significantly.

10 parameters, 100 trials using flattened circuit: 0.1600349259097129 seconds
10 parameters, 100 trials using xpow circuit: 0.17739228205755353 seconds
10 parameters, 100 trials using reconstructed circuit: 0.16374025191180408 seconds:
15 parameters, 100 trials using flattened circuit: 0.2257283239159733 seconds
15 parameters, 100 trials using xpow circuit: 0.2516175350174308 seconds
15 parameters, 100 trials using reconstructed circuit: 0.23388241999782622 seconds:
20 parameters, 100 trials using flattened circuit: 0.2913768459111452 seconds
20 parameters, 100 trials using xpow circuit: 0.32354101818054914 seconds
20 parameters, 100 trials using reconstructed circuit: 0.29962173686362803 seconds:

The changed code is below, for reference:

trials = 100
for depth in [10, 15, 20]:
    global sympy_circuit
    global resolver
    sympy_circuit1 = cirq.Circuit.from_ops([cirq.Rx(sympy.Symbol(str(k)))(cirq.LineQubit(0)) for k in range(depth)])
    sympy_circuit2 = cirq.Circuit.from_ops([cirq.XPowGate(exponent=sympy.Symbol(str(k)))(cirq.LineQubit(0)) for k in range(depth)])
    sympy_circuit, expr_map = cirq.flatten(sympy_circuit1)
    random_params = np.random.randn(trials, depth)
    resolver = expr_map.transform_params(dict(zip([str(k) for k in range(depth)], random_params[0])))
    # time twenty runs
    start = timer()
    for j in range(trials):
        wf = cirq.Simulator().simulate(sympy_circuit, param_resolver=resolver).final_state
    end = timer() - start
    print(f"{depth} parameters, {trials} trials using flattened circuit: {end} seconds")

    # time twenty runs
    start = timer()
    for j in range(trials):
        resolver = dict(zip([str(k) for k in range(depth)], random_params[j]))
        wf = cirq.Simulator().simulate(sympy_circuit2, param_resolver=resolver).final_state
    end = timer() - start
    print(f"{depth} parameters, {trials} trials using xpow circuit: {end} seconds")

    start = timer()
    for j in range(trials):
        float_circuit = update_params(sympy_circuit, random_params[j])
        wf2 = cirq.Simulator().simulate(float_circuit).final_state
    end = timer() - start
    print(f"{depth} parameters, {trials} trials using reconstructed circuit: {end} seconds:")