pytorch-lightning: Lightning is very slow between epochs, compared to PyTorch.

I converted some Pytorch code to Lightning. The dataset is loaded lazily by the train & eval dataloaders.

However, when moving the code to Lightning, I noticed a huge slowdown. After digging around, I noticed that there was a ~10 seconds delay between each epoch. For comparison, on my vanilla Pytorch, an epoch takes ~4s.

I first thought it was a data loading problem, but during the 10s delay, no data is loaded (at least that’s what my print tell me).

I think the issue is related to the number of workers, because setting n_workers=0 solves the problem (but is slower in the end, since only one worker is not enough). I know starting workers is slow, however I have persistent_workers=True and this does not happen in normal Pytorch. My data loaders also have pin_memory=True (removing pin_memory does not solve the problem).

Since this is company code, I cannot disclose the before/after, but I’ll try to “anonymize” some code if necessary. Here is the lightning module:

class RawModule(pl.LightningModule):
    def __init__(self):
        super(RawModule, self).__init__()

        self.encoder1 = nn.Sequential(...)
        self.encoder2 = nn.Sequential(...)

    def forward(self, data1, data2):
        result1 = self.encoder1(data1)
        result2 = self.encoder2(data2)

        result1 = result1 .view(result1 .size(0), -1)
        result2 = result2 .view(result2 .size(0), -1)

        result1 = F.normalize(result1 , p=2, dim=1)
        result2 = F.normalize(result2 , p=2, dim=1)


        return result1, result2

    
    def calculate_loss(self, batch):
        x, r, y = batch
        a, v = self.forward(r, x)

        d = nn.functional.cosine_similarity(a, v)
        loss = logloss(d.unsqueeze(1), y)

        return loss


