accelerate: A BUG? when performing gradient accumulation
System Info
- `Accelerate` version: 0.16.0
- Platform: Linux-4.14.105-1-tlinux3-0013-x86_64-with-glibc2.2.5
- Python version: 3.8.12
- Numpy version: 1.21.4
- PyTorch version (GPU?): 1.10.0 (True)
Information
- The official example scripts
- My own modified scripts
Tasks
- One of the scripts in the examples/ folder of Accelerate or an officially supported
no_trainerscript in theexamplesfolder of thetransformersrepo (such asrun_no_trainer_glue.py) - My own task or dataset (give details below)
Reproduction
step = 0
while True:
for batch in training_dataloader:
step += 1
with accelerator.accumulate(model):
inputs, targets = batch
outputs = model(inputs)
loss = loss_function(outputs, targets)
accelerator.backward(loss)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step % 100 == 0:
for batch in eval_dataloader:
inputs, targets = batch
outputs = model(inputs)
loss = loss_function(outputs, targets)
Expected behavior
As seen, we shuffle from training and evaluating (This is common as we monitor its performance on dev set while training the model ).
but the scheduler’s step will be wrong
It seems the forward steps in evaluating also account for gradient accumulation.
About this issue
- Original URL
- State: closed
- Created a year ago
- Comments: 22 (5 by maintainers)
From my opinion, the singleton of GradientState is a bad design. Each dataloader has a
GradientStateand they shared a global state, that sounds pretty wired. Maybe it should be modified to:GradientStatebelong to each dataloader, and there is a method is accelerator to get currentGradientStatein each loop.@muellerzr
I think the problem is that when we finish the
for batch in eval_dataloaderloop. the state ofend_of_dataloaderis set to True. but note that we are still in thefor batch in training_dataloaderloop. And this will makewith accelerator.accumulate(model)not work as we expetced.I don’t think
with accelerator.disable_gradient_accumulation()is good enough. because, as you can see, the eval code is not in thewith accelerator.accumulate(model)