DeepSpeedExamples: DeepSpeed-Chat: prefetch of layers during reward model forward pass leads to error during sample generation

When running step 3 with ZERO stage 3 enabled for both the actor and critic models, I get the following error (line numbers may be offset due to debug statements I’ve added):

File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 529, in <module>
  main()
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 438, in main
  out = trainer.generate_experience(prompts)
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 103, in generate_experience
    seq = self._generate_sequence(prompts)
File "/path/site-packages/deepspeed/runtime/hybrid_engine.py", line 293, in generate
  seq = self.actor_model.module.generate(prompts,
File "/path/site-packages/deepspeed/runtime/hybrid_engine.py", line 293, in generate
  self.fuse_lora_weight()
File "/path/site-packages/deepspeed/runtime/hybrid_engine.py", line 128, in _fuse_lora
    weight.data += lora_scaling * torch.matmul(lora_left_weight.t(), lora_right_weight.t())
RuntimeError: The size of tensor a (8192) must match the size of tensor b (2048) at non-singleton dimension 1

This happens because the weight.data shape does not match the tensor shape resulting from the lora matmul operation.

I am using a system with 4x 16GB V100 GPUs per node with DeepSpeed 0.9.1. I trained a 1.3b-param model in step 1 and 350m-param model in step 2.

My step 3 run command launches 4 processes on one node, binding one process per GPU:

cd training/step3_rlhf_finetuning
OUTPUT=${OUTPUTDIR}/step3-models/1.3b
mkdir -p $OUTPUT
ACTOR_MODEL_PATH=${OUTPUTDIR}/actor-models/1.3b
CRITIC_MODEL_PATH=${OUTPUTDIR}/reward-models/1.3b
ACTOR_ZERO_STAGE=3
CRITIC_ZERO_STAGE=3
jsrun -r 1 --tasks_per_rs 4 -c ALL_CPUS -g ALL_GPUS python3 main.py \
   --per_device_train_batch_size 4 \
   --per_device_mini_train_batch_size 4 \
   --inference_tp_size 1 \
   --max_answer_seq_len 256 \
   --max_prompt_seq_len 256 \
   --actor_model_name_or_path $ACTOR_MODEL_PATH \
   --critic_model_name_or_path $CRITIC_MODEL_PATH \
   --actor_zero_stage $ACTOR_ZERO_STAGE \
   --critic_zero_stage $CRITIC_ZERO_STAGE \
   --num_padding_at_beginning 1 \
   --gradient_accumulation_steps 1 \
   --deepspeed \
   --actor_lora_dim 128 \
   --enable_hybrid_engine \
   --actor_gradient_checkpointing \
   --critic_gradient_checkpointing \
   --output_dir $OUTPUT

After some debugging, I found that the above error arises because the GatheredParameters context does not gather all layers. If I print the tensor shape for each parameter of each layer immediately after GatheredParameters like so:

https://github.com/microsoft/DeepSpeed/blob/050aee287d70157e29f242b3a629a4cb97b4e4e7/deepspeed/runtime/hybrid_engine.py#L238

                with GatheredParameters(non_active_layers):
                    if rank == 0:
                        for layer_id in range(len(self.layer_params)):
                            for p_id, p in enumerate(self.layer_params[layer_id]):
                                print("after gather layer_id", layer_id, p_id, p.shape, flush=True)
                    self._gather_latency = time.time() - self._t0

then I see the following output on the step just before the error:

nonactive all layers 931
after gather layer_id 0 0 torch.Size([0])
after gather layer_id 0 1 torch.Size([0])
after gather layer_id 0 2 torch.Size([0])
after gather layer_id 0 3 torch.Size([0])
after gather layer_id 0 4 torch.Size([0])
after gather layer_id 0 5 torch.Size([8192])
after gather layer_id 0 6 torch.Size([2048, 8192])
after gather layer_id 0 7 torch.Size([2048])
after gather layer_id 0 8 torch.Size([0])
after gather layer_id 0 9 torch.Size([0])
after gather layer_id 0 10 torch.Size([0])
after gather layer_id 0 11 torch.Size([0])
after gather layer_id 0 12 torch.Size([0])
after gather layer_id 0 13 torch.Size([0])
after gather layer_id 0 14 torch.Size([0])
after gather layer_id 0 15 torch.Size([0])
after gather layer_id 1 0 torch.Size([2048])
after gather layer_id 1 1 torch.Size([2048])
after gather layer_id 1 2 torch.Size([2048])
after gather layer_id 1 3 torch.Size([2048])
after gather layer_id 1 4 torch.Size([8192, 2048])
after gather layer_id 1 5 torch.Size([8192])
after gather layer_id 1 6 torch.Size([2048, 8192])
after gather layer_id 1 7 torch.Size([2048])
after gather layer_id 1 8 torch.Size([2048, 2048])
after gather layer_id 1 9 torch.Size([2048])
after gather layer_id 1 10 torch.Size([2048, 2048])
after gather layer_id 1 11 torch.Size([2048])
after gather layer_id 1 12 torch.Size([2048, 2048])
after gather layer_id 1 13 torch.Size([2048])
after gather layer_id 1 14 torch.Size([2048, 2048])
after gather layer_id 1 15 torch.Size([2048])

Note that dimensions of the parameters in layer_id=0 are mostly all zero. On that steps that complete without an error, those parameters have non-zero shapes as shown below. The count of non_active_layers in 962 below vs 931 above.

nonactive all layers 962
after gather layer_id 0 0 torch.Size([2048])
after gather layer_id 0 1 torch.Size([2048])
after gather layer_id 0 2 torch.Size([2048])
after gather layer_id 0 3 torch.Size([2048])
after gather layer_id 0 4 torch.Size([8192, 2048])
after gather layer_id 0 5 torch.Size([8192])
after gather layer_id 0 6 torch.Size([2048, 8192])
after gather layer_id 0 7 torch.Size([2048])
after gather layer_id 0 8 torch.Size([2048, 2048])
after gather layer_id 0 9 torch.Size([2048])
after gather layer_id 0 10 torch.Size([2048, 2048])
after gather layer_id 0 11 torch.Size([2048])
after gather layer_id 0 12 torch.Size([2048, 2048])
after gather layer_id 0 13 torch.Size([2048])
after gather layer_id 0 14 torch.Size([2048, 2048])
after gather layer_id 0 15 torch.Size([2048])

By adding the following lines for further details:

https://github.com/microsoft/DeepSpeed/blob/050aee287d70157e29f242b3a629a4cb97b4e4e7/deepspeed/runtime/hybrid_engine.py#L234-L238

             else:
                from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
                rank = dist.get_rank(group=self.mp_group)

                non_active_layers = get_inactive_params(self.all_layers_params)
                if rank == 0:
                    print("nonactive layers", len(non_active_layers))
                    for lay_id, lay in enumerate(self.all_layers_params):
                        print("all layers", lay_id, hasattr(lay, 'ds_id'), lay.ds_status == ZeroParamStatus.NOT_AVAILABLE, lay.ds_status)

               non_active_lora_params = get_inactive_params(self.all_lora_params)
                if rank == 0:
                    print("nonactive lora layers", len(non_active_lora_params))
                    for lay_id, lay in enumerate(self.all_lora_params):
                        print("lora layers", lay_id, hasattr(lay, 'ds_id'), lay.ds_status == ZeroParamStatus.NOT_AVAILABLE, lay.ds_status)

                non_active_layers.extend(non_active_lora_params)

It seems that the 0-shape parameters are marked as “ds_status == ZeroParamStatus.INFLIGHT” before calling “GatheredParameters”:

[2023-04-17 15:33:56,759] [INFO] [loss_scaler.py:181:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 32768, reducing to 16384
epoch: 0|step: 2|ppo_ep: 1|act_loss: nan|cri_loss: nan|unsuper_loss: 0.0
average reward score: 3.267578125
-------------------------------------------------------------------------------------
|E2E latency=17.17s |Gather latency=0.46s (2.70%) |Generate time=7.04s (41.02%) |Training time=6.82s (39.71%) |Others=3.31 (19.27%)|CurSamplesPerSec=0.93 |AvgSamplesPerSec=0.60
nonactive layers 651
all layers 0 True False ZeroParamStatus.INFLIGHT
all layers 1 True False ZeroParamStatus.INFLIGHT
all layers 2 True False ZeroParamStatus.INFLIGHT
all layers 3 True False ZeroParamStatus.INFLIGHT
all layers 4 True False ZeroParamStatus.INFLIGHT
all layers 5 True False ZeroParamStatus.INFLIGHT
all layers 6 True False ZeroParamStatus.INFLIGHT
all layers 7 True False ZeroParamStatus.INFLIGHT
all layers 8 True False ZeroParamStatus.INFLIGHT
all layers 9 True False ZeroParamStatus.INFLIGHT
all layers 10 True False ZeroParamStatus.INFLIGHT
all layers 11 True False ZeroParamStatus.INFLIGHT
all layers 12 True False ZeroParamStatus.INFLIGHT
all layers 13 True False ZeroParamStatus.INFLIGHT
all layers 14 True False ZeroParamStatus.INFLIGHT
all layers 15 True False ZeroParamStatus.INFLIGHT
all layers 16 True False ZeroParamStatus.INFLIGHT
all layers 17 True False ZeroParamStatus.INFLIGHT
all layers 18 True False ZeroParamStatus.INFLIGHT
all layers 19 True False ZeroParamStatus.INFLIGHT
all layers 20 True False ZeroParamStatus.INFLIGHT
all layers 21 True True ZeroParamStatus.NOT_AVAILABLE
all layers 22 True True ZeroParamStatus.NOT_AVAILABLE
all layers 23 True True ZeroParamStatus.NOT_AVAILABLE
all layers 24 True True ZeroParamStatus.NOT_AVAILABLE
all layers 25 True True ZeroParamStatus.NOT_AVAILABLE
all layers 26 True True ZeroParamStatus.NOT_AVAILABLE
all layers 27 True True ZeroParamStatus.NOT_AVAILABLE
all layers 28 True False ZeroParamStatus.INFLIGHT
all layers 29 True False ZeroParamStatus.INFLIGHT
all layers 30 True True ZeroParamStatus.NOT_AVAILABLE
all layers 31 True True ZeroParamStatus.NOT_AVAILABLE

<snip>

nonactive lora layers 280
lora layers 0 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 1 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 2 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 3 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 4 True False ZeroParamStatus.INFLIGHT
lora layers 5 True False ZeroParamStatus.INFLIGHT
lora layers 6 True False ZeroParamStatus.INFLIGHT
lora layers 7 True False ZeroParamStatus.INFLIGHT
lora layers 8 True False ZeroParamStatus.INFLIGHT
lora layers 9 True False ZeroParamStatus.INFLIGHT
lora layers 10 True False ZeroParamStatus.INFLIGHT
lora layers 11 True False ZeroParamStatus.INFLIGHT
lora layers 12 True True ZeroParamStatus.NOT_AVAILABLE
lora layers 13 True True ZeroParamStatus.NOT_AVAILABLE

I think those parameters are marked as INFLIGHT because they have been prefetched. Adding some more debugging lines to print the stack at the point where the status is set to INFLIGHT:

https://github.com/microsoft/DeepSpeed/blob/050aee287d70157e29f242b3a629a4cb97b4e4e7/deepspeed/runtime/zero/partition_parameters.py#L873-L885

        def all_gather_coalesced(params: Iterable[Parameter], safe_mode: bool = True) -> AllGatherCoalescedHandle:

            # fetches from nvme if the partition is not available and in nvme
            self._ensure_availability_of_partitioned_params(params)

            if self.world_size == 1:
                return _no_gather_coalesced(params)

            #for param in params:
            for p_id, param in enumerate(params):
                if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
                    raise RuntimeError(param.ds_summary())
                param.ds_status = ZeroParamStatus.INFLIGHT
                if dist.get_rank() == 0:
                    print(p_id, "INFLIGHT2")
                    if p_id > 20:
                        print(traceback.print_stack(file=sys.stdout))

I can see those layers are set to INFLIGHT here:

File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 529, in <module>
  main()
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/main.py", line 452, in main
  actor_loss, critic_loss = trainer.train_rlhf(exp_data)
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py", line 180, in train_rlhf
  value = self.critic_model.forward_value(**batch,
File "/path/DeepSpeedExamples/applications/DeepSpeed-Chat/training/utils/model/reward_model.py", line 125, in forward_value
  transformer_outputs = self.rwtranrsformer(
File "/path/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
  result = forward_call(*input, **kwargs)
File "/path/site-packages/transformers/models/opt/modeling_opt.py", line 759, in forward
  decoder_outputs = self.decoder(
File "/path/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
    result = forward_call(*input, **kwargs)
File "/path/site-packages/transformers/models/opt/modeling_opt.py", line 665, in forward
  layer_outputs = torch.utils.checkpoint.checkpoint(
File "/path/site-packages/torch/utils/checkpoint.py", line 235, in checkpoint
  return CheckpointFunction.apply(function, preserve, *args)
File "/path/site-packages/torch/utils/checkpoint.py", line 96, in forward
  outputs = run_function(*args)
File "/path/site-packages/transformers/models/opt/modeling_opt.py", line 661, in custom_forward
  return module(*inputs, output_attentions, None)
File "/path/site-packages/torch/nn/modules/module.py", line 1148, in _call_impl
  result = forward_call(*input, **kwargs)
File "/path/site-packages/transformers/models/opt/modeling_opt.py", line 337, in forward
  hidden_states = self.activation_fn(hidden_states)
File "/path/site-packages/torch/nn/modules/module.py", line 1137, in _call_impl
  result = hook(self, input)
File "/path/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
  ret_val = func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 366, in _pre_forward_module_hook
  self.pre_sub_module_forward_function(module)
File "/path/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
  return func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/parameter_offload.py", line 478, in pre_sub_module_forward_function
  param_coordinator.fetch_sub_module(sub_module)
File "/path/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
  ret_val = func(*args, **kwargs)
File "/path/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
  return func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 333, in fetch_sub_module
  self.__all_gather_params(params_to_prefetch)
File "/path/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
  ret_val = func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py", line 381, in __all_gather_params
  handle = partitioned_params[0].all_gather_coalesced(partitioned_params)
File "/path/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
  ret_val = func(*args, **kwargs)
File "/path/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 878, in all_gather_coalesced
  print(traceback.print_stack(file=sys.stdout))

It seems that the layers are being prefetched during the call to the critic model forward pass:

https://github.com/microsoft/DeepSpeedExamples/blob/2aa7a31b8fdcb34b8ccdc554021a1f5789752ab3/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py#L174

They are still in INFLIGHT status when trying to generate a sample. The get_inactive_params function then only include params marked as NOT_AVAILABLE:

https://github.com/microsoft/DeepSpeed/blob/48297c4841374776a3c4d9319a9b57378f598d65/deepspeed/runtime/utils.py#L972-L975

Later, GatheredParameters may only consider params whose state is NOT_AVAILABLE:

https://github.com/microsoft/DeepSpeed/blob/48297c4841374776a3c4d9319a9b57378f598d65/deepspeed/runtime/zero/partition_parameters.py#L1058

Assuming that diagnosis is correct, I’m not sure what the recommended fix would be. Should get_inactive_params include INFLIGHT params?

About this issue

  • Original URL
  • State: open
  • Created a year ago
  • Comments: 27 (14 by maintainers)

Most upvoted comments

@tjruwase , I think I found the cause.

I believe the problem is that all four models share the same ReLU module object. Each model registers a forward hook on that module in setup_zero_stage3_hooks(). When invoking the forward pass on the ReLU module from the critic model, the hook from the actor model is invoked, which leads to the prefetch of the actor layers.

I found this by adding the following code in deepspeed/runtime/zero/parameter_offload.py to print object addresses of all child modules of each model:

def print_children(module, indent):
  for name, m in module.named_children():
    spaces = " " * indent
    print(spaces, name, id(m))
    print_children(m, indent + 2)

<snip>

    def setup_zero_stage3_hooks(self):
        self.hierarchy = 0

        #reset step if in inference mode
        @instrument_w_nvtx
        def _end_of_forward_hook(module, *args):

            if not torch._C.is_grad_enabled():
                self.get_param_coordinator(training=False).reset_step()

        #likely one of them should be enough but just to be safe
        self._register_hooks_recursively(self.module)
        self.module.register_forward_hook(_end_of_forward_hook)

        # Add top module to stack trace
        global FWD_MODULE_STACK
        FWD_MODULE_STACK.append(self.module)
        if dist.get_rank() == 0:
            print("FWD_MODULE_STACK SETUP length", len(FWD_MODULE_STACK), "id(module)", id(self.module), type(self.module))
            print(str(self.module))
            for p_id, param in enumerate(iter_params(self.module, recurse=True)):
              key = id(param) if hasattr(param, 'ds_id') else id(param.ds_param_alias)
              print("  ", p_id, id(param), type(param), key)
            print_children(self.module, 2)

With that, I get the following example output for the actor and reward models. You can see that the activation_fn module has the same address for all layers in all models.

FWD_MODULE_STACK SETUP length 1 id(module) 35188069303104 <class 'transformers.models.opt.modeling_opt.OPTForCausalLM'>
3: 0:    model 35188109644992
3: 0:      decoder 35188109647776
3: 0:        embed_tokens 35188109647248
3: 0:        embed_positions 35188109647440
3: 0:        layers 35188109646144
3: 0:          0 35188109647392
3: 0:            self_attn 35188109644800
3: 0:              k_proj 35188109646480
3: 0:                lora_dropout 35188109646576
3: 0:              v_proj 35188109644320
3: 0:                lora_dropout 35188092098880
3: 0:              q_proj 35188109644848
3: 0:                lora_dropout 35188092098832
3: 0:              out_proj 35188109647296
3: 0:                lora_dropout 35188092096912
3: 0: -->        activation_fn 35186791027472  <--- same address for all layers, in each model
3: 0:            self_attn_layer_norm 35188109645040
3: 0:            fc1 35188109645616
3: 0:              lora_dropout 35188092098736
3: 0:            fc2 35188109644656
3: 0:              lora_dropout 35188092100560
3: 0:            final_layer_norm 35188108844336
3: 0:          1 35188108843472
3: 0:            self_attn 35188108843328
3: 0:              k_proj 35188109643984
3: 0:                lora_dropout 35188107776880
3: 0:              v_proj 35188108844432
3: 0:                lora_dropout 35188107776592
3: 0:              q_proj 35188108842800
3: 0:                lora_dropout 35188107777408
3: 0:              out_proj 35188108841264
3: 0:                lora_dropout 35188107777456
3: 0: -->        activation_fn 35186791027472  <--- same address for all layers, in each model
3: 0:            self_attn_layer_norm 35188108841456
3: 0:            fc1 35188108842512
3: 0:              lora_dropout 35188107777696
3: 0:            fc2 35188108841888
3: 0:              lora_dropout 35188107776160
3: 0:            final_layer_norm 35188104685504

<snip>

FWD_MODULE_STACK SETUP length 4 id(module) 35188109548272 <class 'utils.model.reward_model.RewardModel'>
3: 0:    v_head 35188109667440
3: 0:    rwtranrsformer 35188109548224
3: 0:      decoder 35188109548320
3: 0:        embed_tokens 35188109548416
3: 0:        embed_positions 35188109548368
3: 0:        project_out 35188109548512
3: 0:        project_in 35188109548560
3: 0:        layers 35188109548656
3: 0:          0 35188109548608
3: 0:            self_attn 35188109548704
3: 0:              k_proj 35188109548800
3: 0:              v_proj 35188109548848
3: 0:              q_proj 35188109548992
3: 0:              out_proj 35188109549040
3: 0: -->        activation_fn 35186791027472 <--- same address for all layers, in each model
3: 0:            self_attn_layer_norm 35188109548752
3: 0:            fc1 35188109549136
3: 0:            fc2 35188109549184
3: 0:            final_layer_norm 35188109549232
3: 0:          1 35188109549328
3: 0:            self_attn 35188109549376
3: 0:              k_proj 35188109549472
3: 0:              v_proj 35188109549520
3: 0:              q_proj 35188109664320
3: 0:              out_proj 35188109664368
3: 0: -->        activation_fn 35186791027472 <--- same address for all layers, in each model
3: 0:            self_attn_layer_norm 35188109549424
3: 0:            fc1 35188109547552
3: 0:            fc2 35188109547792
3: 0:            final_layer_norm3: 0:  35188109547840

As a test, I then found that I could work around the problem by modifying the OPT model to instantiate a unique ReLU object for each layer in transformers/models/opt/modeling_opt.py:

class OPTDecoderLayer(nn.Module):
    def __init__(self, config: OPTConfig):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.self_attn = OPTAttention(
            embed_dim=self.embed_dim,
            num_heads=config.num_attention_heads,
            dropout=config.attention_dropout,
            is_decoder=True,
        )
        self.do_layer_norm_before = config.do_layer_norm_before
        self.dropout = config.dropout
-->     #self.activation_fn = ACT2FN[config.activation_function]
-->     self.activation_fn = nn.ReLU()