class Module(RawModule):
    def training_step(self, batch, batch_idx):
        loss = self.calculate_loss(batch)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.calculate_loss(batch)
        self.log("validation_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer


if __name__ == '__main__':
    # stuff...

    train_loader = data_utils.DataLoader(
        train_dataset, batch_size=256, shuffle=True,
        num_workers=5, persistent_workers=True,
        pin_memory=True,
    )

    val_loader = data_utils.DataLoader(
        test_dataset, batch_size=256,
        num_workers=2, persistent_workers=True,
        pin_memory=True,
    )

    # Model
    load_from_pytorch = True

    if checkpoint_path is None:
        model = Module()

        if load_from_pytorch:
            if not checkpoint_path:
                raise ValueError("Please provide a checkpoint path")
            model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
    else:
        model = Module.load_from_checkpoint(checkpoint_path)


    trainer = pl.Trainer(
        gpus=1,
        max_epochs=5,
        check_val_every_n_epoch=10,
        log_every_n_steps=5,
    )
    trainer.fit(model, train_loader, val_loader)

Here is the result of profiler="simple":

Action                                  |  Mean duration (s)    |Num calls              |  Total time (s)       |  Percentage %         |
----------------------------------------------------------------------------------------------------------------------------------------
Total                                   |  -                    |_                      |  48.813               |  100 %                |
----------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                      |  27.922               |1                      |  27.922               |  57.202               |
fetch_next_sanity_check_batch           |  4.4013               |3                      |  13.204               |  27.05                |
get_sanity_check_batch                  |  4.4013               |3                      |  13.204               |  27.05                |
fetch_next_train_batch                  |  1.2734               |10                     |  12.734               |  26.087               |
get_train_batch                         |  1.2734               |10                     |  12.734               |  26.087               |
run_training_batch                      |  0.47733              |9                      |  4.296                |  8.8009               |
optimizer_step_with_closure_0           |  0.40089              |9                      |  3.608                |  7.3915               |
validation_step                         |  0.664                |2                      |  1.328                |  2.7206               |
evaluation_step_and_end                 |  0.664                |2                      |  1.328                |  2.7206               |
training_step_and_backward              |  0.12644              |9                      |  1.138                |  2.3313               |
backward                                |  0.096889             |9                      |  0.872                |  1.7864               |
training_step                           |  0.029556             |9                      |  0.266                |  0.54494              |
model_forward                           |  0.029556             |9                      |  0.266                |  0.54494              |
on_train_start                          |  0.016                |1                      |  0.016                |  0.032778             |

Here is the result of profiler="advanced": https://pastebin.com/q3C5P826.

Finally, here is a video demonstrating the problem. I’m printing each piece of data loading, to prove it’s not the issue. https://user-images.githubusercontent.com/30944236/140587623-ae184fa3-370a-42be-8593-200026d11ba4.mp4

Random informations:

  • OS: Windows 10
  • CPU: AMD Ryzen 5 5600X 6 Core
  • GPU: Nvidia RTX 3070
  • Pytorch version: 1.10.0
  • Pytorch Lightning version: 1.5.0
  • Cuda version: 11.5
  • How did I install Pytorch: conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
  • Python 3.8

cc @tchaton @rohitgr7 @borda @akihironitta

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 49 (24 by maintainers)

Most upvoted comments

TL;DR: I just commented the self.reset() line of AbstractDataLoader, located at line 198 of pytorch_lightning/utilities/fetching.py. My code runs ~20x faster. It probably isn’t a correct way to fix things, but my trainings work as well as they did before.

After a bunch of fiddling around, I decided to create a custom DataLoader and overload the __iter__ method. I discovered the problem was that the _iterator property of the DataLoader was always set to None somewhere between epochs. When _iterator is None, the DataLoader is reseted and needs to start everything from scratch.

# Original DataLoader
class DataLoader(Generic[T_co]):
    def __iter__(self):
        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()

# Custom DataLoader
class CustomDataLoader(DataLoader)
    def __iter__(self) -> '_BaseDataLoaderIter':
        print(f'\n> DataLoader __iter__ with {self._iterator=} starting.\n')
        return super().__iter__()

As you can see in the normal DataLoader, having self._iterator set to None causes a call to self._get_iterator(), which relaoads everything.

I decided to override _iterator with a custom property (getter & setter), to print the stack trace when ._iterator is set to None:

class CustomDataLoader(DataLoader)
    ...
   @property
    def _iterator(self):
        return self.__iterator

    @_iterator.setter
    def _iterator(self, value):
        if value is None:
            print('\nSetting __iterator to None. Stack trace:')
            import traceback
            traceback.print_stack()
        self.__iterator = value
        return self.__iterator

(I could also use the debugger for this)

This leads to 2 different yet very similar stack traces (respectively, evaluation & training loaders):

  File "env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1370, in _run_sanity_check
    self._evaluation_loop.run()
  File "env\lib\site-packages\pytorch_lightning\loops\base.py", line 144, in run
    self.advance(*args, **kwargs)
  File "env\lib\site-packages\pytorch_lightning\loops\dataloader\evaluation_loop.py", line 109, in advance
    dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
  File "env\lib\site-packages\pytorch_lightning\loops\base.py", line 139, in run
    self.on_run_start(*args, **kwargs)
  File "env\lib\site-packages\pytorch_lightning\loops\epoch\evaluation_epoch_loop.py", line 87, in on_run_start
    self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_progress.current.ready)
  File "env\lib\site-packages\pytorch_lightning\loops\utilities.py", line 121, in _update_dataloader_iter
    dataloader_iter = enumerate(data_fetcher, batch_idx)
  File "env\lib\site-packages\pytorch_lightning\utilities\fetching.py", line 198, in __iter__
    self.reset()
  File "env\lib\site-packages\pytorch_lightning\utilities\fetching.py", line 214, in reset
    CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader)
  File "env\lib\site-packages\pytorch_lightning\trainer\supporters.py", line 498, in _shutdown_workers_and_reset_iterator
    dataloader._iterator = None
  File "env\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1314, in _run_train
    self.fit_loop.run()
  File "env\lib\site-packages\pytorch_lightning\loops\base.py", line 144, in run
    self.advance(*args, **kwargs)
  File "env\lib\site-packages\pytorch_lightning\loops\fit_loop.py", line 234, in advance
    self.epoch_loop.run(data_fetcher)
  File "env\lib\site-packages\pytorch_lightning\loops\base.py", line 139, in run
    self.on_run_start(*args, **kwargs)
  File "env\lib\site-packages\pytorch_lightning\loops\epoch\training_epoch_loop.py", line 142, in on_run_start
    self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_idx + 1)
  File "env\lib\site-packages\pytorch_lightning\loops\utilities.py", line 121, in _update_dataloader_iter
    dataloader_iter = enumerate(data_fetcher, batch_idx)
  File "env\lib\site-packages\pytorch_lightning\utilities\fetching.py", line 198, in __iter__
    self.reset()
  File "env\lib\site-packages\pytorch_lightning\utilities\fetching.py", line 212, in reset
    self.dataloader.reset()
  File "env\lib\site-packages\pytorch_lightning\trainer\supporters.py", line 504, in reset
    apply_to_collection(self.loaders, DataLoader, self._shutdown_workers_and_reset_iterator)
  File "env\lib\site-packages\pytorch_lightning\utilities\apply_func.py", line 92, in apply_to_collection
    return function(data, *args, **kwargs)
  File "env\lib\site-packages\pytorch_lightning\trainer\supporters.py", line 498, in _shutdown_workers_and_reset_iterator
    dataloader._iterator = None

