pytorch-lightning: Inconsistent behaviour and "AttributeError: _old_init" when using Pytorch Lightning with the ogb library.

🐛 Bug

The simple example from the documentation page https://pytorch-lightning.readthedocs.io/en/latest/notebooks/lightning_examples/mnist-hello-world.html works fine. However, when I add the import instruction from ogb.nodeproppred import *, the code has inconsistent behaviour: sometimes it works and sometimes it throws an “AttributeError: _old_init” exception:

 Traceback (most recent call last):
  File "plexample.py", line 54, in <module>
    trainer.fit(mnist_model, train_loader)
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 700, in fit
    self._call_and_handle_interrupt(
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 654, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 741, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
    results = self._run_stage()
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1252, in _run_stage
    return self._run_train()
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1282, in _run_train
    self.fit_loop.run()
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line 195, in run
    self.on_run_start(*args, **kwargs)
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 210, in on_run_start
    self.trainer.reset_train_dataloader(self.trainer.lightning_module)
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1811, in reset_train_dataloader
    self.train_dataloader = self._data_connector._request_dataloader(RunningStage.TRAINING)
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 430, in _request_dataloader
    dataloader = source.dataloader()
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/contextlib.py", line 126, in __exit__
    next(self.gen)
  File "/home/schlyah/anaconda3/envs/pipinstallenv/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py", line 527, in _replace_init_method
    del cls._old_init
AttributeError: _old_init

To Reproduce

The documentation example code with the import instruction:

import os
import pandas as pd
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
import argparse
from ogb.nodeproppred import *

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64

class MNISTModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        loss = F.cross_entropy(self(x), y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)


if __name__ == '__main__':
    # Init our model
    mnist_model = MNISTModel()

    # Init DataLoader from MNIST Dataset
    train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)

    # Initialize a trainer
    trainer = Trainer(
        accelerator="auto",
        devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
        max_epochs=3,
        callbacks=[TQDMProgressBar(refresh_rate=20)],
    )

    # Train the model ⚡
    trainer.fit(mnist_model, train_loader)

Environment

  • PyTorch Lightning Version: 1.7.0
  • PyTorch Version: 1.12.0
  • Python version: 3.9
  • OS: Linux
  • CUDA/cuDNN version: 10.2
  • GPU models and configuration:
  • How you installed PyTorch: pip

cc @justusschock @awaelchli @ninginthecloud @rohitgr7 @otaj

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Reactions: 3
  • Comments: 15 (9 by maintainers)

Most upvoted comments

Now, the problem is fixed. Thank you

Hi, everyone. I used @PurpleSand123 example with one extra test script

import subprocess
import sys

import numpy as np


def main():
    i = 0
    while i < 100:
        i += 1
        try:
            _ = subprocess.run([sys.executable, "main.py"], check=True, capture_output=True)
        except subprocess.CalledProcessError:
            print(f"Broke on {i}th try")
            break
    if i == 100:
        print("Didn't break")
    return i

if __name__ == "__main__":
    results = np.array([main() for _ in range(10)])
    print(results.mean())

When running this with PL 1.7.0, the output on my machine was:

Broke on 5th try
Broke on 4th try
Broke on 6th try
Broke on 3th try
Broke on 18th try
Broke on 6th try
Broke on 1th try
Broke on 2th try
Broke on 2th try
Broke on 2th try
4.9

When running with the change in the linked PR, the output on my machine was:

Didn't break
Didn't break
Didn't break
Didn't break
Didn't break
Didn't break
Didn't break
Didn't break
Didn't break
Didn't break
100

Please, try your examples with the changes in the linked PR and check if it fixes it for you.

I am having the same problem when using torch_geometric.loader.DataLoader For debugging, I changed “_replace_init_method” code like this

@contextmanager
def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]:
    """This context manager is used to add support for re-instantiation of custom (subclasses) of "base_cls".
    It patches the "__init__" method.
    """
    classes = _get_all_subclasses(base_cls) | {base_cls}
    wrapped = set()
    print('before:', classes)
    for cls in classes:
        if cls.__init__ not in wrapped:
            print(cls.__name__)
            cls._old_init = cls.__init__
            cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg)
            wrapped.add(cls.__init__)
    yield
    print('after:', classes)
    for cls in classes:
        if hasattr(cls, "_old_init"):
            print("del", cls.__name__)
            cls.__init__ = cls._old_init
            del cls._old_init

It works well sometimes and get output as

