transformers: Two bugs in AdamW
Environment info
transformersversion: 4.13.0.dev0- Platform: Linux-3.10.0-1160.45.1.el7.x86_64-x86_64-with-glibc2.17
- Python version: 3.9.7
- PyTorch version (GPU?): 1.10.0+cu113 (True)
- Tensorflow version (GPU?): 2.7.0 (False)
- Using GPU in script?: No
- Using distributed or parallel set-up in script?: No
Who can help
@thomwolf and @stas00 should be able to help based on git blame
Information
There are two bugs in the implementation of AdamW.
Here’s the current code https://github.com/manuelciosici/transformers/blob/04683c0659aacf31a1e1df8aa2e6cf7b447a6f12/src/transformers/optimization.py#L324-L371
Weight decay bug
Look at lines 369-370. The weight decay is multiplied with p.data which no longer corresponds to theta_{t-1} since p.data was modified in line 369. Below is a picture of Algorithm 2 from the original Adamw paper that shows on line 12 that the weight decay should be multiplied with the previous step’s parameters (i.e., theta_{t-1}).

From what I can tell, this is a regression since the original AdamW implementation in transformers applied weight decay properly. Here’s the commit that introduces the bug https://github.com/HuggingFace/transformers/commit/ec07cf5a660926833d6f5208b58730e4af8d1178#diff-40c6163602943c11431f1ec360299a7646bb436c691a646b9f54b2284f556ce0
For confirmation that weight decay is currently buggy, see the original AdamW implementation, where, on line 74, the weight decay is multiplied with the old parameters as opposed to the new parameters that are calculated on line 71.
Denominator computation bug
The second bug appears in the computation of the denominator corresponding to line 10 in Algorithm 2 above. In the current code (see link in the Information section), on line 351, the denominator excludes the division by math.sqrt(bias_correction2). On line 357, division by math.sqrt(bias_correction2) appears, but, by this time, eps has already been added to denom, making the division not equivalent to line 10 in Algorithm 10.
From what I can tell, this bug was also introduced as part of commit https://github.com/HuggingFace/transformers/commit/ec07cf5a660926833d6f5208b58730e4af8d1178#diff-40c6163602943c11431f1ec360299a7646bb436c691a646b9f54b2284f556ce0. The previous line update = next_m / (next_v.sqrt() + group['e']) was correct.
For confirmation that the denominator is not properly calculated, see the original AdamW implementation, where, on line 64 the denominator is computed.
To reproduce
Steps to reproduce the behavior:
- Checkout the branch at https://github.com/manuelciosici/transformers/tree/reveal_broken_adamw:
- Run the unit tests in
tests/test_optimization.py - Tests
test_compare_adamw_no_weight_decayandtest_compare_adamw_with_weight_decayshould fail (see the attached failed_tests.txt)
Expected behavior
The two implementations of AdamW should match their parameter updates.
Proposed fix
Checkout the branch at https://github.com/manuelciosici/transformers/tree/fix_adamw . It contains both the unit tests above and a fix for both bugs mentioned above.
I can make a PR once we agree on the two bugs and the fix.
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Reactions: 1
- Comments: 16 (11 by maintainers)
The NVIDIA engineers have been profiling a few things and torch’s AdamW is faster than ours (apparently apex’s is even faster), so I will add this to the performance docs once I’m able to benchmark this when your PR is ready, @manuelciosici
https://github.com/huggingface/transformers/pull/14708
@stas00 Thank you. I work on this during the weekend.
The key to understand is that it’s not implementing AdamW, but a slightly different algorithm.
Users expect exact algorithm implementation out of the box and if it’s not exact it should be named differently.
Perhaps
AdamWHF?This implementation of Adamw, Although slower, seems to give me better performance then the pytorch one in terms of acc and F1. I’m not sure if I’m the only one with this result but if this is the case for multiple persons, deprecating it could be a shame.
Thanks for the summary @stas00, this looks great to me!
@stas00 I was reading it as a reference implementation while trying to understand
deepspeed’s CPU AdamW implementation.One thing to note is that magnitude of both bugs is a function of AdamW’s hyper-parameters (i.e., it is influenced by learning rate, epsilon, and weight decay). For example, for prompt tuning where learning rates can be as high as
0.3, the effect of buggy weight decay will be more pronounced.@sgugger I understand the concerns that fixing the optimizer will lead to misalignment with existing examples and documentation. However, ignoring the bugs is not good either. Since opening the issue, found that I was not the first one to discover the weight decay issue. I expect that, if the code stays as is, the two bugs will be rediscovered periodically.
An alternative to ignoring the bugs would be for
transformersto deprecate its AdamW implementation with a removal target of, saytransformers>=5.0.0(or6.0.0if a longer sunset is necessary) and add a comment in the AdamW implementation explaining the two bugs. This way, current examples & documentation can continue to work as expected, while users migrate totorch’s AdamW. How does this sound?