jaxopt: Problem differentiating through `solver.run` in `OptaxSolver`
I’ve been trying to use OptaxSolver
to perform a simple function minimization, since I want to differentiate through it’s solution (the fixed point of the solver), but ran into an issue I’m not familiar with.
Here’s a MWE for the error message:
import jax
import jax.scipy as jsp
from jaxopt import OptaxSolver
import optax
def pipeline(param_for_grad, data):
def to_minimize(latent):
return -jsp.stats.norm.logpdf(data, loc=param_for_grad*latent, scale=1)
solver = OptaxSolver(fun=to_minimize, opt=optax.adam(3e-4), implicit_diff=True)
initial, _ = solver.init(init_params = 5.)
result, _ = solver.run(init_params = initial)
return result
jax.value_and_grad(pipeline)(2., data=6.)
which yields this error:
CustomVJPException: Detected differentiation of a custom_vjp function with respect to a closed-over value. That isn't supported because the custom VJP rule only specifies how to differentiate the custom_vjp function with respect to explicit input parameters. Try passing the closed-over value into the custom_vjp function as an argument, and adapting the custom_vjp fwd and bwd rules.
My versions are:
jax==0.2.20
jaxlib==0.1.71
jaxopt==0.0.1
optax==0.0.9
Am I doing something very silly? I guess I’m also wondering if this example within the scope of the solver API? I noticed that this doesn’t occur with solver.update
, just with solver.run
.
Thanks 😃
About this issue
- Original URL
- State: open
- Created 3 years ago
- Reactions: 2
- Comments: 21 (8 by maintainers)
Thanks @phinate for the question and @fllinares for the answer!
Indeed, as Felipe explained, your
param_for_grad
was in the scope (this is what is meant by closed-over value) but it wasn’t an explicit argument ofrun
.By the way, since
run
callsinit
for you, the lineis not needed. You can just set
initial = 5
and then callrun(initial, params_for_grad)
.Hi,
I think only a couple of small changes would be needed.
To use implicit differentiation with
solver.run
, you should (1) expose theargs
with respect to which you’d like to differentiate the solver’s solution explicitly in the signature offun
and (2) avoid using keyword arguments in the call tosolver.run
.In your MWE:
P.S. I also made a small change to the learning rate so that Adam converges in this example with the default maximum number of steps.
hi again @mblondel, sorry to resurrect this from the dead – my solution that uses
closure_convert
randomly started leaking tracers, with one of thejax/jaxlib
updates and it’s a bit of a nightmare to debug. Luckily, I found a fairly simple MWE:(this is
jaxopt==0.6
)Another thing to note: the
jaxpr
tracing induced byclosure_convert
seems to really fill up the cache, which made this quite problematic in practice (I had to use @patrick-kidger’s hack from this JAX issue). Just a health warning for anyone else interested in this type of solution!I can’t see an immediate way, but if we could cast this example into the form you referenced above with the decomposed derivatives, that would be the best way to get around this issue (i.e. avoid
closure_convert
altogether).I’m not sure what your
setup_objective
function is doing but I would try to decompose itThe key idea is to use function composition so that the chain rule will apply. You may have to tweak it to your problem but you get the idea.
We decided to use explicit variables because with closed-over-variables there is no way to tell which need to be differentiated and which don’t. This is problematic if you have several big variables in your scope, such as data matrices.
We are working on a documentation, hopefully these things will become clearer soon.