optax: optax.MultiSteps out of memory
I always get an out of memory error using optax.MultiSteps, even when every_k_schedule=1. Using optax.apply_every(k=1) in a chain works fine.
optimizer = optax.chain(
optax.clip_by_global_norm(0.5),
optax.adam(lr),
#optax.apply_every(k=1)
)
optimizer = optax.MultiSteps(optimizer, every_k_schedule=1)
Later I’m using
opt_state = optimizer.init(params)
and
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
I have no idea what I could be doing wrong. I’m not changing anything else, like batch size.
About this issue
- Original URL
- State: open
- Created a year ago
- Comments: 15 (5 by maintainers)
Commits related to this issue
- Optimise memory usage in `MultiSteps`. Change the implementation to allow JAX/XLA to re-use memory buffers. #472 PiperOrigin-RevId: 561129449 — committed to google-deepmind/optax by hbq1 10 months ago
- Optimise memory usage in `MultiSteps`. Change the implementation to allow JAX/XLA to re-use memory buffers. #472 PiperOrigin-RevId: 561129449 — committed to google-deepmind/optax by hbq1 10 months ago
- Optimise memory usage in `MultiSteps`. Change the implementation to allow JAX/XLA to re-use memory buffers. #472 PiperOrigin-RevId: 561129449 — committed to google-deepmind/optax by hbq1 10 months ago
- Optimise memory usage in `MultiSteps`. Change the implementation to allow JAX/XLA to re-use memory buffers. #472 PiperOrigin-RevId: 561129449 — committed to google-deepmind/optax by hbq1 10 months ago
- Optimise memory usage in `MultiSteps`. Change the implementation to allow JAX/XLA to re-use memory buffers. #472 PiperOrigin-RevId: 561129449 — committed to google-deepmind/optax by hbq1 10 months ago
- Optimise memory usage in `MultiSteps`. Change the implementation to allow JAX/XLA to re-use memory buffers. #472 PiperOrigin-RevId: 561129449 — committed to google-deepmind/optax by hbq1 10 months ago
- Optimise memory usage in `MultiSteps`. Change the implementation to allow JAX/XLA to re-use memory buffers. #472 PiperOrigin-RevId: 561129449 — committed to google-deepmind/optax by hbq1 10 months ago
- Optimise memory usage in `MultiSteps`. Change the implementation to allow JAX/XLA to re-use memory buffers. #472 PiperOrigin-RevId: 561129449 — committed to google-deepmind/optax by hbq1 10 months ago
- Optimise memory usage in `MultiSteps`. Change the implementation to allow JAX/XLA to re-use memory buffers. #472 PiperOrigin-RevId: 561390202 — committed to google-deepmind/optax by hbq1 10 months ago
Just added a PR merging apply_every logic into MultiStep function. From my initial tests, it reduces the memory footprint (able to train Llama2 7b in a v3-8 now) without affecting convergence.
As a follow-up, I was conducting some debugging by myself and it seems that the problem is on this part of the code (line 414):
If I got it right, JAX is allocating memory for both function outputs (_mid_step and _final_step), so this basically doubles the space to store optimizer states and grads.
Still trying to figure out a way to solve it, though.
I can confirm that MultiStep implementation has much larger memory overhead than just one extra buffer for gradient (something like 4x extra buffers). This is very problematic when using this class with large models.
Awesome work @celiolarcher!
jax.lax.condseems to be suboptimal in some use cases, e.g. here, in theory, it should understand that either_mid_stepor_final_stepneeds to be executed, so it shouldn’t allocate memory for both outputs. It might be something that JAX/XLA devs would like to have a look at. Let me know if you’d like me to file a bug to https://github.com/google/jax/issues, or feel free to do it yourself ofc!Hi everyone, thanks for flagging it up. I just merged a new version of
optax.MultiStepswhich should be more memory friendly, could you check this please?Hi! Interesting - thanks for reporting this!
Are you also at more than ~2/3 memory usage when you use
apply_every? From a first look, I could see that the implementation ofapply_everyreturns0*updatesfor skipped steps while MultiSteps constructs a new array of 0s (even ifevery_k_schedule=1) so the former has a better memory footprint. This would explain a higher memory usage by up to 50% - but not more.I’m not sure why the two functions use completely different code paths - we should be able to merge them (and deprecate one of them).