pytorch-lightning: Cannot use `torch.jit.trace` to trace `LightningModule` in Lightning v1.7

๐Ÿ› Bug

When I use torch.jit.trace to trace a LightningModule, I got RuntimeError: XXX(LightningModule class name) is not attached to a Trainer.

This because in lightning 1.7.0, property trainer will raise an RuntimeError if the module doesnโ€™t attach a Trainer.

https://github.com/Lightning-AI/lightning/blob/12a061f2aaefaa9ed9ccf81ab6f378835b675a7e/src/pytorch_lightning/core/module.py#L179

but in torch.jit, it will justify each attr by hasattr,

https://github.com/pytorch/pytorch/blob/de0e03001d31523ef86c3d7852c87cdad6d96632/torch/_jit_internal.py#L749

and in hasattr docstring, Return whether the object has an attribute with the given name. This is done by calling getattr(obj, name) and catching AttributeError.

To Reproduce

Initialize any LightningModule under Lightning v1.7.0, and trace it by torch.jit.trace without attach trainer to the lightning module.

To Fix

Replace RuntimeError by AttributeError.

This fix is work for me, but I donโ€™t know will this cause other problems.

Environment

  • Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): LightningModule
  • PyTorch Lightning Version (e.g., 1.5.0): 1.7.0
  • PyTorch Version (e.g., 1.10): 1.10
  • Python version (e.g., 3.9): 3.7.12

cc @carmocca @justusschock @awaelchli @borda @ananthsub @ninginthecloud @jjenniferdai @rohitgr7

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Reactions: 2
  • Comments: 15 (5 by maintainers)

Most upvoted comments

UPDATE: Solved in 1.7.7

The same happens with me.

I am still having this issue. I am using Lightning Flash ObjectDetector with YOLOv5 backbone and neither script or trace the model works. The error says

ModelAdapter is not attached to a Trainer.

Hi @carmocca, I am getting the same error when using torch.jit.trace for tracing a LightningModule. Unfortunately, I cannot use the above workarounds, as torch.jit.trace is being internally called by a library I am using with LightningModule. So do you have any suggestions to make torch.jit.trace work with a LightningModule?

FYI, I am using pytorch-lightning version 1.7.3

I am also still seeing this error. On the most recent versions. pytorch-lightning 1.8.3.post0 pytorch 1.13.0

My workaround is give it a dummy trainer model._trainer = pl.Trainer()

can we just create a dummy object and attach it to the lightningmodule? @carmocca

@laclouis5 No, and most likely never since PyTorch no longer works on TorchScript since the release of torch.compile

Hi! Unfortunately, this is caused by a bug in PyTorch where properties are not correctly ignored: https://github.com/pytorch/pytorch/issues/67146

As a workaround, you can use model.to_torchscript(method="trace")