pytorch-lightning: incorrect global_step with multiple optimizers and automatic_optimization=False

Bug description

Hello,

I encountered a bug when training with automatic_optimization = False and two optimizers.

In summary: the global_step attribute of the trainer and the lightning module is tracking the total number of calls to optimizer.step() (in my case, two per training_step), rather than the total number of iterations of the dataloader.

This conflicts with the notion of step in arguments like log_every_n_steps and val_check_interval in the trainer. Case in point, if we call

self.log("global_step", self.global_step)

inside training_step, with CSVLogger, log_every_n_steps=10, and two optimizer.step()s per training_step, the CSV logs show:

global_step,epoch,step
20.0,0,9
40.0,0,19
60.0,0,29
80.0,0,39
100.0,0,49

Note how global_step conflicts with step, and in fact is twice the expected value, since we have two optimizers.

I have attached a complete code example that replicates the issue.

What version are you seeing the problem on?

v2.0

How to reproduce the bug

import pytorch_lightning as pl
import torch

from pytorch_lightning.loggers import CSVLogger
from torch.utils.data import TensorDataset, IterableDataset, DataLoader


SEMVER = tuple(int(x) for x in pl.__version__.split("."))
assert SEMVER >= (2, 0, 3)


class LinearRegression(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.gamma = torch.nn.Parameter(torch.ones(()))
        self.beta = torch.nn.Parameter(torch.zeros(()))
        self.automatic_optimization = False

    def forward(self, x):
        return self.gamma * x + self.beta

    def configure_optimizers(self):
        gamma_opt = torch.optim.SGD([self.gamma], lr=1e-2)
        beta_opt = torch.optim.SGD([self.beta], lr=1e-2)
        return gamma_opt, beta_opt

    def training_step(self, batch, batch_idx):

        # Two optimizers.
        gamma_opt, beta_opt = self.optimizers()

        # Forward pass with loss.
        inputs, targets = batch
        predictions = self(inputs)
        loss = torch.nn.functional.mse_loss(predictions, targets)

        # Backprop through entire graph but only update gamma.
        gamma_opt.zero_grad()
        self.manual_backward(loss, retain_graph=True)
        gamma_opt.step()

        # Backprop through partial graph and only update beta.
        beta_opt.zero_grad()
        self.manual_backward(loss, inputs=[self.beta])
        beta_opt.step()

        # Log the global step.
        self.log("global_step_train", self.global_step)

    def validation_step(self, batch, batch_idx):

        # Forward pass with loss.
        inputs, targets = batch
        predictions = self(inputs)
        loss = torch.nn.functional.mse_loss(predictions, targets)

        # Log the global step.
        self.log("global_step_val", self.global_step)


class IterableTensorDataset(IterableDataset):

    def __init__(self, inputs, targets):
        self.inputs, self.targets = inputs, targets

    def __iter__(self):
        while True:
            i = torch.randint(self.inputs.shape[0], size=())
            yield self.inputs[i], self.targets[i]


def load_dataset(gamma=2.0, beta=-1.0, sigma=0.2):
    inputs = torch.linspace(-1, 1, 201)
    targets = gamma * inputs + beta
    targets += sigma * torch.randn_like(targets)
    indices = torch.randperm(inputs.shape[0])
    pivot = inputs.shape[0] // 2
    train_inds, val_inds = indices[:pivot], indices[pivot:]
    return (
        (inputs[train_inds], targets[train_inds]),
        (inputs[val_inds], targets[val_inds]))


def main():
    train_data, val_data = load_dataset()
    train_set = IterableTensorDataset(*train_data)
    val_set = TensorDataset(*val_data)
    train_loader = DataLoader(train_set, batch_size=4)
    test_loader = DataLoader(val_set, batch_size=1)
    trainer = pl.Trainer(
        log_every_n_steps=10,
        val_check_interval=100,
        logger=CSVLogger("./logs"),
        enable_progress_bar=False,
        max_steps=1000,
    )
    model = LinearRegression()
    trainer.fit(model, train_loader, test_loader)
    return (
        model.gamma.data.detach().cpu().item(),
        model.beta.data.detach().cpu().item())


if __name__ == "__main__":
    print(main())

Error messages and logs

global_step_train,epoch,step,global_step_val
20.0,0,9,
40.0,0,19,
60.0,0,29,
80.0,0,39,
100.0,0,49,
120.0,0,59,
140.0,0,69,
160.0,0,79,
180.0,0,89,
200.0,0,99,
,0,99,200.0
220.0,0,109,
240.0,0,119,
260.0,0,129,
280.0,0,139,
300.0,0,149,
320.0,0,159,
340.0,0,169,
360.0,0,179,
380.0,0,189,
400.0,0,199,
,0,199,400.0
420.0,0,209,
440.0,0,219,
460.0,0,229,
480.0,0,239,
500.0,0,249,
520.0,0,259,
540.0,0,269,
560.0,0,279,
580.0,0,289,
600.0,0,299,
,0,299,600.0
620.0,0,309,
640.0,0,319,
660.0,0,329,
680.0,0,339,
700.0,0,349,
720.0,0,359,
740.0,0,369,
760.0,0,379,
780.0,0,389,
800.0,0,399,
,0,399,800.0
820.0,0,409,
840.0,0,419,
860.0,0,429,
880.0,0,439,
900.0,0,449,
920.0,0,459,
940.0,0,469,
960.0,0,479,
980.0,0,489,
1000.0,0,499,
,0,499,1000.0

Environment

Current environment
  • CUDA:
    • GPU: None
    • available: False
    • version: None
  • Lightning:
    • lightning-utilities: 0.9.0
    • pytorch-lightning: 2.0.4
    • torch: 2.0.1
    • torchmetrics: 0.11.4
  • Packages:
    • aiohttp: 3.8.4
    • aiosignal: 1.3.1
    • async-timeout: 4.0.2
    • attrs: 23.1.0
    • certifi: 2023.5.7
    • charset-normalizer: 3.1.0
    • filelock: 3.12.2
    • frozenlist: 1.3.3
    • fsspec: 2023.6.0
    • idna: 3.4
    • jinja2: 3.1.2
    • lightning-utilities: 0.9.0
    • markupsafe: 2.1.3
    • mpmath: 1.3.0
    • multidict: 6.0.4
    • networkx: 3.1
    • numpy: 1.25.0
    • packaging: 23.1
    • pip: 23.0.1
    • pytorch-lightning: 2.0.4
    • pyyaml: 6.0
    • requests: 2.31.0
    • setuptools: 67.6.0
    • sympy: 1.12
    • torch: 2.0.1
    • torchmetrics: 0.11.4
    • tqdm: 4.65.0
    • typing-extensions: 4.7.0
    • urllib3: 2.0.3
    • wheel: 0.38.4
    • yarl: 1.9.2
  • System:
    • OS: Darwin
    • architecture:
      • 64bit
    • processor: i386
    • python: 3.10.11
    • release: 20.6.0
    • version: Darwin Kernel Version 20.6.0: Thu Mar 9 20:39:26 PST 2023; root:xnu-7195.141.49.700.6~1/RELEASE_X86_64

More info

If this is the intended behavior, it should be reconciled with the trainer’s notion of step. Arguments like log_every_n_steps and val_check_interval use a different definition of step.

About this issue

  • Original URL
  • State: open
  • Created a year ago
  • Reactions: 8
  • Comments: 19 (4 by maintainers)

Most upvoted comments

This is a common issue for GAN and we should take a look.

+1 met the same thing

I’ve encountered the same problem and solved this problem as below. However, I’m not sure this method does not makes another problem. If someone finds possible edge case about my logic, please commnet below.

[Background]

  • trainer’s global step is alias of trainer.fit_loop.epoch_roop.manual_optimization.optim_step_progress.total.completed
  • When you call trainer.fit with manual optimization, actual training logic (your lightningmodule.training_step implementation) execuated at the trainer.fit_loop.epoch_roop.manual_optimization.run(). (@jkyl mentiond same thing above)
  • At trainer.fit_loop.epoch_roop.manual_optimization.run(), three methods will be called.
    • trainer.fit_loop.epoch_roop.manual_optimization.on_run_start
    • trainer.fit_loop.epoch_roop.manual_optimization.advance
    • trainer.fit_loop.epoch_roop.manual_optimization.on_run_end
  • At trainer.fit_loop.epoch_roop.manual_optimization.on_run_start, override all optimizer’s _on_before_step and _on_after_step so that each optimizer’s step increases trainer.fit_loop.epoch_roop.manual_optimization.optim_step_progress.total.completed by one.
  • Below code is manual optimization class. self.optim_step_progress.increment_completed() method increases trainer.fit_loop.epoch_roop.manual_optimization.optim_step_progress.total.completed.
class _ManualOptimization(_Loop):
    """A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens
    entirely in the :meth:`~lightning.pytorch.core.module.LightningModule.training_step` and therefore the user is
    responsible for back-propagating gradients and making calls to the optimizers.

    This loop is a trivial case because it performs only a single iteration (calling directly into the module's
    :meth:`~lightning.pytorch.core.module.LightningModule.training_step`) and passing through the output(s).

    """

    output_result_cls = ManualResult

    def __init__(self, trainer: "pl.Trainer") -> None:
        super().__init__(trainer)
        # since manual optimization does not track lr scheduler or optimizer frequencies, we use a simpler progress than
        # `_OptimizationProgress`
        self.optim_step_progress = _Progress.from_defaults(_ReadyCompletedTracker)

        self._output: _OUTPUTS_TYPE = {}

    def run(self, kwargs: OrderedDict) -> _OUTPUTS_TYPE:
        self.on_run_start()
        with suppress(StopIteration):  # no loop to break at this level
            self.advance(kwargs)
        self._restarting = False
        return self.on_run_end()

    def on_run_start(self) -> None:
        # inject logic around the optimizer step
        for lightning_optimizer in self.trainer.strategy._lightning_optimizers:
            lightning_optimizer._on_before_step = self._on_before_step
            lightning_optimizer._on_after_step = self._on_after_step

    def advance(self, kwargs: OrderedDict) -> None:
        """Performs the training step for manual optimization.

        Args:
            kwargs: The kwargs passed down to the hooks.

        """
        trainer = self.trainer

        # manually capture logged metrics
        training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
        del kwargs  # release the batch from memory
        self.trainer.strategy.post_training_step()
        result = self.output_result_cls.from_training_step_output(training_step_output)

        self._output = result.asdict()

    def on_run_end(self) -> _OUTPUTS_TYPE:
        """Returns the result of this loop, i.e., the post-processed outputs from the training step."""
        output, self._output = self._output, {}  # free memory
        # reset logic around the optimizer step
        for lightning_optimizer in self.trainer.strategy._lightning_optimizers:
            lightning_optimizer._on_before_step = do_nothing_closure
            lightning_optimizer._on_after_step = do_nothing_closure
        return output

    def _on_before_step(self) -> None:
        self.optim_step_progress.increment_ready()
        self.trainer.profiler.start("optimizer_step")

    def _on_after_step(self) -> None:
        self.trainer.profiler.stop("optimizer_step")
        self.optim_step_progress.increment_completed()

[Solution]

  • Since manual optimization logic overrides optimizer’s hook before “training_step” called, we can re-override the optimizer’s hook at the top of the “training_step”.
  • Example:
...
    def training_step(self, batch, batch_idx):
        gamma_opt, beta_opt = self.optimizers()
        beta_opt._on_before_step = lambda : self.trainer.profiler.start("optimizer_step")
        beta_opt._on_after_step = lambda : self.trainer.profiler.stop("optimizer_step")
        ...

[Suggestion for the PytorchLightning]

  • If this method seems safe, we could contribute by PR in two ways.
    1. Add this method as a guide in the Pytorch Lightning documentation (somewhere like PYTORCH LIGHTNING BASIC GAN TUTORIAL)
    2. Or, we can make Lightningmodule’s configure_optimziers interface to suppert options like this
...
def configure_optimziers():
    opt1 = Adam(...)
    opt2 = Adam(...)
    return (
        {"optimizer": opt1},
        {"optimizer": opt2, "do_not_count_global_step": True},
    )
    

Thought I’d provide a little more detail on my use case since other people have encountered this.

I’m training a GAN with multiple discriminator steps per generator step. My training step looks like this:

def training_step(self, batch, batch_idx):
    if batch_idx % self.n_critic == 0:
        self.update_generator_and_discriminator(batch)
    else:
        self.update_discriminator_only(batch)

This is more efficient than only updating one of the networks each iteration, because it allows one to re-use the generator outputs for the discriminator update. But, it also means that update_generator_and_discriminator makes two calls to optimizer.step.

As a workaround to this bug, I subclassed the trainer, like this:

class MyTrainer(pl.Trainer):

    def __init__(self, *, n_critic: int, **kwargs):
        super().__init__(**kwargs)
        self.n_critic = n_critic

    @property
    def global_step(self) -> int:
        return convert_global_step_to_current_iter(super().global_step, self.n_critic)

And I also implemented the following method:

def convert_global_step_to_current_iter(step: int, nc: int) -> int:
    return int(step * nc / (nc + 1))

This lets my callbacks run at the correct frequency, but is not a general solution. This only applies to the case where every n_critic number of steps, global_step is incremented by 2, and every other step it’s incremented by 1.

(@ repo owners feel free to assign the issue to me)

When i set self.automatic_optimization = False, I got same error. it caused by optimizer.step() that increase self.global_step 1 for each called.

My observation is as follow.

When i using 2 optimizer, i got 2 times larger global step then actual step. When i using 3 optimizer, i got 3 times larger global step then actual step.

So, in this case, we need to figure out how to handle global_step increase when called optimizer.step() for proper training.

The source of this behavior starts with the fact that trainer.global_step refers to the global_step property of training_epoch_loop.

In turn, that property derives its result from the optim_step_progress attribute of the _ManualOptimization loop object, whose total.completed attribute is incremented in _ManualOptimization._on_after_step.

Ultimately, _ManualOptimization._on_after_step is called via all of the LightningOptimizers created by the lightning module here. All optimizers are injected with the method here.

One possible fix would be to inject only one of the optimizers with the total.completed incrementing behavior, rather than all.

Just ran into this as well, thanks @yzslab for the quick fix. Considering this is still a problem while this github issue looks like going stale, I’ll have a stab at getting a PR in

@askerlee yes, the _on_before_step and _on_after_step functions get reassigned for each step, so you’ll have to overwrite them in each step

Separately, in the meantime if anybody needs a quick fix for any number of optimizers, update to this:

for i, opt in enumerate(self.optimizers()):
    opt.zero_grad()
    if i+1 < len(self.optimizers()):
        opt._on_before_step = lambda : self.trainer.profiler.start("optimizer_step")
        opt._on_after_step = lambda : self.trainer.profiler.stop("optimizer_step")

Maybe considering change the definition of global step as the number of time “training_step” is called. But this will be a breaking change… Adding a flag to open it will be better.

Hello, I have just just tried to use lightning (2.0.6), I observed my global_step is also out of sync with actual steps, which is also reflected on the tensorboard, making learning rate unchanged with more training in my case:

trainer = pl.Trainer(devices=2, accelerator=‘gpu’, strategy=‘ddp’, max_epochs=EPOCHS, logger=True, log_every_n_steps=50, check_val_every_n_epoch=1, callbacks=checkpoint_callback, accumulate_grad_batches=16,) What trainer sets have overwritten my nemo configure in what follows:

trainer: devices: -1 # number of GPUs, -1 would use all available GPUs num_nodes: 1 max_epochs: 1000 max_steps: 200000 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: ddp accumulate_grad_batches: 1 gradient_clip_val: 0.0 precision: bf16 # 16, 32, or bf16 log_every_n_steps: 100 # Interval of logging. enable_progress_bar: True resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs sync_batchnorm: true enable_checkpointing: False # Provided by exp_manager logger: false # Provided by exp_manager benchmark: false # needs to be false for models with variable-length speech input as it slows down training So far, my training progress is like: Epoch 254: 23%|██▎ | 201/883 [02:12<07:29, 1.52it/s, v_num=44, loss_step=49.80, loss_epoch=51.10] From this, I think I have already run 253*883 steps However, what my tensorboard is displaying:

image It told me only 14k global steps has been run, obviously wrong.

This is annoying since lightning changed the learning rate according to global_steps, and now global steps are mis-calculated Besides, it cannot step normally. For instance, I set my max_steps as 200000, and actual running steps are already over 223399, and it is not stopped as expected. image

Why this matters:

  • The training process will terminate earlier than the user expects, with respect to the max_steps that they specify.
  • User-specified callbacks that refer to global_step will be out-of-sync with the true current iteration.
  • In the case of variable numbers of optimizer steps per training step (as is the case for GAN training with variable n_critic), the discrepancy between global_step and the true current iteration will not be easily correctible, i.e. n_critic would need to be propagated into any callbacks or stop criteria.