apex: amp + checkpoint loading = problems

Hi, as you know I have been experimenting with amp for a while now. Today I stumbled upon very unexpected behavior. My FP16 models (trained with amp) do just as well than the FP32 models by themselves. But I usually also ensemble my models by doing something like this:

results = []
for c in checkpoints:
    network.load(checkpoint)
    results.append(network(data))

Interestingly, the performance drops quite a bit if I am doing that with amp enabled. To illustrate this, I created a minimalistic example with mnist:

from copy import deepcopy
import torch
import matplotlib
matplotlib.use("agg")
from torch.backends import cudnn
from apex import amp
import argparse
from torch import cuda
from torch import nn
from urllib import request
import gzip
import pickle
import os
import numpy as np


def load(mnist_file):
    init()
    with open(mnist_file, 'rb') as f:
        mnist = pickle.load(f)
    data_tr = mnist["training_images"].reshape(60000, 1, 28, 28)
    data_te = mnist["test_images"].reshape(10000, 1, 28, 28)
    return data_tr, mnist["training_labels"], data_te, mnist["test_labels"]


filename = [
["training_images","train-images-idx3-ubyte.gz"],
["test_images","t10k-images-idx3-ubyte.gz"],
["training_labels","train-labels-idx1-ubyte.gz"],
["test_labels","t10k-labels-idx1-ubyte.gz"]
]


def download_mnist():
    base_url = "http://yann.lecun.com/exdb/mnist/"
    for name in filename:
        print("Downloading "+name[1]+"...")
        request.urlretrieve(base_url+name[1], name[1])
    print("Download complete.")


def save_mnist():
    mnist = {}
    for name in filename[:2]:
        with gzip.open(name[1], 'rb') as f:
            mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28*28)
    for name in filename[-2:]:
        with gzip.open(name[1], 'rb') as f:
            mnist[name[0]] = np.frombuffer(f.read(), np.uint8, offset=8)
    with open("mnist.pkl", 'wb') as f:
        pickle.dump(mnist,f)
    print("Save complete.")


def init():
    if not os.path.isfile("mnist.pkl"):
        download_mnist()
        save_mnist()


def poly_lr(epoch, max_epochs, initial_lr, exponent=0.9):
    return initial_lr * (1 - epoch / max_epochs)**exponent


class GlobalAveragePool(nn.Module):
    def forward(self, x):
        axes = range(2, len(x.shape))
        for a in axes[::-1]:
            x = x.mean(a, keepdim=False)
        return x


def get_default_network_config():
    """
    returns a dictionary that contains pointers to conv, nonlin and norm ops and the default kwargs I like to use
    :return:
    """
    props = {}
    props['conv_op'] = nn.Conv2d
    props['conv_op_kwargs'] = {'stride': 1, 'dilation': 1, 'bias': True} # kernel size will be set by network!
    props['nonlin'] = nn.LeakyReLU
    props['nonlin_kwargs'] = {'negative_slope': 1e-2, 'inplace': True}
    props['norm_op'] = nn.BatchNorm2d
    props['norm_op_kwargs'] = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
    props['dropout_op'] = nn.Dropout2d
    props['dropout_op_kwargs'] = {'p': 0.0, 'inplace': True}
    return props


