jaxopt: Problem differentiating through `solver.run` in `OptaxSolver`

I’ve been trying to use OptaxSolver to perform a simple function minimization, since I want to differentiate through it’s solution (the fixed point of the solver), but ran into an issue I’m not familiar with.

Here’s a MWE for the error message:

import jax
import jax.scipy as jsp
from jaxopt import OptaxSolver
import optax

def pipeline(param_for_grad, data):
    def to_minimize(latent):
        return -jsp.stats.norm.logpdf(data, loc=param_for_grad*latent, scale=1)

    solver = OptaxSolver(fun=to_minimize, opt=optax.adam(3e-4), implicit_diff=True)

    initial, _ = solver.init(init_params = 5.)

    result, _ = solver.run(init_params = initial)

    return result

jax.value_and_grad(pipeline)(2., data=6.)

which yields this error:

CustomVJPException: Detected differentiation of a custom_vjp function with respect to a closed-over value. That isn't supported because the custom VJP rule only specifies how to differentiate the custom_vjp function with respect to explicit input parameters. Try passing the closed-over value into the custom_vjp function as an argument, and adapting the custom_vjp fwd and bwd rules.

My versions are:

jax==0.2.20
jaxlib==0.1.71
jaxopt==0.0.1
optax==0.0.9

Am I doing something very silly? I guess I’m also wondering if this example within the scope of the solver API? I noticed that this doesn’t occur with solver.update, just with solver.run.

Thanks 😃

About this issue

  • Original URL
  • State: open
  • Created 3 years ago
  • Reactions: 2
  • Comments: 21 (8 by maintainers)

Most upvoted comments

Thanks @phinate for the question and @fllinares for the answer!

Indeed, as Felipe explained, your param_for_grad was in the scope (this is what is meant by closed-over value) but it wasn’t an explicit argument of run.

By the way, since run calls init for you, the line

initial, _ = solver.init(init_params = 5.)

is not needed. You can just set initial = 5 and then call run(initial, params_for_grad).

Hi,

I think only a couple of small changes would be needed.

To use implicit differentiation with solver.run, you should (1) expose the args with respect to which you’d like to differentiate the solver’s solution explicitly in the signature of fun and (2) avoid using keyword arguments in the call to solver.run.

In your MWE:

def pipeline(param_for_grad, data):
    def to_minimize(latent, param_for_grad):
        return -jsp.stats.norm.logpdf(data, loc=param_for_grad*latent, scale=1)

    solver = OptaxSolver(fun=to_minimize, opt=optax.adam(5e-2), implicit_diff=True)

    initial, _ = solver.init(init_params = 5.)

    result, _ = solver.run(initial, param_for_grad)

    return result

jax.value_and_grad(pipeline)(2., data=6.)

P.S. I also made a small change to the learning rate so that Adam converges in this example with the default maximum number of steps.

hi again @mblondel, sorry to resurrect this from the dead – my solution that uses closure_convert randomly started leaking tracers, with one of the jax/jaxlib updates and it’s a bit of a nightmare to debug. Luckily, I found a fairly simple MWE:


from functools import partial

import jax
import jax.numpy as jnp
import jaxopt
import optax


# dummy model for test purposes
class Model:
    x: jax.Array
    def __init__(self, x) -> None:
        self.x = x
    def logpdf(self, pars, data):
        return jnp.sum(pars*data*self.x)

@partial(jax.jit, static_argnames=["objective_fn"])
def _minimize(
    objective_fn,
    init_pars,
    lr,
):
    # this is the line added from our discussion above
    converted_fn, aux_pars = jax.closure_convert(objective_fn, init_pars)
    # aux_pars seems to be empty -- would have assumed it was the closed-over vals or similar?
    solver = jaxopt.OptaxSolver(
        fun=converted_fn, opt=optax.adam(lr), implicit_diff=True, maxiter=5000
    )
    return solver.run(init_pars, *aux_pars)[0]


@partial(jax.jit, static_argnames=["model"])
def fit(
    data,
    model,
    init_pars,
    lr = 1e-3,
):
    def fit_objective(pars):
        return -model.logpdf(pars, data)

    fit_res = _minimize(fit_objective, init_pars, lr)
    return fit_res

def pipeline(x):
    model = Model(x)
    mle_pars = fit(
        model=model,
        data=jnp.array([5.0, 5.0]),
        init_pars=jnp.array([1.0, 1.1]),
        lr=1e-3,
    )
    return mle_pars

jax.jacrev(pipeline)(jnp.asarray(0.5))
# >> JaxStackTraceBeforeTransformation

(this is jaxopt==0.6)

Another thing to note: the jaxpr tracing induced by closure_convert seems to really fill up the cache, which made this quite problematic in practice (I had to use @patrick-kidger’s hack from this JAX issue). Just a health warning for anyone else interested in this type of solution!

I can’t see an immediate way, but if we could cast this example into the form you referenced above with the decomposed derivatives, that would be the best way to get around this issue (i.e. avoid closure_convert altogether).

I’m not sure what your setup_objective function is doing but I would try to decompose it

    def pipeline(param_for_grad, latent):
        res = intermediary_step(param_for_grad, **kwargs)

        def objective_fun(params, intermediary_result):
            [...]   # do not use param_for_grad or res here!

        solver = OptaxSolver(fun=objective_fun, opt=optax.adam(5e-2), implicit_diff=True)
        return solver.run(init_params, intermediary_result=res * latent).params

    jax.jacobian(pipeline)(param_for_grad, latent)

The key idea is to use function composition so that the chain rule will apply. You may have to tweak it to your problem but you get the idea.

We decided to use explicit variables because with closed-over-variables there is no way to tell which need to be differentiated and which don’t. This is problematic if you have several big variables in your scope, such as data matrices.

We are working on a documentation, hopefully these things will become clearer soon.