pytorch-lightning: incorrect global_step with multiple optimizers and automatic_optimization=False
Bug description
Hello,
I encountered a bug when training with automatic_optimization = False and two optimizers.
In summary: the global_step attribute of the trainer and the lightning module is tracking the total number of calls to optimizer.step() (in my case, two per training_step), rather than the total number of iterations of the dataloader.
This conflicts with the notion of step in arguments like log_every_n_steps and val_check_interval in the trainer. Case in point, if we call
self.log("global_step", self.global_step)
inside training_step, with CSVLogger, log_every_n_steps=10, and two optimizer.step()s per training_step, the CSV logs show:
global_step,epoch,step
20.0,0,9
40.0,0,19
60.0,0,29
80.0,0,39
100.0,0,49
Note how global_step conflicts with step, and in fact is twice the expected value, since we have two optimizers.
I have attached a complete code example that replicates the issue.
What version are you seeing the problem on?
v2.0
How to reproduce the bug
import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers import CSVLogger
from torch.utils.data import TensorDataset, IterableDataset, DataLoader
SEMVER = tuple(int(x) for x in pl.__version__.split("."))
assert SEMVER >= (2, 0, 3)
class LinearRegression(pl.LightningModule):
def __init__(self):
super().__init__()
self.gamma = torch.nn.Parameter(torch.ones(()))
self.beta = torch.nn.Parameter(torch.zeros(()))
self.automatic_optimization = False
def forward(self, x):
return self.gamma * x + self.beta
def configure_optimizers(self):
gamma_opt = torch.optim.SGD([self.gamma], lr=1e-2)
beta_opt = torch.optim.SGD([self.beta], lr=1e-2)
return gamma_opt, beta_opt
def training_step(self, batch, batch_idx):
# Two optimizers.
gamma_opt, beta_opt = self.optimizers()
# Forward pass with loss.
inputs, targets = batch
predictions = self(inputs)
loss = torch.nn.functional.mse_loss(predictions, targets)
# Backprop through entire graph but only update gamma.
gamma_opt.zero_grad()
self.manual_backward(loss, retain_graph=True)
gamma_opt.step()
# Backprop through partial graph and only update beta.
beta_opt.zero_grad()
self.manual_backward(loss, inputs=[self.beta])
beta_opt.step()
# Log the global step.
self.log("global_step_train", self.global_step)
def validation_step(self, batch, batch_idx):
# Forward pass with loss.
inputs, targets = batch
predictions = self(inputs)
loss = torch.nn.functional.mse_loss(predictions, targets)
# Log the global step.
self.log("global_step_val", self.global_step)
class IterableTensorDataset(IterableDataset):
def __init__(self, inputs, targets):
self.inputs, self.targets = inputs, targets
def __iter__(self):
while True:
i = torch.randint(self.inputs.shape[0], size=())
yield self.inputs[i], self.targets[i]
def load_dataset(gamma=2.0, beta=-1.0, sigma=0.2):
inputs = torch.linspace(-1, 1, 201)
targets = gamma * inputs + beta
targets += sigma * torch.randn_like(targets)
indices = torch.randperm(inputs.shape[0])
pivot = inputs.shape[0] // 2
train_inds, val_inds = indices[:pivot], indices[pivot:]
return (
(inputs[train_inds], targets[train_inds]),
(inputs[val_inds], targets[val_inds]))
def main():
train_data, val_data = load_dataset()
train_set = IterableTensorDataset(*train_data)
val_set = TensorDataset(*val_data)
train_loader = DataLoader(train_set, batch_size=4)
test_loader = DataLoader(val_set, batch_size=1)
trainer = pl.Trainer(
log_every_n_steps=10,
val_check_interval=100,
logger=CSVLogger("./logs"),
enable_progress_bar=False,
max_steps=1000,
)
model = LinearRegression()
trainer.fit(model, train_loader, test_loader)
return (
model.gamma.data.detach().cpu().item(),
model.beta.data.detach().cpu().item())
if __name__ == "__main__":
print(main())
Error messages and logs
global_step_train,epoch,step,global_step_val
20.0,0,9,
40.0,0,19,
60.0,0,29,
80.0,0,39,
100.0,0,49,
120.0,0,59,
140.0,0,69,
160.0,0,79,
180.0,0,89,
200.0,0,99,
,0,99,200.0
220.0,0,109,
240.0,0,119,
260.0,0,129,
280.0,0,139,
300.0,0,149,
320.0,0,159,
340.0,0,169,
360.0,0,179,
380.0,0,189,
400.0,0,199,
,0,199,400.0
420.0,0,209,
440.0,0,219,
460.0,0,229,
480.0,0,239,
500.0,0,249,
520.0,0,259,
540.0,0,269,
560.0,0,279,
580.0,0,289,
600.0,0,299,
,0,299,600.0
620.0,0,309,
640.0,0,319,
660.0,0,329,
680.0,0,339,
700.0,0,349,
720.0,0,359,
740.0,0,369,
760.0,0,379,
780.0,0,389,
800.0,0,399,
,0,399,800.0
820.0,0,409,
840.0,0,419,
860.0,0,429,
880.0,0,439,
900.0,0,449,
920.0,0,459,
940.0,0,469,
960.0,0,479,
980.0,0,489,
1000.0,0,499,
,0,499,1000.0
Environment
Current environment
- CUDA:
- GPU: None
- available: False
- version: None
- Lightning:
- lightning-utilities: 0.9.0
- pytorch-lightning: 2.0.4
- torch: 2.0.1
- torchmetrics: 0.11.4
- Packages:
- aiohttp: 3.8.4
- aiosignal: 1.3.1
- async-timeout: 4.0.2
- attrs: 23.1.0
- certifi: 2023.5.7
- charset-normalizer: 3.1.0
- filelock: 3.12.2
- frozenlist: 1.3.3
- fsspec: 2023.6.0
- idna: 3.4
- jinja2: 3.1.2
- lightning-utilities: 0.9.0
- markupsafe: 2.1.3
- mpmath: 1.3.0
- multidict: 6.0.4
- networkx: 3.1
- numpy: 1.25.0
- packaging: 23.1
- pip: 23.0.1
- pytorch-lightning: 2.0.4
- pyyaml: 6.0
- requests: 2.31.0
- setuptools: 67.6.0
- sympy: 1.12
- torch: 2.0.1
- torchmetrics: 0.11.4
- tqdm: 4.65.0
- typing-extensions: 4.7.0
- urllib3: 2.0.3
- wheel: 0.38.4
- yarl: 1.9.2
- System:
- OS: Darwin
- architecture:
- 64bit
- processor: i386
- python: 3.10.11
- release: 20.6.0
- version: Darwin Kernel Version 20.6.0: Thu Mar 9 20:39:26 PST 2023; root:xnu-7195.141.49.700.6~1/RELEASE_X86_64
More info
If this is the intended behavior, it should be reconciled with the trainer’s notion of step. Arguments like log_every_n_steps and val_check_interval use a different definition of step.
About this issue
- Original URL
- State: open
- Created a year ago
- Reactions: 8
- Comments: 19 (4 by maintainers)
This is a common issue for GAN and we should take a look.
+1 met the same thing
I’ve encountered the same problem and solved this problem as below. However, I’m not sure this method does not makes another problem. If someone finds possible edge case about my logic, please commnet below.
[Background]
trainer.fit_loop.epoch_roop.manual_optimization.optim_step_progress.total.completedtrainer.fitwith manual optimization, actual training logic (yourlightningmodule.training_stepimplementation) execuated at thetrainer.fit_loop.epoch_roop.manual_optimization.run(). (@jkyl mentiond same thing above)trainer.fit_loop.epoch_roop.manual_optimization.run(), three methods will be called.trainer.fit_loop.epoch_roop.manual_optimization.on_run_starttrainer.fit_loop.epoch_roop.manual_optimization.advancetrainer.fit_loop.epoch_roop.manual_optimization.on_run_endtrainer.fit_loop.epoch_roop.manual_optimization.on_run_start, override all optimizer’s_on_before_stepand_on_after_stepso that each optimizer’s step increasestrainer.fit_loop.epoch_roop.manual_optimization.optim_step_progress.total.completedby one.self.optim_step_progress.increment_completed()method increasestrainer.fit_loop.epoch_roop.manual_optimization.optim_step_progress.total.completed.[Solution]
[Suggestion for the PytorchLightning]
configure_optimziersinterface to suppert options like thisThought I’d provide a little more detail on my use case since other people have encountered this.
I’m training a GAN with multiple discriminator steps per generator step. My training step looks like this:
This is more efficient than only updating one of the networks each iteration, because it allows one to re-use the generator outputs for the discriminator update. But, it also means that
update_generator_and_discriminatormakes two calls tooptimizer.step.As a workaround to this bug, I subclassed the trainer, like this:
And I also implemented the following method:
This lets my callbacks run at the correct frequency, but is not a general solution. This only applies to the case where every
n_criticnumber of steps,global_stepis incremented by 2, and every other step it’s incremented by 1.(@ repo owners feel free to assign the issue to me)
When i set self.automatic_optimization = False, I got same error. it caused by optimizer.step() that increase self.global_step 1 for each called.
My observation is as follow.
When i using 2 optimizer, i got 2 times larger global step then actual step. When i using 3 optimizer, i got 3 times larger global step then actual step.
So, in this case, we need to figure out how to handle global_step increase when called optimizer.step() for proper training.
The source of this behavior starts with the fact that
trainer.global_steprefers to theglobal_stepproperty oftraining_epoch_loop.In turn, that property derives its result from the
optim_step_progressattribute of the_ManualOptimizationloop object, whosetotal.completedattribute is incremented in_ManualOptimization._on_after_step.Ultimately,
_ManualOptimization._on_after_stepis called via all of theLightningOptimizers created by the lightning module here. All optimizers are injected with the method here.One possible fix would be to inject only one of the optimizers with the
total.completedincrementing behavior, rather than all.Just ran into this as well, thanks @yzslab for the quick fix. Considering this is still a problem while this github issue looks like going stale, I’ll have a stab at getting a PR in
@askerlee yes, the
_on_before_stepand_on_after_stepfunctions get reassigned for each step, so you’ll have to overwrite them in each stepSeparately, in the meantime if anybody needs a quick fix for any number of optimizers, update to this:
Maybe considering change the definition of global step as the number of time “training_step” is called. But this will be a breaking change… Adding a flag to open it will be better.
Hello, I have just just tried to use lightning (2.0.6), I observed my global_step is also out of sync with actual steps, which is also reflected on the tensorboard, making learning rate unchanged with more training in my case:
trainer = pl.Trainer(devices=2, accelerator=‘gpu’, strategy=‘ddp’, max_epochs=EPOCHS, logger=True, log_every_n_steps=50, check_val_every_n_epoch=1, callbacks=checkpoint_callback, accumulate_grad_batches=16,) What trainer sets have overwritten my nemo configure in what follows:
trainer: devices: -1 # number of GPUs, -1 would use all available GPUs num_nodes: 1 max_epochs: 1000 max_steps: 200000 # computed at runtime if not set val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: ddp accumulate_grad_batches: 1 gradient_clip_val: 0.0 precision: bf16 # 16, 32, or bf16 log_every_n_steps: 100 # Interval of logging. enable_progress_bar: True resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs sync_batchnorm: true enable_checkpointing: False # Provided by exp_manager logger: false # Provided by exp_manager benchmark: false # needs to be false for models with variable-length speech input as it slows down training So far, my training progress is like: Epoch 254: 23%|██▎ | 201/883 [02:12<07:29, 1.52it/s, v_num=44, loss_step=49.80, loss_epoch=51.10] From this, I think I have already run 253*883 steps However, what my tensorboard is displaying:
This is annoying since lightning changed the learning rate according to global_steps, and now global steps are mis-calculated Besides, it cannot step normally. For instance, I set my max_steps as 200000, and actual running steps are already over 223399, and it is not stopped as expected.
Why this matters:
max_stepsthat they specify.global_stepwill be out-of-sync with the true current iteration.global_stepand the true current iteration will not be easily correctible, i.e. n_critic would need to be propagated into any callbacks or stop criteria.