class ConvDropoutNormReLU(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, network_props):
        """
        if network_props['dropout_op'] is None then no dropout
        if network_props['norm_op'] is None then no norm
        :param input_channels:
        :param output_channels:
        :param kernel_size:
        :param network_props:
        """
        super(ConvDropoutNormReLU, self).__init__()

        network_props = deepcopy(network_props)  # network_props is a dict and mutable, so we deepcopy to be safe.

        self.conv = network_props['conv_op'](input_channels, output_channels, kernel_size,
                                             padding=[(i - 1) // 2 for i in kernel_size],
                                             **network_props['conv_op_kwargs'])

        # maybe dropout
        if network_props['dropout_op'] is not None:
            self.do = network_props['dropout_op'](**network_props['dropout_op_kwargs'])
        else:
            self.do = lambda x: x

        if network_props['norm_op'] is not None:
            self.norm = network_props['norm_op'](output_channels, **network_props['norm_op_kwargs'])
        else:
            self.norm = lambda x: x

        self.nonlin = network_props['nonlin'](**network_props['nonlin_kwargs'])

        self.all = nn.Sequential(self.conv, self.do, self.norm, self.nonlin)

    def forward(self, x):
        return self.all(x)


class StackedConvLayers(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, network_props, num_convs, first_stride=None):
        """
        if network_props['dropout_op'] is None then no dropout
        if network_props['norm_op'] is None then no norm
        :param input_channels:
        :param output_channels:
        :param kernel_size:
        :param network_props:
        """
        super(StackedConvLayers, self).__init__()

        network_props = deepcopy(network_props)  # network_props is a dict and mutable, so we deepcopy to be safe.
        network_props_first = deepcopy(network_props)

        if first_stride is not None:
            network_props_first['conv_op_kwargs']['stride'] = first_stride

        self.convs = nn.Sequential(
            ConvDropoutNormReLU(input_channels, output_channels, kernel_size, network_props_first),
            *[ConvDropoutNormReLU(output_channels, output_channels, kernel_size, network_props) for _ in range(num_convs - 1)]
        )

    def forward(self, x):
        return self.convs(x)


class SimpleNetwork(nn.Module):
    def __init__(self, props=None):
        super(SimpleNetwork, self).__init__()
        if props is None:
            props = get_default_network_config()
        self.stage1 = StackedConvLayers(1, 16, (3, 3), props, 2, 1)
        self.stage2 = StackedConvLayers(16, 32, (3, 3), props, 2, 2)
        self.stage3 = StackedConvLayers(32, 64, (3, 3), props, 3, 2)
        self.stage4 = StackedConvLayers(64, 128, (3, 3), props, 3, 2)
        self.pool = GlobalAveragePool()
        self.fc = nn.Linear(128, 10, False)

    def forward(self, x):
        return self.fc(self.pool(self.stage4(self.stage3(self.stage2(self.stage1(x))))))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, required=False, default=None)
    parser.add_argument("--test_only", action="store_true", default=False)
    parser.add_argument("-s", help="output filename for trained model")
    parser.add_argument("-test_fnames", required=False, nargs='+')

    args = parser.parse_args()
    seed = args.seed
    test_only = args.test_only

    # seeding
    np.random.seed(seed)
    cuda.manual_seed(np.random.randint(10000))
    cuda.manual_seed_all(np.random.randint(10000))
    cudnn.deterministic = True
    cudnn.benchmark = False

    amp_handle = amp.init()

    data_tr, target_tr, data_te, target_te = load("mnist.pkl")

    data_tr = torch.from_numpy(data_tr).float().cuda()
    target_tr = torch.from_numpy(target_tr).long().cuda()
    data_te = torch.from_numpy(data_te).float().cuda()
    target_te = torch.from_numpy(target_te).long().cuda()

    network = SimpleNetwork().cuda()

    batch_size = 512

    if not test_only:
        optimizer = torch.optim.Adam(network.parameters(), 1e-3, amsgrad=True, weight_decay=1e-5)

        epochs = 30

        loss = torch.nn.CrossEntropyLoss()

        network.train()
        for epoch in range(epochs):
            print(epoch)
            optimizer.param_groups[0]['lr'] = poly_lr(epoch, epochs, 1e-3, 0.9)

            for _ in range(60000 // batch_size):
                optimizer.zero_grad()
                idxs = np.random.choice(60000, batch_size)
                data = data_tr[idxs]
                target = target_tr[idxs]

                out = network(data)

                l = loss(out, target)

                with amp_handle.scale_loss(l, optimizer) as scaled_loss:
                    scaled_loss.backward()

                optimizer.step()

        torch.save(network.state_dict(), args.s)

        with torch.no_grad():
            network.eval()
            out = network(data_te)

            _, amax = out.max(dim=1)
            acc = (amax == target_te).float().mean()
            print("accuracy on test: ", acc)
    else:
        if not isinstance(args.test_fnames, list):
            args.test_fnames = [args.test_fnames]

        for f in args.test_fnames:
            network.load_state_dict(torch.load(f, map_location=torch.device('cuda', torch.cuda.current_device())))

            with torch.no_grad():
                network.eval()
                out = network(data_te)

                _, amax = out.max(dim=1)
                acc = (amax == target_te).float().mean()
                print("file", f, "accuracy on test: ", acc)

I just hacked this together, so please ignore any potential ugliness in the code.

Here is how you can reproduce the problem: First, train the network several times and save to different output files:

python train_mnist.py --seed 1 -s mnist_seed1.model

accuracy on test: tensor(0.9959, device=‘cuda:0’)

python train_mnist.py --seed 2 -s mnist_seed2.model

accuracy on test: tensor(0.9955, device=‘cuda:0’)

python train_mnist.py --seed 3 -s mnist_seed3.model

accuracy on test: tensor(0.9949, device=‘cuda:0’)

Now that you have the trained models, you can run the testing by passing the filenames to the script like this: python train_mnist.py --test_only -test_fnames mnist_seed1.model

file mnist_seed1.model accuracy on test: tensor(0.9959, device=‘cuda:0’)

python train_mnist.py --test_only -test_fnames mnist_seed2.model

file mnist_seed2.model accuracy on test: tensor(0.9955, device=‘cuda:0’)

python train_mnist.py --test_only -test_fnames mnist_seed3.model

file mnist_seed3.model accuracy on test: tensor(0.9949, device=‘cuda:0’)

The script also supports giving it several model checkpoints at once and it will test all of them one after the other. Although I am not ensembling here, this is the same procedure that I do in my ensembling code and the same issue appears here as well: python train_mnist.py --test_only -test_fnames mnist_seed1.model mnist_seed2.model mnist_seed3.model

file mnist_seed1.model accuracy on test: tensor(0.9959, device=‘cuda:0’) file mnist_seed2.model accuracy on test: tensor(0.1135, device=‘cuda:0’) file mnist_seed3.model accuracy on test: tensor(0.1029, device=‘cuda:0’)

If you look into the script (line 240+), it is doing nothing different than before, except loading new checkpoints with network.load_state_dict between test set predictions. We are seeing a big drop in performance from the second checkpoint onwards.

To demonstrate that this is not a problem with the files themselves, I ran it in a different order with the same result: python train_mnist.py --test_only -test_fnames mnist_seed3.model mnist_seed1.model mnist_seed2.model

file mnist_seed3.model accuracy on test: tensor(0.9949, device=‘cuda:0’) file mnist_seed1.model accuracy on test: tensor(0.1036, device=‘cuda:0’) file mnist_seed2.model accuracy on test: tensor(0.1010, device=‘cuda:0’)

I can fix this issue in this particular script by not initializing amp when I am running just the testing (replace amp_handle = amp.init() with

    if not test_only:
        amp_handle = amp.init()

). After replacing that, testing multiple checkpoints runs nicely:

python1 train_mnist.py --test_only -test_fnames mnist_seed1.model mnist_seed2.model mnist_seed3.model

file mnist_seed1.model accuracy on test: tensor(0.9959, device=‘cuda:0’) file mnist_seed2.model accuracy on test: tensor(0.9954, device=‘cuda:0’) file mnist_seed3.model accuracy on test: tensor(0.9949, device=‘cuda:0’)

I am not sure what is going on here, but I think this it would be rather important to understand what is going on. It took me a good 3 hours to finally figure out what was causing my severe performance regression today. Do you have any idea how this issue could be solved? I need to be able to load checkpoints during and after my trainings and rely on this to work 😃

Best, Fabian

About this issue

  • Original URL
  • State: open
  • Created 5 years ago
  • Reactions: 2
  • Comments: 41 (16 by maintainers)

Most upvoted comments

So, i load model, then amp state dick and still have problems with loss spikes, any ideas?

Is there any update yet? I’m running into the same problem and cannot figure out what is causing this. I tried to restore the model and continued to train it but it seemed that I was training from scratch not from checkpoint.

@ptrblck I’m observing the same issue (restarting training on O2 spikes loss, but I can load O2-trained checkpoint in O1 and continue training without problem), could you broadcast on this thread when you guys merge the PR?