before: {<class 'torch_geometric.loader.random_node_sampler.RandomNodeSampler'>, <class 'torch_geometric.loader.neighbor_sampler.NeighborSampler'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTNodeSampler'>, <class 'torch_geometric.loader.neighbor_loader.NeighborLoader'>, <class 'torch_geometric.loader.dense_data_loader.DenseDataLoader'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTEdgeSampler'>, <class 'torch_geometric.loader.hgt_loader.HGTLoader'>, <class 'torch_geometric.loader.dataloader.DataLoader'>, <class 'torch_geometric.loader.cluster.ClusterLoader'>, <class 'torch_geometric.loader.base.BaseDataLoader'>, <class 'torch_geometric.loader.temporal_dataloader.TemporalDataLoader'>, <class 'torch.utils.data.dataloader.DataLoader'>, <class 'torch_geometric.loader.data_list_loader.DataListLoader'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTRandomWalkSampler'>, <class 'torch_geometric.loader.shadow.ShaDowKHopSampler'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTSampler'>}
RandomNodeSampler
NeighborSampler
GraphSAINTNodeSampler
NeighborLoader
DenseDataLoader
GraphSAINTEdgeSampler
HGTLoader
DataLoader
ClusterLoader
BaseDataLoader
TemporalDataLoader
DataLoader
DataListLoader
GraphSAINTRandomWalkSampler
ShaDowKHopSampler
GraphSAINTSampler
before: {<class 'torch.utils.data.sampler.BatchSampler'>}
BatchSampler
after: {<class 'torch.utils.data.sampler.BatchSampler'>}
del BatchSampler
after: {<class 'torch_geometric.loader.random_node_sampler.RandomNodeSampler'>, <class 'torch_geometric.loader.neighbor_sampler.NeighborSampler'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTNodeSampler'>, <class 'torch_geometric.loader.neighbor_loader.NeighborLoader'>, <class 'torch_geometric.loader.dense_data_loader.DenseDataLoader'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTEdgeSampler'>, <class 'torch_geometric.loader.hgt_loader.HGTLoader'>, <class 'torch_geometric.loader.dataloader.DataLoader'>, <class 'torch_geometric.loader.cluster.ClusterLoader'>, <class 'torch_geometric.loader.base.BaseDataLoader'>, <class 'torch_geometric.loader.temporal_dataloader.TemporalDataLoader'>, <class 'torch.utils.data.dataloader.DataLoader'>, <class 'torch_geometric.loader.data_list_loader.DataListLoader'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTRandomWalkSampler'>, <class 'torch_geometric.loader.shadow.ShaDowKHopSampler'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTSampler'>}
del RandomNodeSampler
del NeighborSampler
del GraphSAINTNodeSampler
del NeighborLoader
del DenseDataLoader
del GraphSAINTEdgeSampler
del HGTLoader
del DataLoader
del ClusterLoader
del BaseDataLoader
del TemporalDataLoader
del DataLoader
del DataListLoader
del GraphSAINTRandomWalkSampler
del ShaDowKHopSampler
del GraphSAINTSampler

However, sometimes, the error occurs and the output is

before: {<class 'torch_geometric.loader.graph_saint.GraphSAINTNodeSampler'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTSampler'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTEdgeSampler'>, <class 'torch_geometric.loader.temporal_dataloader.TemporalDataLoader'>, <class 'torch_geometric.loader.shadow.ShaDowKHopSampler'>, <class 'torch_geometric.loader.neighbor_sampler.NeighborSampler'>, <class 'torch_geometric.loader.neighbor_loader.NeighborLoader'>, <class 'torch_geometric.loader.cluster.ClusterLoader'>, <class 'torch_geometric.loader.random_node_sampler.RandomNodeSampler'>, <class 'torch_geometric.loader.data_list_loader.DataListLoader'>, <class 'torch_geometric.loader.hgt_loader.HGTLoader'>, <class 'torch_geometric.loader.dataloader.DataLoader'>, <class 'torch_geometric.loader.base.BaseDataLoader'>, <class 'torch.utils.data.dataloader.DataLoader'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTRandomWalkSampler'>, <class 'torch_geometric.loader.dense_data_loader.DenseDataLoader'>}
GraphSAINTNodeSampler
GraphSAINTSampler
TemporalDataLoader
ShaDowKHopSampler
NeighborSampler
NeighborLoader
ClusterLoader
RandomNodeSampler
DataListLoader
HGTLoader
DataLoader
BaseDataLoader
DataLoader
GraphSAINTRandomWalkSampler
DenseDataLoader
before: {<class 'torch.utils.data.sampler.BatchSampler'>}
BatchSampler
after: {<class 'torch.utils.data.sampler.BatchSampler'>}
del BatchSampler
after: {<class 'torch_geometric.loader.graph_saint.GraphSAINTNodeSampler'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTSampler'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTEdgeSampler'>, <class 'torch_geometric.loader.temporal_dataloader.TemporalDataLoader'>, <class 'torch_geometric.loader.shadow.ShaDowKHopSampler'>, <class 'torch_geometric.loader.neighbor_sampler.NeighborSampler'>, <class 'torch_geometric.loader.neighbor_loader.NeighborLoader'>, <class 'torch_geometric.loader.cluster.ClusterLoader'>, <class 'torch_geometric.loader.random_node_sampler.RandomNodeSampler'>, <class 'torch_geometric.loader.data_list_loader.DataListLoader'>, <class 'torch_geometric.loader.hgt_loader.HGTLoader'>, <class 'torch_geometric.loader.dataloader.DataLoader'>, <class 'torch_geometric.loader.base.BaseDataLoader'>, <class 'torch.utils.data.dataloader.DataLoader'>, <class 'torch_geometric.loader.graph_saint.GraphSAINTRandomWalkSampler'>, <class 'torch_geometric.loader.dense_data_loader.DenseDataLoader'>}
del GraphSAINTNodeSampler
del GraphSAINTSampler
del GraphSAINTEdgeSampler
Traceback (most recent call last):`

`File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 868, in test
    return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule)
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 654, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 915, in _test_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
    results = self._run_stage()
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1249, in _run_stage
    return self._run_evaluate()
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1288, in _run_evaluate
    self._evaluation_loop._reload_evaluation_dataloaders()
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 234, in _reload_evaluation_dataloaders
    self.trainer.reset_test_dataloader()
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1941, in reset_test_dataloader
    self.num_test_batches, self.test_dataloaders = self._data_connector._reset_eval_dataloader(
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 344, in _reset_eval_dataloader
    dataloaders = self._request_dataloader(mode)
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 427, in _request_dataloader
    with _replace_init_method(DataLoader, "dataset"), _replace_init_method(BatchSampler):
  File "/anaconda/envs/pytorch/lib/python3.10/contextlib.py", line 142, in __exit__
    next(self.gen)
  File "/anaconda/envs/pytorch/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py", line 531, in _replace_init_method
    del cls._old_init
AttributeError: _old_init

Hi @schlyah, I spent a while on this and am unable to reproduce. Would you mind sharing your full environment (i.e. pip freeze) as well? Thanks a lot!