DeepSpeed: [BUG] Recommended way to implement EMA

Describe the bug A clear and concise description of what the bug is.

Hi deepspeed team, I have some code that uses exponential moving average (EMA) for training a UNet model, the code relies on the named_parameters() and parameters() of the model to store and update the params, the simplified impl is like below:


# Part 0 ====================================================================

class LitEma(nn.Module):

    def __init__(
        self,
        model:nn.Module,
        decay:float=0.9999,
        use_num_upates:bool=True,
    ):
        super().__init__()
        if decay < 0.0 or decay > 1.0:
            raise ValueError('Decay must be between 0 and 1')

        self.m_name2s_name = {}
        self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
        self.register_buffer(
            'num_updates',
            torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int))

        for name, p in model.named_parameters():
            if p.requires_grad:
                #remove as '.'-character is not allowed in buffers
                s_name = name.replace('.','')
                self.m_name2s_name.update({ name: s_name })
                self.register_buffer(s_name, p.clone().detach().data)

        self.collected_params = []

    def forward(self, model:nn.Module):
        decay = self.decay

        if self.num_updates >= 0:
            self.num_updates += 1
            decay = min(self.decay, (1+self.num_updates)/(10+self.num_updates))

        one_minus_decay = 1.0 - decay

        with torch.no_grad():
            m_param = dict(model.named_parameters())
            shadow_params = dict(self.named_buffers())

            for key in m_param:
                if m_param[key].requires_grad:
                    sname = self.m_name2s_name[key]
                    shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
                    shadow_params[sname].sub_(one_minus_decay*(shadow_params[sname]-m_param[key]))
                else:
                    assert not key in self.m_name2s_name

    def copy_to(self, model:nn.Module):
        m_param = dict(model.named_parameters())
        shadow_params = dict(self.named_buffers())
        for key in m_param:
            if m_param[key].requires_grad:
                m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
            else:
                assert not key in self.m_name2s_name

    def store(self, parameters:Iterable[nn.Parameter]):
        """
        Save the current parameters for restoring later.
        Args:
          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            temporarily stored.
        """
        self.collected_params = [param.clone() for param in parameters]

    def restore(self, parameters:Iterable[nn.Parameter]):
        """
        Restore the parameters stored with the `store` method.
        Useful to validate the model with EMA parameters without affecting the
        original optimization process. Store the parameters before the
        `copy_to` method. After validation (or model saving), use this to
        restore the former parameters.
        Args:
          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            updated with the stored parameters.
        """
        for c_param, param in zip(self.collected_params, parameters):
            param.data.copy_(c_param.data)

# Part 1 ====================================================================

    if self.use_ema:
        self.model_ema = LitEma(self.model)
        print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

# Part 2 ====================================================================

    @contextmanager
    def ema_scope(self, context:str=None):
        if self.use_ema:
            self.model_ema.store(self.model.parameters())
            self.model_ema.copy_to(self.model)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.model.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

# Part 3 ====================================================================

    def on_train_batch_end(self, *args, **kwargs):
        if self.use_ema:
            self.model_ema(self.model)

When using deepspeed, the .parameters() and .named_parameters() all returns empty, I’m wondering what is the recommended way of implementing the above LitEma class with deepspeed? Sorry if this seems to be a dumb question, but I’m new here and with offload and sharding it seems unclear to me how to implement it correctly.

To Reproduce Steps to reproduce the behavior:

  1. Go to ‘…’
  2. Click on ‘…’
  3. Scroll down to ‘…’
  4. See error

Expected behavior A clear and concise description of what you expected to happen.

ds_report output Please run ds_report to give us details about your setup.

Screenshots If applicable, add screenshots to help explain your problem.

System info (please complete the following information):

  • OS: [e.g. Ubuntu 18.04]
  • GPU count and types [e.g. two machines with x8 A100s each]
  • Interconnects (if applicable) [e.g., two machines connected with 100 Gbps IB]
  • Python version
  • Any other relevant info about your setup

Launcher context Are you launching your experiment with the deepspeed launcher, MPI, or something else?

Docker context Are you using a specific docker image that you can share?

Additional context Add any other context about the problem here.

About this issue

  • Original URL
  • State: open
  • Created 2 years ago
  • Reactions: 7
  • Comments: 21 (6 by maintainers)

