pytorch-lightning: TypeError: can't pickle _thread.lock objects - Error while logging model into mlflow in multi gpu scenario

❓ Questions and Help

What is your question?

Trying to log model into mlflow using mlflow.pytorch.log_model in train end. Getting the above error only in multi gpu scenario.

Code

mnist script file -

import pytorch_lightning as pl
import torch
from argparse import ArgumentParser
#from mlflow.pytorch.pytorch_autolog import __MLflowPLCallback
from pytorch_lightning.logging import MLFlowLogger
from sklearn.metrics import accuracy_score
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms


class LightningMNISTClassifier(pl.LightningModule):
    def __init__(self):
        """
        Initializes the network
        """
        super(LightningMNISTClassifier, self).__init__()

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)

        # transforms for images
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument(
            "--batch-size",
            type=int,
            default=64,
            metavar="N",
            help="input batch size for training (default: 64)",
        )
        parser.add_argument(
            "--num-workers",
            type=int,
            default=0,
            metavar="N",
            help="number of workers (default: 0)",
        )
        parser.add_argument(
            "--lr",
            type=float,
            default=1e-3,
            metavar="LR",
            help="learning rate (default: 1e-3)",
        )
        return parser

    def forward(self, x):
        """
        Forward Function
        """
        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)

        # layer 1 (b, 1*28*28) -> (b, 128)
        x = self.layer_1(x)
        x = torch.relu(x)

        # layer 2 (b, 128) -> (b, 256)
        x = self.layer_2(x)
        x = torch.relu(x)

        # layer 3 (b, 256) -> (b, 10)
        x = self.layer_3(x)

        # probability distribution over labels
        x = torch.log_softmax(x, dim=1)

        return x

    def cross_entropy_loss(self, logits, labels):
        """
        Loss Fn to compute loss
        """
        return F.nll_loss(logits, labels)

    def training_step(self, train_batch, batch_idx):
        """
        training the data as batches and returns training loss on each batch
        """
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        return {"loss": loss}

    def validation_step(self, val_batch, batch_idx):
        """
        Performs validation of data in batches
        """
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        return {"val_loss": loss}

    def validation_epoch_end(self, outputs):
        """
        Computes average validation accuracy
        """
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        tensorboard_logs = {"val_loss": avg_loss}
        return {"avg_val_loss": avg_loss, "log": tensorboard_logs}

    def test_step(self, test_batch, batch_idx):
        """
        Performs test and computes test accuracy
        """
        x, y = test_batch
        output = self.forward(x)
        a, y_hat = torch.max(output, dim=1)
        test_acc = accuracy_score(y_hat.cpu(), y.cpu())
        return {"test_acc": torch.tensor(test_acc)}

    def test_epoch_end(self, outputs):
        """
        Computes average test accuracy score
        """
        avg_test_acc = torch.stack([x["test_acc"] for x in outputs]).mean()
        return {"avg_test_acc": avg_test_acc}

    def prepare_data(self):
        """
        Preprocess the input data.
        """
        return {}

    def train_dataloader(self):
        """
        Loading training data as batches
        """
        mnist_train = datasets.MNIST(
            "dataset", download=True, train=True, transform=self.transform
        )
        return DataLoader(
            mnist_train,
            batch_size=64,
            num_workers=1
        )

    def val_dataloader(self):
        """
        Loading validation data as batches
        """
        mnist_train = datasets.MNIST(
            "dataset", download=True, train=True, transform=self.transform
        )
        mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

        return DataLoader(
            mnist_val,
            batch_size=64,
            num_workers=1
        )

    def test_dataloader(self):
        """
        Loading test data as batches
        """
        mnist_test = datasets.MNIST(
            "dataset", download=True, train=False, transform=self.transform
        )
        return DataLoader(
            mnist_test,
            batch_size=64,
            num_workers=1
        )

    def configure_optimizers(self):
        """
        Creates and returns Optimizer
        """
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        self.scheduler = {
            "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer,
                mode="min",
                factor=0.2,
                patience=2,
                min_lr=1e-6,
                verbose=True,
            )
        }
        return [self.optimizer], [self.scheduler]

    def optimizer_step(
        self,
        epoch,
        batch_idx,
        optimizer,
        optimizer_idx,
        second_order_closure=None,
        on_tpu=False,
        using_lbfgs=False,
        using_native_amp=False,
    ):
        self.optimizer.step()
        self.optimizer.zero_grad()


if __name__ == "__main__":
    from pytorch_autolog import autolog
    autolog()
    model = LightningMNISTClassifier()
    mlflow_logger = MLFlowLogger(
        experiment_name="Default", tracking_uri="http://localhost:5000/"
    )
    trainer = pl.Trainer(
        logger=mlflow_logger,
        gpus=2,
        distributed_backend="ddp",
        max_epochs=1
    )
    trainer.fit(model)
    trainer.test()

