pytorch-lightning: Lightning is very slow between epochs, compared to PyTorch.
I converted some Pytorch code to Lightning. The dataset is loaded lazily by the train & eval dataloaders.
However, when moving the code to Lightning, I noticed a huge slowdown. After digging around, I noticed that there was a ~10 seconds delay between each epoch. For comparison, on my vanilla Pytorch, an epoch takes ~4s.
I first thought it was a data loading problem, but during the 10s delay, no data is loaded (at least that’s what my print tell me).
I think the issue is related to the number of workers, because setting n_workers=0 solves the problem (but is slower in the end, since only one worker is not enough). I know starting workers is slow, however I have persistent_workers=True and this does not happen in normal Pytorch. My data loaders also have pin_memory=True (removing pin_memory does not solve the problem).
Since this is company code, I cannot disclose the before/after, but I’ll try to “anonymize” some code if necessary. Here is the lightning module:
class RawModule(pl.LightningModule):
def __init__(self):
super(RawModule, self).__init__()
self.encoder1 = nn.Sequential(...)
self.encoder2 = nn.Sequential(...)
def forward(self, data1, data2):
result1 = self.encoder1(data1)
result2 = self.encoder2(data2)
result1 = result1 .view(result1 .size(0), -1)
result2 = result2 .view(result2 .size(0), -1)
result1 = F.normalize(result1 , p=2, dim=1)
result2 = F.normalize(result2 , p=2, dim=1)
return result1, result2
def calculate_loss(self, batch):
x, r, y = batch
a, v = self.forward(r, x)
d = nn.functional.cosine_similarity(a, v)
loss = logloss(d.unsqueeze(1), y)
return loss
class Module(RawModule):
def training_step(self, batch, batch_idx):
loss = self.calculate_loss(batch)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
loss = self.calculate_loss(batch)
self.log("validation_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
return optimizer
if __name__ == '__main__':
# stuff...
train_loader = data_utils.DataLoader(
train_dataset, batch_size=256, shuffle=True,
num_workers=5, persistent_workers=True,
pin_memory=True,
)
val_loader = data_utils.DataLoader(
test_dataset, batch_size=256,
num_workers=2, persistent_workers=True,
pin_memory=True,
)
# Model
load_from_pytorch = True
if checkpoint_path is None:
model = Module()
if load_from_pytorch:
if not checkpoint_path:
raise ValueError("Please provide a checkpoint path")
model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
else:
model = Module.load_from_checkpoint(checkpoint_path)
trainer = pl.Trainer(
gpus=1,
max_epochs=5,
check_val_every_n_epoch=10,
log_every_n_steps=5,
)
trainer.fit(model, train_loader, val_loader)
Here is the result of profiler="simple":
Action | Mean duration (s) |Num calls | Total time (s) | Percentage % |
----------------------------------------------------------------------------------------------------------------------------------------
Total | - |_ | 48.813 | 100 % |
----------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch | 27.922 |1 | 27.922 | 57.202 |
fetch_next_sanity_check_batch | 4.4013 |3 | 13.204 | 27.05 |
get_sanity_check_batch | 4.4013 |3 | 13.204 | 27.05 |
fetch_next_train_batch | 1.2734 |10 | 12.734 | 26.087 |
get_train_batch | 1.2734 |10 | 12.734 | 26.087 |
run_training_batch | 0.47733 |9 | 4.296 | 8.8009 |
optimizer_step_with_closure_0 | 0.40089 |9 | 3.608 | 7.3915 |
validation_step | 0.664 |2 | 1.328 | 2.7206 |
evaluation_step_and_end | 0.664 |2 | 1.328 | 2.7206 |
training_step_and_backward | 0.12644 |9 | 1.138 | 2.3313 |
backward | 0.096889 |9 | 0.872 | 1.7864 |
training_step | 0.029556 |9 | 0.266 | 0.54494 |
model_forward | 0.029556 |9 | 0.266 | 0.54494 |
on_train_start | 0.016 |1 | 0.016 | 0.032778 |
Here is the result of profiler="advanced": https://pastebin.com/q3C5P826.
Finally, here is a video demonstrating the problem. I’m printing each piece of data loading, to prove it’s not the issue. https://user-images.githubusercontent.com/30944236/140587623-ae184fa3-370a-42be-8593-200026d11ba4.mp4
Random informations:
- OS: Windows 10
- CPU: AMD Ryzen 5 5600X 6 Core
- GPU: Nvidia RTX 3070
- Pytorch version: 1.10.0
- Pytorch Lightning version: 1.5.0
- Cuda version: 11.5
- How did I install Pytorch:
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch - Python 3.8
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Comments: 49 (24 by maintainers)
TL;DR: I just commented the
self.reset()line of AbstractDataLoader, located at line 198 ofpytorch_lightning/utilities/fetching.py. My code runs ~20x faster. It probably isn’t a correct way to fix things, but my trainings work as well as they did before.After a bunch of fiddling around, I decided to create a custom DataLoader and overload the
__iter__method. I discovered the problem was that the_iteratorproperty of the DataLoader was always set toNonesomewhere between epochs. When_iteratoris None, the DataLoader is reseted and needs to start everything from scratch.As you can see in the normal DataLoader, having
self._iteratorset toNonecauses a call toself._get_iterator(), which relaoads everything.I decided to override
_iteratorwith a custom property (getter & setter), to print the stack trace when._iteratoris set to None:(I could also use the debugger for this)
This leads to 2 different yet very similar stack traces (respectively, evaluation & training loaders):
Well… We’re nearly there. It looks like advancing 1 epoch calls
self.reset()on the DataFetcher itself, which then resets the DataLoader and leads to our problem.Indeed, when checking AbstractDataFetcher, we have this:
So… I guess we found the problem? Each time a new epoch runs, it calls
self.advanceon the FitLoop/EvalLoop, which then callsself.on_run_start.on_run_startwill create the_dataloader_iter(which is a normal behavior), which itself calls_update_dataloader_iter. This function, throughenumerate, callsAbstractDataFetcher.__iter__which callsself.reset(), entirely reloading the DataLoader.Notice that
self.reset()is also called in the__init__method ofAbstractDataFetcher, for setup purposes.I just commented the
self.reset()line of AbstractDataLoader, located at line 198 ofpytorch_lightning/utilities/fetching.py. While it speeds up the code a lot, and the entire training seems to work, it probably breaks a bunch of things ? I’d wait until the Lightning team fixes the problem before trying anything serious.What a ride.
Hi @isvogor-foi, I the same issue as reported here and after updating pytorch-lightning I didn’t see any improvement either. After reading OP blog: https://medium.com/@florian-ernst/finding-why-pytorch-lightning-made-my-training-4x-slower-ae64a4720bd1, I noticed that I was missing the persistent_workers=True flag on my DataLoader:
Hopefully this will help you! Performance was much improved for me.
Is this still fixed on 2.0.4? I’m still seeing this behavior. This code runs lightning fast (albeit with warnings about not having any workers):
Adding the workers makes the warnings go away, but freezes ~15 seconds before validation or a new epoch:
This still seems to be the case, especially with dataloader startup time
Hi,
This issue still persists in 2.1.3. I’m directly passing the dataloaders to the Trainer
train_dataloader=DataLoader(train_dataset,batch_size=64,collate_fn=collate_fn_train,drop_last=True,shuffle=True,pin_memory=True,num_workers=8,prefetch_factor=8,persistent_workers=True,)but the training freezes for a few seconds after each epoch. When I ran the advanced profiler inrun_training_epoch, I saw that reset is called every epoch (dataloader.py:1086(_reset)andtraining_epoch_loop.py:143(reset)). The training is extremely slow if I setnum_workers=0.Is there a fix/workaround for this?
Tested the new 1.5.1 release today, looks like performance is back on track. Thanks to everyone!
After a lot of digging around, I managed to pin down the line causing the problem.
It’s the line 142 in loops/epoch/training_epoch_loop.py :
Therefore, the culprit is:
The on_run_start hook is called from loops/base.py :
And this
runmethod is called from loops/fit_loop.py :The problem is the
data_fetcher. It’s indeed related to dataloaders, as expected. I still don’t know what is the root reason, but I’ll try to find it.I had a similar observation where data_fetcher caused unusual long run times. For me it was indeed fixed by completely disabling multiprocess dataloading (num_workers=0). Although, I have not tried to set "reload_dataloaders_every_epoch=False“. Interesting to see that others have the same issue/observation. Funnily, setting num_workers=0 has led me to open #10182. Perhaps, there is something more to this?
Hi everyone… This topic is very interesting as I’m striking the same issue. I’m comparing the same implementation, in Torch and Lightning. I came across this post, so I noticed I was using the ancient Lightning 1.4.7, so I updated to 1.5.9. I repeated my test and nothing changed… Torch was still significantly faster than lightning. So, as @TheMrZZ suggested, I commented
resetin the__iter__function, and repeated the test. Sure enough, Lightning was… lightning fast now! I’m loading images so I got the following:Apparently, the performance issue has been fixed in 1.5.1, however, it seems that with 1.5.9 the reset line is still here. So, I’m curious, why do we need to reset the data fetcher, after each epoch?
Dear @TheMrZZ,
Thanks for your investigation and happy we solved this ugly bug.
Best, T.C
Is there any update on this? It seems like I’m not the only one facing this issue. I started using PTL since 2.x and would prefer not to downgrade to 1.5.1 to make this issue go away. Is there a 2.x version that doesn’t have this issue? @awaelchli @Borda @carmocca
@rohitgr7 I think I’ve found the issue. Previously, I was passing the datasets into the
LightningDataModule, runningsave_hyperparametersand then accessing them viaself.hparams["train_dataset"],self.hparams.["validation_dataset"], etc. It seems as if the entire cached dataset was being cached in this case.The fix is to explicitly make these modules instance variables and then run
self.save_hyperparameters(ignore=["train_dataset", "validation_dataset", "test_dataset"])I wouldn’t know how to make this reproducible either, but I have pinned PL to 1.6.5 in my project because I observe the same behavior as @lminer on 1.7.* with very slow epochs. I also have very large memory usage, although I am using single gpu and I am not calling save_hyperparameters at all. Could there be other underlying factor?
Hm… very interesting. I did the same with
mnistexample that comes in the lightning examples. And noticed that Torch was only slightly better in my case, a second or so. However, I’ve just dismantled theadvancecall intraining_epoch_loop.py. Not to explain details, I’ve added some arrows to indicate the execution timeline.So the training loop is this:
and before it, this part takes a lot of time:
and after it there is another long call:
Therefore excessive logging can slow it down.
So, in my particular case, I was using lightning 1.5.9 with one V100, and the aforementioned hook calls seem to build up over time, so in my long experiment with 100 epochs, with 256 batch size, and 35k images Torch performs better. Using just 1 GPU. I didn’t try multiple GPUs. I think we can leave it at that.
@carmocca thanks!
Here are some simple time results:
This was using 1
NVIDIA GeForce RTX 3090and PyTorch Lightning 1.6.0dev, commit 8394770d4afa5480f881229b150ac44eaa8c41b0,torch==1.10.1,torchmetrics==0.7.0,torchvision==0.11.2,Python 3.8.12.PyTorch Lightning
I used https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/imagenet.py verbatim, called with
PyTorch
I used https://github.com/pytorch/examples/blob/master/imagenet/main.py with the following changes applied to disable validation, and stop at 100 batches:
So basically the same speed.
@carmocca Sure, vanilla - meaning, it’s the one taken from official implementation:
Lightning: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/imagenet.py Torch: https://github.com/pytorch/examples/blob/master/imagenet/main.py
Try running those on the same device, with the same image dataset, batch size, etc. I’ve tried on my machine, AWS instance, even Colab, and Torch is always much better.
@carmocca Hi, well, I didn’t try 1.5.1, I tried only 1.4.7, and 1.5.9. I’ll see whether I can try 1.5.1 and get back at you with this!