Well… We’re nearly there. It looks like advancing 1 epoch calls self.reset() on the DataFetcher itself, which then resets the DataLoader and leads to our problem.

Indeed, when checking AbstractDataFetcher, we have this:

class AbstractDataFetcher(...):
    def __iter__(self) -> Generator[Tuple[Any, bool], None, None]:
        if self.dataloader is None:
            raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.")
-->     self.reset()
        self.dataloader_iter = iter(self.dataloader)
        self._apply_patch()
        self.prefetching(self.prefetch_batches)
        return self

# And iter(AbstractDataFetcher) is called here, in utilities.py:
def _update_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) -> Iterator:
    """Attach the dataloader."""
    if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
        # restore iteration
-->     dataloader_iter = enumerate(data_fetcher, batch_idx)
    else:
        dataloader_iter = iter(data_fetcher)
    return dataloader_iter

So… I guess we found the problem? Each time a new epoch runs, it calls self.advance on the FitLoop/EvalLoop, which then calls self.on_run_start. on_run_start will create the _dataloader_iter (which is a normal behavior), which itself calls _update_dataloader_iter. This function, through enumerate, calls AbstractDataFetcher.__iter__ which calls self.reset(), entirely reloading the DataLoader.

Notice that self.reset() is also called in the __init__ method of AbstractDataFetcher, for setup purposes.

I just commented the self.reset() line of AbstractDataLoader, located at line 198 of pytorch_lightning/utilities/fetching.py. While it speeds up the code a lot, and the entire training seems to work, it probably breaks a bunch of things ? I’d wait until the Lightning team fixes the problem before trying anything serious.

What a ride.

Hi @isvogor-foi, I the same issue as reported here and after updating pytorch-lightning I didn’t see any improvement either. After reading OP blog: https://medium.com/@florian-ernst/finding-why-pytorch-lightning-made-my-training-4x-slower-ae64a4720bd1, I noticed that I was missing the persistent_workers=True flag on my DataLoader:

# My data Loader parameters
DataLoader(
  train_dataset, batch_size=64, shuffle=True, num_workers=n_workers,
  persistent_workers=True, pin_memory=True,
)

Hopefully this will help you! Performance was much improved for me.

Is this still fixed on 2.0.4? I’m still seeing this behavior. This code runs lightning fast (albeit with warnings about not having any workers):

    train_data_loader = DataLoader(
        NameDataset("data/training.csv", char_to_int),
        batch_size=TRAIN_BATCH_SIZE,
        shuffle=True,
        # num_workers=4,
    )
    eval_data_loader = DataLoader(
        NameDataset("data/eval.csv", char_to_int, debug=False),
        batch_size=TRAIN_BATCH_SIZE,
        shuffle=False,
        # num_workers=4,
    )

    lightning_model = PlContactEncoder(
        model, criterion, SIMILARITY_METRIC(0.5, return_distance=True), LEARNING_RATE
    )
    lightning_model.to(device)

    trainer = pl.Trainer(max_epochs=N_EPOCHS, max_steps=50)
    trainer.fit(
        model=lightning_model,
        train_dataloaders=train_data_loader,
        val_dataloaders=eval_data_loader,
    )

