jaxopt: Possible memory leak when calling solver.run multiple times

I am trying to solve a problem where solver.run is called multiple times to minimize a series of functions while varying a parameter. Using memory_profiler I can see that the allocated memory increases each time the function solver.run is called and never decreases.

Here is a minimal example to reproduce the issue:

import jax.numpy as jnp
import jaxopt
from memory_profiler import profile

@profile
def optimize(min):

    def obj(x, min):
        return jnp.square(x-min).sum()

    x0 = jnp.zeros(1)
    mm = jnp.array(min)

    solver = jaxopt.LBFGS(obj, maxiter=100)
    x = solver.run(x0, min=mm).params[0]
    print(x)

for i in range(10):
    optimize(i)

And here is the corresponding plot of the allocated memory: Figure_1

Can you please confirm the issue or provide a solution for that? Thanks. Alessandro

About this issue

  • Original URL
  • State: open
  • Created a year ago
  • Reactions: 1
  • Comments: 17 (1 by maintainers)

Most upvoted comments

It’s nice to have a workaround but shouldn’t garbage collection be able to do this automatically?

This behavior can be avoided using the newly implemented jax.clear_caches() in jax (thanks @froystig !).

For example, the code below doesn’t have the ever increasing profile. Instead, it has the more expected initial increment and then plateau:

Figure_1

import jax.numpy as jnp
import jaxopt
import jax
import gc
import time


def optimize(min):

    def obj(x, min):
        return jnp.square(x-min).sum()

    x0 = jnp.zeros(1)
    mm = jnp.array(min)

    solver = jaxopt.LBFGS(obj, maxiter=100)
    x = solver.run(x0, min=mm).params[0]
    print(x)

for i in range(10):
    optimize(i)
    jax.clear_caches()

Hi, has been there any progress on this issue?