pytorch-lightning: load_from_checkpoint: TypeError: __init__() missing 1 required positional argument

❓ Questions and Help

What is your question?

load_from_checkpoint: TypeError: init() missing 1 required positional argument

I have read the issues before, but the things different is my LightningModule is inherited from my self-defined LightningModule.

How to solve this problem or what is the best practice better suited to my needs?

Code

To reproduce the error:

import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning import Trainer

from argparse import Namespace

class _LitModel(pl.LightningModule):

    def __init__(self, hparams):
        super().__init__()
        if isinstance(hparams, dict):
            hparams = Namespace(**hparams)
        self.hparams = hparams
        self.l1 = torch.nn.Linear(28 * 28, hparams.classes)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return {'val_loss': loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'val_loss': avg_loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

class LitModel(_LitModel):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--classes', type=int, default=10)
parser.add_argument('--checkpoint', type=str, default=None)
hparams = parser.parse_args()

mnist_train = MNIST(os.getcwd(), train=True, download=True,
                    transform=transforms.ToTensor())
mnist_train = DataLoader(mnist_train, num_workers=1)
mnist_val = MNIST(os.getcwd(), train=False, download=False,
                  transform=transforms.ToTensor())
mnist_val = DataLoader(mnist_val, num_workers=1)

# A bit weird here. I just want to show `load_from_checkpoint` will fail.
if hparams.checkpoint is None:
    model = LitModel(hparams)
else:
    model = LitModel.load_from_checkpoint(hparams.checkpoint)

trainer = Trainer(max_epochs=2, limit_train_batches=2,
                  limit_val_batches=2, progress_bar_refresh_rate=0)
trainer.fit(model, mnist_train, mnist_val)

Error msg

Traceback (most recent call last):
  File "main.py", line 64, in <module>
    model = LitModel.load_from_checkpoint(hparams.checkpoint)
  File "/home/siahuat0727/.local/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 138, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, *args, **kwargs)
  File "/home/siahuat0727/.local/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 174, in _load_model_state
    model = cls(*cls_args, **cls_kwargs)
  File "main.py", line 46, in __init__
    super().__init__(*args, **kwargs)
TypeError: __init__() missing 1 required positional argument: 'hparams'

How to run to get the error

$ python3 main.py 
$ python3 main.py --checkpoint lightning_logs/version_0/checkpoints/epoch\=1.ckpt

What’s your environment?

  • OS: Linux
  • Packaging: pip
  • Version 0.9.0rc12

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 16 (7 by maintainers)

Most upvoted comments

@pietz in this case you can instantiate your model normally, model = YourModel(...) and then load the state dict from the checkpoint:

checkpoint = torch.load(path)
model.load_state_dict(checkpoint["state_dict"]

@pietz in this case you can instantiate your model normally, model = YourModel(...) and then load the state dict from the checkpoint:

checkpoint = torch.load(path)
model.load_state_dict(checkpoint["state_dict"]

but we sometimes have to load some other params like optimizer params, we have to use several load_* function. It is not good!

Did you try to call self.save_hyperparameters() in _LitModel? Because it looks like hparams were not saved to checkpoint.

what can i do if i already trained my models without calling self.save_hyperparameters() explicitely?

The same issue appears with version 1.0.5 (0.9.0 is fine). Can you help with it?

(also track_grad_norm doesn’t work in 0.9.0, so I have to switch to 1.0.5…)

Bump

@stathius yes, the “old” hparams method is not yet deprecated but it simply has conceptual flaws in terms of typing, that cannot be fixed as in a “bugfix”. The solution we came up with here is to simply decouple two things:

  1. Saving hyperparameters into the checkpoint
  2. making hyperparameters accessible through a convenient self.hparams “namespace”.

And the code you posted is exactly doing that, and this is the recommended way today.

What solved it for me is that instead of passing the hparams, you can pass them as kwargs. So in your class use:

class my_pl_module(LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.save_hyperparameters()

Still a bug though cause the hparams method is not yet deprecated.

The same issue appears with version 1.0.5 (0.9.0 is fine). Can you help with it?

(also track_grad_norm doesn’t work in 0.9.0, so I have to switch to 1.0.5…)

@awaelchli Ah, thank you. Looking back I should have been able to figure this one out myself 😃

@awaelchli Hihi, the result is the same. It works if I directly use _LitModel instead of LitModel. So I think that’s sth about inheritance.