sam: model does not converge after adding disable_bn and enable_bn

Hi. Following the suggestion posted in README, I disabled BatchNorm2d as below and then my model would not converge anymore? Did you literally mean BatchNorm2d or SyncBN? Hopefully you can find out what is wrong with this modification. Many thanks!

    for epoch in range(args.epochs):
        model.train()
        log.train(len_dataset=len(dataset.train))

        for batch in dataset.train:
            inputs, targets = (b.to(device) for b in batch)
           # first forward-backward step
            enable_bn(model)    # update BN in the wk
            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, targets)
            loss.mean().backward()
            optimizer.first_step(zero_grad=True)

            # second forward-backward step
            disable_bn(model)   # do not update BN as we are in the perturbation point w_adv
            smooth_crossentropy(model(inputs), targets).mean().backward()
            optimizer.second_step(zero_grad=True)
            train_total_number += targets.shape[0]

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 16 (2 by maintainers)

Most upvoted comments

Here are my two cents on this issue. When working with BatchNorm, there are two set of variables to monitor. The first set has both gamma (aka weight) and beta (aka bias), while the second set has both running_mean and running_var. When this issue was raised, it was about both running_mean and running_var without regarding gamma and beta. Accordingly, the proposed solution added salt to injury and SAM no longer converges.

The initially proposed – not working – solution is to use bn.eval(). This freezes both the running_mean and running_var, but it also freezes beta and gamma. These are two learnable params that should update and remain aligned with the other learnable params. When beta and gamma are frozen using bn.eval, they diverge from the rest of params. This divergence is minimal with minimal impact initially – at the first iterations. Yet, as the number of iteration increases, this divergence increases and the loss diverges to nan eventually.

Accordingly, I propose the following solution. By temporary setting momentum to zero, the running_mean and running_var are technically frozen. Yet, the learnable params beta and gamma are still learnable. The 2-step SAM optimizer would update beta and gamma along the same direction of the rest of params. SAM no longer diverges to nan.

def _disable_running_stats(self, m):
    if isinstance(m, nn.BatchNorm2d):
        m.backup_momentum = m.momentum
        m.momentum = 0
def _enable_running_stats(self, m):
    if isinstance(m, nn.BatchNorm2d):
        m.momentum = m.backup_momentum

These methods can be called using model.apply(self._disable_running_stats) and model.apply(self._enable_running_stats)

One final note for the astute, these two operations (Op1 and Op2) are probably introducing some precision issues.

I hit the same problem when I save the running mean and var at the first pass and restore them at the second pass, the training accuracy is as normal as vanilla SGD but the validation accuracy is almost 0.

@ahmdtaha

When using bn.eval(), the batchnorm layer enters the inference mode. Accordingly, neither set 1 nor set 2 is updated!

to the best of my knowledge, set 1 won’t be updated, but set2 would be updated if you do backward.

For example, most of object detections would call batchnorm.eval() when training (https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/backbones/resnet.py#L657), the main reason is due to small batch size, but the weight and bias of batchnorm would be still updated when calling backward.

If you want to freeze weight and bias you have to set requires_grad to False explictly(https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/backbones/resnet.py#L623).

Hi Ming-Hsuan-Tu @twmht ,

Sorry for the late reply. As mentioned previously batchnorm has two sets of variables Set 1: {running_mean and running_var} Set 2: {beta and gamma}

Set 1 is updated during every forward pass. In contrast, set 2 is updated during every backward pass. Think of {beta and gamma} as {weights and biases} of a typical layer. Basically, {beta and gamma} are used passively during forward passes.

When using bn.eval(), the batchnorm layer enters the inference mode. Accordingly, neither set 1 nor set 2 is updated! Of course, this is an undesired behavior. This behavior has major negative implications with SAM. If you tried bn.eval(), your training will diverge, i.e., loss = nan. On a high level, this happens because almost every layer {weights and biases} is updated normally during the backward pass, while the batchnorm layers {beta and gamma} are not updated!

With a typical optimizer (e.g., SGD or Adam), this scenario won’t lead to loss=nan. But with SAM, things are more entangled 😃 I hope this clarifies things a bit.

@pengbohua I want to revise my previous comment. According to [1], “Empirically, the degree of improvement negatively correlates with the level of inductive biases built into the architecture.” Indeed, when I evaluate SAM with a compact architecture, SAM bring marginal improvement if any. Yet, when I evaluate SAM with huge architecture, SAM delivers significant improvements!

[1] When Vision Transformers Outperform ResNets without Pre-training or Strong Data Augmentations

Here are my two cents on this issue. When working with BatchNorm, there are two set of variables to monitor. The first set has both gamma (aka weight) and beta (aka bias), while the second set has both running_mean and running_var. When this issue was raised, it was about both running_mean and running_var without regarding gamma and beta. Accordingly, the proposed solution added salt to injury and SAM no longer converges.

The initially proposed – not working – solution is to use bn.eval(). This freezes both the running_mean and running_var, but it also freezes beta and gamma. These are two learnable params that should update and remain aligned with the other learnable params. When beta and gamma are frozen using bn.eval, they diverge from the rest of params. This divergence is minimal with minimal impact initially – at the first iterations. Yet, as the number of iteration increases, this divergence increases and the loss diverges to nan eventually.

Accordingly, I propose the following solution. By temporary setting momentum to zero, the running_mean and running_var are technically frozen. Yet, the learnable params beta and gamma are still learnable. The 2-step SAM optimizer would update beta and gamma along the same direction of the rest of params. SAM no longer diverges to nan.

def _disable_running_stats(self, m):
    if isinstance(m, nn.BatchNorm2d):
        m.backup_momentum = m.momentum
        m.momentum = 0
def _enable_running_stats(self, m):
    if isinstance(m, nn.BatchNorm2d):
        m.momentum = m.backup_momentum

These methods can be called using model.apply(self._disable_running_stats) and model.apply(self._enable_running_stats)

One final note for the astute, these two operations (Op1 and Op2) are probably introducing some precision issues.

Thank you very much for sharing these bugfixes with us! Using momentum to bypass the running statistics is very clever 😃 I’ve pushed a new commit that should correct both issues.