Sample code from autolog - Callback class.

    class __MLflowPLCallback(pl.Callback):

        def __init__(self):
            super().__init__()

        def on_train_end(self, trainer, pl_module):
            """
            Logs the model checkpoint into mlflow - models folder on the training end
            """

            mlflow.set_tracking_uri(trainer.logger._tracking_uri )
            mlflow.set_experiment(trainer.logger._experiment_name)
            mlflow.start_run(trainer.logger.run_id)
            mlflow.pytorch.log_model(trainer.model, "models")
            mlflow.end_run()


Stack Trace

Traceback (most recent call last):                                                                                                                                                                          
  File "mnist.py", line 231, in <module>
    trainer.fit(model)
  File "/home/ubuntu/mnist/pytorch_autolog.py", line 218, in fit
    return _run_and_log_function(self, original, args, kwargs)
  File "/home/ubuntu/mnist/pytorch_autolog.py", line 209, in _run_and_log_function
    result = original(self, *args, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 992, in fit
    results = self.spawn_ddp_children(model)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 462, in spawn_ddp_children
    results = self.ddp_train(local_rank, q=None, model=model, is_master=True)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 560, in ddp_train
    results = self.run_pretrain_routine(model)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1213, in run_pretrain_routine
    self.train()
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 392, in train
    self.run_training_teardown()
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 872, in run_training_teardown
    self.on_train_end()
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/callback_hook.py", line 72, in on_train_end
    callback.on_train_end(self, self.get_model())
  File "/home/ubuntu/mnist/pytorch_autolog.py", line 120, in on_train_end
    mlflow.pytorch.log_model(trainer.model, "models")
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/mlflow/pytorch/__init__.py", line 179, in log_model
    signature=signature, input_example=input_example, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/mlflow/models/model.py", line 154, in log
    **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/mlflow/pytorch/__init__.py", line 300, in save_model
    torch.save(pytorch_model, model_path, pickle_module=pickle_module, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 370, in save
    _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/serialization.py", line 443, in _legacy_save
    pickler.dump(obj)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/cloudpickle/cloudpickle.py", line 491, in dump
    return Pickler.dump(self, obj)
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 437, in dump
    self.save(obj)
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 549, in save
    self.save_reduce(obj=obj, *rv)
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 662, in save_reduce
    save(state)
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 504, in save
    f(self, obj) # Call unbound method with explicit self
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 859, in save_dict
    self._batch_setitems(obj.items())
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 885, in _batch_setitems
    save(v)
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 549, in save
    self.save_reduce(obj=obj, *rv)
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 659, in save_reduce
    self._batch_setitems(dictitems)
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 890, in _batch_setitems
    save(v)
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 549, in save
    self.save_reduce(obj=obj, *rv)
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 662, in save_reduce
    save(state)
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 504, in save
    f(self, obj) # Call unbound method with explicit self
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 859, in save_dict
    self._batch_setitems(obj.items())
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 885, in _batch_setitems
    save(v)
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 549, in save
    self.save_reduce(obj=obj, *rv)
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 662, in save_reduce
    save(state)
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 504, in save
    f(self, obj) # Call unbound method with explicit self
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 859, in save_dict
    self._batch_setitems(obj.items())
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 885, in _batch_setitems
    save(v)
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 504, in save
    f(self, obj) # Call unbound method with explicit self
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 819, in save_list
    self._batch_appends(obj)
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 846, in _batch_appends
    save(tmp[0])
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 549, in save
    self.save_reduce(obj=obj, *rv)
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 662, in save_reduce
    save(state)
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 504, in save
    f(self, obj) # Call unbound method with explicit self
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 859, in save_dict
    self._batch_setitems(obj.items())
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 885, in _batch_setitems
    save(v)
  File "/home/ubuntu/anaconda3/lib/python3.7/pickle.py", line 524, in save
    rv = reduce(self.proto)
TypeError: can't pickle _thread.lock objects


What have you tried?

Tried out the possibilities mentioned in the similar thread - https://github.com/PyTorchLightning/pytorch-lightning/issues/2186

Tried wrapping the code inside, trainer.is_global_zero . And also tried trainer.global_rank == 0. Also tried decorating the method as @rank_zero_only. But no luck. Getting the same error.

What’s your environment?

  • OS: Ubuntu
  • Packaging - torch, pytorch-lightning, torchvision, mlflow

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 15 (5 by maintainers)

Most upvoted comments

I ran into this as well. I wrote a little function to iterate through the model.dict.items() and check which caused pickle errors. It looks like there’s a model attribute pointing to the trainer, which has that _thread.lock on it. Maybe there’s a step that’s supposed to clean this up that’s being missed somehow? My quick work-around was to delattr(model, "trainer") before pickling the model, but I haven’t actually tried loading the model again, so I this could cause other problems.

What happens if you don’t use the scheduler? Please try commenting out the scheduler definition and return only the optimizer.

I removed the scheduler part and re-ran the script. Still experiencing the same error.