DeepSpeed: [BUG] `TrainSchedule` seems to use one more buffer than what's needed

I believe that TrainSchedule, which implements the synchronous 1F1B pipeline training schedule, uses one more buffer than what is actually needed in every pipeline stage except for the last.

https://github.com/microsoft/DeepSpeed/blob/d323abd80f62bebb9924bb85feb72b57c25af50d/deepspeed/runtime/pipe/schedule.py#L243-L247

  1. Imagining a 4-stage pipeline with 8 microbatches, the first stage will do four forward computations and then one backward computation. So thinking in the high level, it should need only four buffers to remember the activations for the first four forward passes. However, num_pipe_buffers will return 5 for such a case (self.stages == 4, self.stage_id == 0, and self.microbatches == 8).
  2. In any case, the largest value of self.stage_id is self.stages - 1. Thus, buffers is never smaller than 2. However, this function currently returns max(2, buffers), which hints a one-off mistake. maxing with 2 itself makes sense because the last stage, which only has at most one in-flight microbatch and theoretically only needs one buffer, actually needs two buffers in order for send and recv to not deadlock or overwrite to each other.

Changing to the following, training should run well in theory.

    def num_pipe_buffers(self):
        buffers = min(self.stages - self.stage_id, self.micro_batches)
        return max(2, buffers)

However, it dies with an assertion error triggered here

https://github.com/microsoft/DeepSpeed/blob/d323abd80f62bebb9924bb85feb72b57c25af50d/deepspeed/runtime/pipe/engine.py#L1032

which is part of _exec_send_grads and ensures that intermediate activations from the previous stage in the 'input' buffer has the .grad attribute, which are the Jacovian-vector products produced by _exec_backward_pass.

Why does this happen? Inspecting the PipeInstructions generated by the modified TrainSchedule class for the penultimate stage (self.stage_id == 2 where self.stages == 4 and self.microbatches == 8):

>>> pprint(list(FixBufferTrainSchedule(8, 4, 2)), width=120)
[[-1],
 [-2],
 [0, RecvActivation(buffer_id=0), ForwardPass(buffer_id=0)],
 [-1, SendActivation(buffer_id=0)],
 [1, RecvActivation(buffer_id=1), ForwardPass(buffer_id=1)],
 [0, SendActivation(buffer_id=1), RecvGrad(buffer_id=0), BackwardPass(buffer_id=0)],
 [2, RecvActivation(buffer_id=0), SendGrad(buffer_id=0), ForwardPass(buffer_id=0)],  # WRONG!
 [1, SendActivation(buffer_id=0), RecvGrad(buffer_id=1), BackwardPass(buffer_id=1)],
 [3, RecvActivation(buffer_id=1), SendGrad(buffer_id=1), ForwardPass(buffer_id=1)],
 [2, SendActivation(buffer_id=1), RecvGrad(buffer_id=0), BackwardPass(buffer_id=0)],
 [4, RecvActivation(buffer_id=0), SendGrad(buffer_id=0), ForwardPass(buffer_id=0)],
 [3, SendActivation(buffer_id=0), RecvGrad(buffer_id=1), BackwardPass(buffer_id=1)],
 [5, RecvActivation(buffer_id=1), SendGrad(buffer_id=1), ForwardPass(buffer_id=1)],
 [4, SendActivation(buffer_id=1), RecvGrad(buffer_id=0), BackwardPass(buffer_id=0)],
 [6, RecvActivation(buffer_id=0), SendGrad(buffer_id=0), ForwardPass(buffer_id=0)],
 [5, SendActivation(buffer_id=0), RecvGrad(buffer_id=1), BackwardPass(buffer_id=1)],
 [7, RecvActivation(buffer_id=1), SendGrad(buffer_id=1), ForwardPass(buffer_id=1)],
 [6, SendActivation(buffer_id=1), RecvGrad(buffer_id=0), BackwardPass(buffer_id=0)],
 [8, SendGrad(buffer_id=0)],
 [7, RecvGrad(buffer_id=1), BackwardPass(buffer_id=1)],
 [9, SendGrad(buffer_id=1)],
 [8, ReduceTiedGrads(), ReduceGrads(), OptimizerStep()]]

The integers (-1, -2, 0, -1, …) in front of the list in each line is the current microbatch number for each step. I hacked the source code to better visualize.

See the line that says # WRONG!, where RecvActivation is performed on buffer 0, which is holding the gradients produced by the previous BackwardPass(buffer_id=0) in the previous step. I believe this is overwriting the buffer and thus buffer.grad would be None, triggering an assertion error.

This also means that when generating the instructions, all the buffer_ids in the # WRONG! step was 0. This means that prev_buffer and curr_buffer were both 0.

https://github.com/microsoft/DeepSpeed/blob/d323abd80f62bebb9924bb85feb72b57c25af50d/deepspeed/runtime/pipe/schedule.py#L198-L201

This happens because the number of buffers returned by the fixed num_pipe_buffers method is 2, and it coincides with the distance of the previous and current microbatch ids (specifically 0 and 2). Thus, merely taking the modulo (in self._buffer_idx) with 2 results in the same buffer id 0. With the current (suspectedly wrong) num_pipe_buffers, we would take the modulo of 3, which pretty much “solves” this buffer overlap issue.

https://github.com/microsoft/DeepSpeed/blob/d323abd80f62bebb9924bb85feb72b57c25af50d/deepspeed/runtime/pipe/schedule.py#L105-L117

I suspect that it’s not right to derive the buffer id from the previous and current microbatch; it may need to take the step number into account, for example.

Ideally I wanted to submit a PR that fixes this, but the logic in TrainSchedule was really complex and I couldn’t just fix this quickly, and I also wanted to ask whether this “extra buffer” observation even makes sense. While this is touching a chunk of code that worked well for three years and may be a convoluted task, I still believe it’s worth fixing (if it’s actually wrong), because at the core of DeepSpeed is ZeRO, which is a technique that trades off performance for memory efficiency.

CC @ShadenSmith

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Reactions: 1
  • Comments: 16 (16 by maintainers)

Commits related to this issue

Most upvoted comments

@jaywonchung @ShadenSmith

I also checked with Megatron-DeepSpeed (GPT, 165M parameter) and confirmed that the loss values exactly match. The following figures show the loss values resulting from 4 stages.

I will submit the PR for the fix. Thank you @jaywonchung again for your great help!

loss_gpt_no_zero loss_gpt_zero1

@jaywonchung , thanks so much for this fantastic find!! Great work, findings, and report; we super appreciate it!

@jaywonchung Thank you for the information!

I found that my settings for reproducibility didn’t properly work and used the code you mentioned. (I needed to remove the pooling layer in AlexNet because it doesn’t support deterministic algorithm) Finally I observed the loss values exactly match. In the following figure, the plots resulting from the same number of stages overlap.

loss_no_zero loss_zero1

The different numbers of stages lead the different values because the RNG for dropout runs different times. This shouldn’t be a problem.