Most upvoted comments

As another point of reference, setting devices=1 fixes the issue and results in ema_val_loss and val_loss becoming exactly equivalent.

The following code should work (note not tested). Please see here for guide on manipulating z3 models.

def clone_zero_model(src_model, dst_model, zero_stage=0):
    zero_stage_3 = (zero_stage == 3)
    with torch.no_grad():
        for src_param, dst_param in zip(src_model.parameters(), dst_model.parameters()):
            # TODO: use prefiltering for efficiency
            params_to_fetch = _z3_params_to_fetch([src_param, dst_param
                                                   ]) if zero_stage_3 else []
            should_gather_param = len(params_to_fetch) > 0
            with deepspeed.zero.GatheredParameters(
                    params_to_fetch, enabled=should_gather_param):
                dst_param.data.copy_(src_param.data)

For a z3 model can be used as follows:

  main_model = ...
  ema_model = ... # constructed similarly to main_model

def on_train_batch_end(self, *args, **kwargs):
   clone_zero_model(src_model=main_model, dst_model=ema_model, zero_stage=3)

What would be the recommended strategy for creating a copy of the model for proper recognition of parameter partitioning?

Hey trying to use this code myself—could you please share how this is being implemented within your lightning module? I’m currently trying something like this and appears to not be workinng:

In the init:

self.ema_model = DSEma(self.model)

Then:

def on_train_batch_end(self, *args, **kwargs):
        self.ema_model(self.model)

@czczup Hi cz. It does run EMA for me in zero stage 3. But I didn’t check whether it behaves exactly the same as LitEMA.

Hi, I try to work out a usable EMA module with Zero Stage 3. See below:

from deepspeed.runtime.zero import GatheredParameters

class DSEma(nn.Module):
    def __init__(self, model, decay=0.9999, use_num_updates=True):
        super().__init__()
        if decay < 0.0 or decay > 1.0:
            raise ValueError('Decay must be between 0 and 1')

        self.m_name2s_name = {}
        self.decay = decay
        self.num_updates = 0 if use_num_updates else -1

        with GatheredParameters(model.parameters(), fwd_module=self):
            for name, p in model.named_parameters():
                if p.requires_grad:
                    # remove as '.'-character is not allowed in buffers
                    s_name = name.replace('.', '')
                    self.m_name2s_name.update({name: s_name})
                    self.register_buffer(s_name, p.clone().detach().data)
                    # remove as '.'-character is not allowed in buffers
        self.collected_params = []

    def forward(self, model):
        decay = self.decay

        if self.num_updates >= 0:
            self.num_updates += 1
            decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))

        one_minus_decay = 1.0 - decay
        shadow_params = dict(self.named_buffers())

        with torch.no_grad():
            with GatheredParameters(model.parameters()):
                if deepspeed.comm.get_rank() == 0:
                    m_param = dict(model.named_parameters())

                    for key in m_param:
                        if m_param[key].requires_grad:
                            sname = self.m_name2s_name[key]
                            shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
                            shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
                        else:
                            assert not key in self.m_name2s_name

    def copy_to(self, model):
        shadow_params = dict(self.named_buffers())
        with GatheredParameters(model.parameters(), modifier_rank=0):
            if deepspeed.comm.get_rank() == 0:
                m_param = dict(model.named_parameters())
                for key in m_param:
                    if m_param[key].requires_grad:
                        m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
                    else:
                        assert not key in self.m_name2s_name

    def store(self, model):
        """
        Save the current parameters for restoring later.
        Args:
          model: A model that parameters will be stored
        """
        with GatheredParameters(model.parameters()):
            if deepspeed.comm.get_rank() == 0:
                parameters = model.parameters()
                self.collected_params = [param.clone() for param in parameters]

    def restore(self, model):
        """
        Restore the parameters stored with the `store` method.
        Useful to validate the model with EMA parameters without affecting the
        original optimization process. Store the parameters before the
        `copy_to` method. After validation (or model saving), use this to
        restore the former parameters.
        Args:
          model: A model that to restore its parameters.
        """
        with GatheredParameters(model.parameters(), modifier_rank=0):
            if deepspeed.comm.get_rank() == 0:
                parameters = model.parameters()
                for c_param, param in zip(self.collected_params, parameters):
                    param.data.copy_(c_param.data)