DeepSpeed: [BUG] `TrainSchedule` seems to use one more buffer than what's needed
I believe that TrainSchedule
, which implements the synchronous 1F1B pipeline training schedule, uses one more buffer than what is actually needed in every pipeline stage except for the last.
- Imagining a 4-stage pipeline with 8 microbatches, the first stage will do four forward computations and then one backward computation. So thinking in the high level, it should need only four buffers to remember the activations for the first four forward passes. However,
num_pipe_buffers
will return5
for such a case (self.stages == 4
,self.stage_id == 0
, andself.microbatches == 8
). - In any case, the largest value of
self.stage_id
isself.stages - 1
. Thus,buffers
is never smaller than 2. However, this function currently returnsmax(2, buffers)
, which hints a one-off mistake.max
ing with 2 itself makes sense because the last stage, which only has at most one in-flight microbatch and theoretically only needs one buffer, actually needs two buffers in order forsend
andrecv
to not deadlock or overwrite to each other.
Changing to the following, training should run well in theory.
def num_pipe_buffers(self):
buffers = min(self.stages - self.stage_id, self.micro_batches)
return max(2, buffers)
However, it dies with an assertion error triggered here
which is part of _exec_send_grads
and ensures that intermediate activations from the previous stage in the 'input'
buffer has the .grad
attribute, which are the Jacovian-vector products produced by _exec_backward_pass
.
Why does this happen? Inspecting the PipeInstruction
s generated by the modified TrainSchedule
class for the penultimate stage (self.stage_id == 2
where self.stages == 4
and self.microbatches == 8
):
>>> pprint(list(FixBufferTrainSchedule(8, 4, 2)), width=120)
[[-1],
[-2],
[0, RecvActivation(buffer_id=0), ForwardPass(buffer_id=0)],
[-1, SendActivation(buffer_id=0)],
[1, RecvActivation(buffer_id=1), ForwardPass(buffer_id=1)],
[0, SendActivation(buffer_id=1), RecvGrad(buffer_id=0), BackwardPass(buffer_id=0)],
[2, RecvActivation(buffer_id=0), SendGrad(buffer_id=0), ForwardPass(buffer_id=0)], # WRONG!
[1, SendActivation(buffer_id=0), RecvGrad(buffer_id=1), BackwardPass(buffer_id=1)],
[3, RecvActivation(buffer_id=1), SendGrad(buffer_id=1), ForwardPass(buffer_id=1)],
[2, SendActivation(buffer_id=1), RecvGrad(buffer_id=0), BackwardPass(buffer_id=0)],
[4, RecvActivation(buffer_id=0), SendGrad(buffer_id=0), ForwardPass(buffer_id=0)],
[3, SendActivation(buffer_id=0), RecvGrad(buffer_id=1), BackwardPass(buffer_id=1)],
[5, RecvActivation(buffer_id=1), SendGrad(buffer_id=1), ForwardPass(buffer_id=1)],
[4, SendActivation(buffer_id=1), RecvGrad(buffer_id=0), BackwardPass(buffer_id=0)],
[6, RecvActivation(buffer_id=0), SendGrad(buffer_id=0), ForwardPass(buffer_id=0)],
[5, SendActivation(buffer_id=0), RecvGrad(buffer_id=1), BackwardPass(buffer_id=1)],
[7, RecvActivation(buffer_id=1), SendGrad(buffer_id=1), ForwardPass(buffer_id=1)],
[6, SendActivation(buffer_id=1), RecvGrad(buffer_id=0), BackwardPass(buffer_id=0)],
[8, SendGrad(buffer_id=0)],
[7, RecvGrad(buffer_id=1), BackwardPass(buffer_id=1)],
[9, SendGrad(buffer_id=1)],
[8, ReduceTiedGrads(), ReduceGrads(), OptimizerStep()]]
The integers (-1, -2, 0, -1, …) in front of the list in each line is the current microbatch number for each step. I hacked the source code to better visualize.
See the line that says # WRONG!
, where RecvActivation
is performed on buffer 0, which is holding the gradients produced by the previous BackwardPass(buffer_id=0)
in the previous step. I believe this is overwriting the buffer and thus buffer.grad
would be None
, triggering an assertion error.
This also means that when generating the instructions, all the buffer_id
s in the # WRONG!
step was 0. This means that prev_buffer
and curr_buffer
were both 0.
This happens because the number of buffers returned by the fixed num_pipe_buffers
method is 2
, and it coincides with the distance of the previous and current microbatch ids (specifically 0 and 2). Thus, merely taking the modulo (in self._buffer_idx
) with 2 results in the same buffer id 0. With the current (suspectedly wrong) num_pipe_buffers
, we would take the modulo of 3, which pretty much “solves” this buffer overlap issue.
I suspect that it’s not right to derive the buffer id from the previous and current microbatch; it may need to take the step number into account, for example.
Ideally I wanted to submit a PR that fixes this, but the logic in TrainSchedule
was really complex and I couldn’t just fix this quickly, and I also wanted to ask whether this “extra buffer” observation even makes sense. While this is touching a chunk of code that worked well for three years and may be a convoluted task, I still believe it’s worth fixing (if it’s actually wrong), because at the core of DeepSpeed is ZeRO, which is a technique that trades off performance for memory efficiency.
CC @ShadenSmith
About this issue
- Original URL
- State: closed
- Created a year ago
- Reactions: 1
- Comments: 16 (16 by maintainers)
Commits related to this issue
- fix buffer size for pipeline parallel (#2800) — committed to microsoft/DeepSpeed by tohtana a year ago
- Fix buffer size for pipeline parallel and communication schedule (#2862) * fix buffer size for pipeline parallel (#2800) * improve explanation of buffer size for pipeline parallelism Co-authore... — committed to microsoft/DeepSpeed by tohtana a year ago
@jaywonchung @ShadenSmith
I also checked with Megatron-DeepSpeed (GPT, 165M parameter) and confirmed that the loss values exactly match. The following figures show the loss values resulting from 4 stages.
I will submit the PR for the fix. Thank you @jaywonchung again for your great help!
@jaywonchung , thanks so much for this fantastic find!! Great work, findings, and report; we super appreciate it!
@jaywonchung Thank you for the information!
I found that my settings for reproducibility didn’t properly work and used the code you mentioned. (I needed to remove the pooling layer in AlexNet because it doesn’t support deterministic algorithm) Finally I observed the loss values exactly match. In the following figure, the plots resulting from the same number of stages overlap.
The different numbers of stages lead the different values because the RNG for dropout runs different times. This shouldn’t be a problem.