jaxopt: Possible memory leak when calling solver.run multiple times
I am trying to solve a problem where solver.run
is called multiple times to minimize a series of functions while varying a parameter. Using memory_profiler
I can see that the allocated memory increases each time the function solver.run
is called and never decreases.
Here is a minimal example to reproduce the issue:
import jax.numpy as jnp
import jaxopt
from memory_profiler import profile
@profile
def optimize(min):
def obj(x, min):
return jnp.square(x-min).sum()
x0 = jnp.zeros(1)
mm = jnp.array(min)
solver = jaxopt.LBFGS(obj, maxiter=100)
x = solver.run(x0, min=mm).params[0]
print(x)
for i in range(10):
optimize(i)
And here is the corresponding plot of the allocated memory:
Can you please confirm the issue or provide a solution for that? Thanks. Alessandro
About this issue
- Original URL
- State: open
- Created a year ago
- Reactions: 1
- Comments: 17 (1 by maintainers)
It’s nice to have a workaround but shouldn’t garbage collection be able to do this automatically?
This behavior can be avoided using the newly implemented
jax.clear_caches()
in jax (thanks @froystig !).For example, the code below doesn’t have the ever increasing profile. Instead, it has the more expected initial increment and then plateau:
Hi, has been there any progress on this issue?