transformers: Two bugs in AdamW

Environment info

  • transformers version: 4.13.0.dev0
  • Platform: Linux-3.10.0-1160.45.1.el7.x86_64-x86_64-with-glibc2.17
  • Python version: 3.9.7
  • PyTorch version (GPU?): 1.10.0+cu113 (True)
  • Tensorflow version (GPU?): 2.7.0 (False)
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help

@thomwolf and @stas00 should be able to help based on git blame

Information

There are two bugs in the implementation of AdamW.

Here’s the current code https://github.com/manuelciosici/transformers/blob/04683c0659aacf31a1e1df8aa2e6cf7b447a6f12/src/transformers/optimization.py#L324-L371

Weight decay bug

Look at lines 369-370. The weight decay is multiplied with p.data which no longer corresponds to theta_{t-1} since p.data was modified in line 369. Below is a picture of Algorithm 2 from the original Adamw paper that shows on line 12 that the weight decay should be multiplied with the previous step’s parameters (i.e., theta_{t-1}).

Screen Shot 2021-11-26 at 09 19 33

From what I can tell, this is a regression since the original AdamW implementation in transformers applied weight decay properly. Here’s the commit that introduces the bug https://github.com/HuggingFace/transformers/commit/ec07cf5a660926833d6f5208b58730e4af8d1178#diff-40c6163602943c11431f1ec360299a7646bb436c691a646b9f54b2284f556ce0

For confirmation that weight decay is currently buggy, see the original AdamW implementation, where, on line 74, the weight decay is multiplied with the old parameters as opposed to the new parameters that are calculated on line 71.

Denominator computation bug

The second bug appears in the computation of the denominator corresponding to line 10 in Algorithm 2 above. In the current code (see link in the Information section), on line 351, the denominator excludes the division by math.sqrt(bias_correction2). On line 357, division by math.sqrt(bias_correction2) appears, but, by this time, eps has already been added to denom, making the division not equivalent to line 10 in Algorithm 10.

From what I can tell, this bug was also introduced as part of commit https://github.com/HuggingFace/transformers/commit/ec07cf5a660926833d6f5208b58730e4af8d1178#diff-40c6163602943c11431f1ec360299a7646bb436c691a646b9f54b2284f556ce0. The previous line update = next_m / (next_v.sqrt() + group['e']) was correct.

For confirmation that the denominator is not properly calculated, see the original AdamW implementation, where, on line 64 the denominator is computed.

To reproduce

Steps to reproduce the behavior:

  1. Checkout the branch at https://github.com/manuelciosici/transformers/tree/reveal_broken_adamw:
  2. Run the unit tests in tests/test_optimization.py
  3. Tests test_compare_adamw_no_weight_decay and test_compare_adamw_with_weight_decay should fail (see the attached failed_tests.txt)

Expected behavior

The two implementations of AdamW should match their parameter updates.

Proposed fix

Checkout the branch at https://github.com/manuelciosici/transformers/tree/fix_adamw . It contains both the unit tests above and a fix for both bugs mentioned above.

I can make a PR once we agree on the two bugs and the fix.

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Reactions: 1
  • Comments: 16 (11 by maintainers)

Most upvoted comments

The NVIDIA engineers have been profiling a few things and torch’s AdamW is faster than ours (apparently apex’s is even faster), so I will add this to the performance docs once I’m able to benchmark this when your PR is ready, @manuelciosici

https://github.com/huggingface/transformers/pull/14708

@stas00 Thank you. I work on this during the weekend.

The key to understand is that it’s not implementing AdamW, but a slightly different algorithm.

Users expect exact algorithm implementation out of the box and if it’s not exact it should be named differently.

Perhaps AdamWHF?

This implementation of Adamw, Although slower, seems to give me better performance then the pytorch one in terms of acc and F1. I’m not sure if I’m the only one with this result but if this is the case for multiple persons, deprecating it could be a shame.

Thanks for the summary @stas00, this looks great to me!

@stas00 I was reading it as a reference implementation while trying to understand deepspeed’s CPU AdamW implementation.

One thing to note is that magnitude of both bugs is a function of AdamW’s hyper-parameters (i.e., it is influenced by learning rate, epsilon, and weight decay). For example, for prompt tuning where learning rates can be as high as 0.3, the effect of buggy weight decay will be more pronounced.

@sgugger I understand the concerns that fixing the optimizer will lead to misalignment with existing examples and documentation. However, ignoring the bugs is not good either. Since opening the issue, found that I was not the first one to discover the weight decay issue. I expect that, if the code stays as is, the two bugs will be rediscovered periodically.

An alternative to ignoring the bugs would be for transformers to deprecate its AdamW implementation with a removal target of, say transformers>=5.0.0 (or 6.0.0 if a longer sunset is necessary) and add a comment in the AdamW implementation explaining the two bugs. This way, current examples & documentation can continue to work as expected, while users migrate to torch’s AdamW. How does this sound?