Adding the workers makes the warnings go away, but freezes ~15 seconds before validation or a new epoch:

    train_data_loader = DataLoader(
        NameDataset("data/training.csv", char_to_int),
        batch_size=TRAIN_BATCH_SIZE,
        shuffle=True,
        num_workers=4,
    )
    eval_data_loader = DataLoader(
        NameDataset("data/eval.csv", char_to_int, debug=False),
        batch_size=TRAIN_BATCH_SIZE,
        shuffle=False,
        num_workers=4,
    )

    lightning_model = PlContactEncoder(
        model, criterion, SIMILARITY_METRIC(0.5, return_distance=True), LEARNING_RATE
    )
    lightning_model.to(device)

    trainer = pl.Trainer(max_epochs=N_EPOCHS, max_steps=50)
    trainer.fit(
        model=lightning_model,
        train_dataloaders=train_data_loader,
        val_dataloaders=eval_data_loader,
    )

This still seems to be the case, especially with dataloader startup time

Hi,

This issue still persists in 2.1.3. I’m directly passing the dataloaders to the Trainer train_dataloader=DataLoader(train_dataset,batch_size=64,collate_fn=collate_fn_train,drop_last=True,shuffle=True,pin_memory=True,num_workers=8,prefetch_factor=8,persistent_workers=True,) but the training freezes for a few seconds after each epoch. When I ran the advanced profiler in run_training_epoch, I saw that reset is called every epoch (dataloader.py:1086(_reset) and training_epoch_loop.py:143(reset)). The training is extremely slow if I set num_workers=0.

Is there a fix/workaround for this?

Tested the new 1.5.1 release today, looks like performance is back on track. Thanks to everyone!

After a lot of digging around, I managed to pin down the line causing the problem.

It’s the line 142 in loops/epoch/training_epoch_loop.py :

class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
    ...
    def on_run_start(self, data_fetcher: AbstractDataFetcher, **kwargs: Any) -> None:
        # hook
        self.trainer.logger_connector.on_epoch_start()
        self.trainer.call_hook("on_epoch_start")
        self.trainer.call_hook("on_train_epoch_start")
        self.trainer.fit_loop.epoch_progress.increment_started()

        self._reload_dataloader_state_dict(data_fetcher)
-->     self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_idx + 1)

Therefore, the culprit is:

def _update_dataloader_iter(data_fetcher: AbstractDataFetcher, batch_idx: int) -> Iterator:
    """Attach the dataloader."""
    if not isinstance(data_fetcher, DataLoaderIterDataFetcher):
        # restore iteration
        dataloader_iter = enumerate(data_fetcher, batch_idx)
    else:
        dataloader_iter = iter(data_fetcher)
    return dataloader_iter

The on_run_start hook is called from loops/base.py :

class Loop(ABC, Generic[T]):
    ...
    def run(self, *args: Any, **kwargs: Any) -> T:
        if self.skip:
            return self.on_skip()

        self.reset()

-->     self.on_run_start(*args, **kwargs)
        ...

And this run method is called from loops/fit_loop.py :

class FitLoop(Loop):
    ...
    def advance(self) -> None:
        """Runs one whole epoch."""
        dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader)
        data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader)

        with self.trainer.profiler.profile("run_training_epoch"):
-->         self.epoch_loop.run(data_fetcher)
        ...

The problem is the data_fetcher. It’s indeed related to dataloaders, as expected. I still don’t know what is the root reason, but I’ll try to find it.

I had a similar observation where data_fetcher caused unusual long run times. For me it was indeed fixed by completely disabling multiprocess dataloading (num_workers=0). Although, I have not tried to set "reload_dataloaders_every_epoch=False“. Interesting to see that others have the same issue/observation. Funnily, setting num_workers=0 has led me to open #10182. Perhaps, there is something more to this?

Hi everyone… This topic is very interesting as I’m striking the same issue. I’m comparing the same implementation, in Torch and Lightning. I came across this post, so I noticed I was using the ancient Lightning 1.4.7, so I updated to 1.5.9. I repeated my test and nothing changed… Torch was still significantly faster than lightning. So, as @TheMrZZ suggested, I commented reset in the __iter__ function, and repeated the test. Sure enough, Lightning was… lightning fast now! I’m loading images so I got the following:

  • lightning (vanilla 1.5.9): 11 images/s, total runtime 407s
  • lightning (no mod): 108 images/s, total runtime 56s
  • torch (no mod): 40 images/s, total runtime 373s

