optax: optax.MultiSteps out of memory

I always get an out of memory error using optax.MultiSteps, even when every_k_schedule=1. Using optax.apply_every(k=1) in a chain works fine.

optimizer = optax.chain(
    optax.clip_by_global_norm(0.5),
    optax.adam(lr),
    #optax.apply_every(k=1)
)
optimizer = optax.MultiSteps(optimizer, every_k_schedule=1)

Later I’m using opt_state = optimizer.init(params) and

updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)

I have no idea what I could be doing wrong. I’m not changing anything else, like batch size.

About this issue

  • Original URL
  • State: open
  • Created a year ago
  • Comments: 15 (5 by maintainers)

Commits related to this issue

Most upvoted comments

Just added a PR merging apply_every logic into MultiStep function. From my initial tests, it reduces the memory footprint (able to train Llama2 7b in a v3-8 now) without affecting convergence.

As a follow-up, I was conducting some debugging by myself and it seems that the problem is on this part of the code (line 414):

new_updates, new_state = jax.lax.cond(
          state.mini_step < k_steps - 1,
          _mid_step, _final_step, *(state, params, acc_grads))

If I got it right, JAX is allocating memory for both function outputs (_mid_step and _final_step), so this basically doubles the space to store optimizer states and grads.

Still trying to figure out a way to solve it, though.

I can confirm that MultiStep implementation has much larger memory overhead than just one extra buffer for gradient (something like 4x extra buffers). This is very problematic when using this class with large models.

Awesome work @celiolarcher!

jax.lax.cond seems to be suboptimal in some use cases, e.g. here, in theory, it should understand that either _mid_step or _final_step needs to be executed, so it shouldn’t allocate memory for both outputs. It might be something that JAX/XLA devs would like to have a look at. Let me know if you’d like me to file a bug to https://github.com/google/jax/issues, or feel free to do it yourself ofc!

Hi everyone, thanks for flagging it up. I just merged a new version of optax.MultiSteps which should be more memory friendly, could you check this please?

Hi! Interesting - thanks for reporting this!

Are you also at more than ~2/3 memory usage when you use apply_every? From a first look, I could see that the implementation of apply_every returns 0*updates for skipped steps while MultiSteps constructs a new array of 0s (even if every_k_schedule=1) so the former has a better memory footprint. This would explain a higher memory usage by up to 50% - but not more.

I’m not sure why the two functions use completely different code paths - we should be able to merge them (and deprecate one of them).