Apparently, the performance issue has been fixed in 1.5.1, however, it seems that with 1.5.9 the reset line is still here. So, I’m curious, why do we need to reset the data fetcher, after each epoch?

Dear @TheMrZZ,

Thanks for your investigation and happy we solved this ugly bug.

Best, T.C

Is there any update on this? It seems like I’m not the only one facing this issue. I started using PTL since 2.x and would prefer not to downgrade to 1.5.1 to make this issue go away. Is there a 2.x version that doesn’t have this issue? @awaelchli @Borda @carmocca

@rohitgr7 I think I’ve found the issue. Previously, I was passing the datasets into the LightningDataModule, running save_hyperparameters and then accessing them via self.hparams["train_dataset"], self.hparams.["validation_dataset"], etc. It seems as if the entire cached dataset was being cached in this case.

The fix is to explicitly make these modules instance variables and then run self.save_hyperparameters(ignore=["train_dataset", "validation_dataset", "test_dataset"])

I wouldn’t know how to make this reproducible either, but I have pinned PL to 1.6.5 in my project because I observe the same behavior as @lminer on 1.7.* with very slow epochs. I also have very large memory usage, although I am using single gpu and I am not calling save_hyperparameters at all. Could there be other underlying factor?

Hm… very interesting. I did the same with mnist example that comes in the lightning examples. And noticed that Torch was only slightly better in my case, a second or so. However, I’ve just dismantled the advance call in training_epoch_loop.py. Not to explain details, I’ve added some arrows to indicate the execution timeline.

lightning

So the training loop is this:

with self.trainer.profiler.profile("run_training_batch"):
    batch_output = self.batch_loop.run(batch, batch_idx)

and before it, this part takes a lot of time:

  response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
  if response == -1:
      self.batch_progress.increment_processed()
      raise StopIteration

and after it there is another long call:

  self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs)
  self.trainer.call_hook("on_batch_end")
  self.trainer.logger_connector.on_batch_end()

Therefore excessive logging can slow it down.

So, in my particular case, I was using lightning 1.5.9 with one V100, and the aforementioned hook calls seem to build up over time, so in my long experiment with 100 epochs, with 256 batch size, and 35k images Torch performs better. Using just 1 GPU. I didn’t try multiple GPUs. I think we can leave it at that.

@carmocca thanks!

Here are some simple time results:

This was using 1 NVIDIA GeForce RTX 3090 and PyTorch Lightning 1.6.0dev, commit 8394770d4afa5480f881229b150ac44eaa8c41b0, torch==1.10.1, torchmetrics==0.7.0, torchvision==0.11.2, Python 3.8.12.

PyTorch Lightning

real    1m28.276s
user    5m29.551s
sys     0m34.544s

I used https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/imagenet.py verbatim, called with

time python imagenet.py fit --model.data_path /home/imagenet/data --trainer.limit_train_batches=100 --trainer.limit_val_batches=0 --trainer.max_epochs=2

PyTorch

real    1m29.662s
user    5m15.496s
sys     0m31.530s

I used https://github.com/pytorch/examples/blob/master/imagenet/main.py with the following changes applied to disable validation, and stop at 100 batches:

249c249
<         acc1 = validate(val_loader, model, criterion, args)
---
>         #acc1 = validate(val_loader, model, criterion, args)
252,253c252,253
<         is_best = acc1 > best_acc1
<         best_acc1 = max(acc1, best_acc1)
---
>         is_best = False
>         best_acc1 = best_acc1
281a282,283
>         if i == 100:
>             break
time python torch_imagenet.py /home/imagenet/data --epochs=2 --gpu=0

So basically the same speed.

@carmocca Sure, vanilla - meaning, it’s the one taken from official implementation:

Lightning: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/imagenet.py Torch: https://github.com/pytorch/examples/blob/master/imagenet/main.py

Try running those on the same device, with the same image dataset, batch size, etc. I’ve tried on my machine, AWS instance, even Colab, and Torch is always much better.

@carmocca Hi, well, I didn’t try 1.5.1, I tried only 1.4.7, and 1.5.9. I’ll see whether I can try 1.5.1 and get back at you with this!