pytorch-lightning: AttributeError: '_ResultMetric' object has no attribute '_forward_pre_hooks'

Bug description

using self.log with torch compile results in failure. https://github.com/pytorch/pytorch/pull/103621 applied to mitigate “failed to reach fixed point” error with python 3.8

Pytorch lightning 2.0.5 used for the experiment

What version are you seeing the problem on?

v2.0

How to reproduce the bug

def test_compiled_model_to_log_metric_with_cpu(tmp_path):
    class MyModel(BoringModel):
        def training_step(self, batch, batch_idx):
            loss = self.step(batch)
            self.log("loss", loss)
            return loss

    model = MyModel()
    compiled_model = torch.compile(model)

    trainer = Trainer(
        default_root_dir=tmp_path,
        accelerator="cpu",
        fast_dev_run=True,
        devices=1,
        enable_checkpointing=False,
        enable_model_summary=False,
        enable_progress_bar=False,
    )
    trainer.fit(compiled_model)

    assert set(trainer.callback_metrics) == {"loss"}

Error messages and logs

>       raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, name))
E       AttributeError: '_ResultMetric' object has no attribute '_forward_pre_hooks'
E
E       from user code:
E          File "/home/janand/anaconda3/envs/pylight/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 188, in __init__
E           super().__init__()
E
E       Set torch._dynamo.config.verbose=True for more information
E
E
E       You can suppress this exception and fall back to eager by setting:
E           torch._dynamo.config.suppress_errors = True

../../../../anaconda3/envs/pylight/lib/python3.8/site-packages/torch/nn/modules/module.py:1617: AttributeError


Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @carmocca @Blaizzy

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 15 (15 by maintainers)

Most upvoted comments

I found the difference and extracted a raw PyTorch example:

import torch
import torch.nn as nn

class BinaryStatScores(nn.Module):
    def __init__(self):
        super().__init__()


class Accuracy(nn.Module):
    def __new__(cls):
        return BinaryStatScores()


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(2, 2)

    def forward(self, x):
        y = Accuracy()
        return self.layer(x)


def overwrite_torch_functions():
    module_set_attr_orig = torch.nn.Module.__setattr__

    def wrap_set_attr(self, name, value):
        if isinstance(value, torch.nn.Module):
            print(value)  # <-- calls `__repr__` on the module
        module_set_attr_orig(self, name, value)

    torch.nn.Module.__setattr__ = wrap_set_attr

overwrite_torch_functions()

model = Model()
model = torch.compile(model)
model(torch.rand(2, 2))

torch==2.1.0.dev20230818+cu118

We can report this to torch and see if it can be addressed. However, this use case of patching out setattr on all nn.Modules is quite exotic and probably not future-proof.

@jerome-habana my recommendation for you is to just make the setattr compatible with torch.compile explicitly by making it aware of the internal root module. For example, you could handle it like this in your setattr override:

if isinstance(value, torch.nn.Module) and not isinstance(self, torch._dynamo.output_graph.FakeRootModule):
    print(value)

This way, you avoid calling print on the dynamo module which resulted in the repr issue.

I was able to reduce it to just a torchmetrics example where an equivalent error is raised (the cause of the same error we are seeing in Lightnings _ResultMetric that inherits from torchmetrics):

import torch
import torch.nn as nn
from torchmetrics import Accuracy


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(2, 2)

    def forward(self, x):
        y = Accuracy(task="binary")
        return self.layer(x)


def overwrite_torch_functions():
    module_set_attr_orig = torch.nn.Module.__setattr__

    def wrap_set_attr(self, name, value):
        if isinstance(value, torch.nn.Module):
            print(value)  # <-- calls `__repr__` on the module
        module_set_attr_orig(self, name, value)

    torch.nn.Module.__setattr__ = wrap_set_attr

overwrite_torch_functions()

model = Model()
model = torch.compile(model)
model(torch.rand(2, 2))

Traceback:

Traceback (most recent call last):
  File "/home/adrian/repositories/lightning/examples/habana_compile2.py", line 30, in <module>
    model(torch.rand(2, 2))
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 333, in _fn
    return fn(*args, **kwargs)
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/adrian/repositories/lightning/examples/habana_compile2.py", line 12, in forward
    y = Accuracy(task="binary")
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torchmetrics/classification/accuracy.py", line 491, in __new__
    return BinaryAccuracy(threshold, **kwargs)
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 493, in catch_errors
    return callback(frame, cache_size, hooks, frame_state)
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 624, in _convert_frame
    result = inner_convert(frame, cache_size, hooks, frame_state)
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 132, in _fn
    return fn(*args, **kwargs)
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 370, in _convert_frame_assert
    return _compile(
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 571, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(e.__traceback__) from None
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 554, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 180, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 465, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 432, in transform
    tracer.run()
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2071, in run
    super().run()
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 724, in run
    and self.step()
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 688, in step
    getattr(self, inst.opname)(inst)
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 439, in wrapper
    self.output.compile_subgraph(self, reason=reason)
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 775, in compile_subgraph
    root = FakeRootModule(self.nn_modules)
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 169, in __init__
    setattr(self, k, v)
  File "/home/adrian/repositories/lightning/examples/habana_compile2.py", line 21, in wrap_set_attr
    print(value)  # <-- calls `__repr__` on the module
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2493, in __repr__
    for key, module in self._modules.items():
  File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1695, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
torch._dynamo.exc.InternalTorchDynamoError: 'BinaryAccuracy' object has no attribute '_modules'

from user code:
   File "/home/adrian/.conda/envs/lightning/lib/python3.10/site-packages/torchmetrics/classification/stat_scores.py", line 161, in __init__
    super(_AbstractStatScores, self).__init__(**kwargs)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

The instantiation of the metric in the middle of forward() simulates what is happening in Lightning when self.log() is called.

During the execution of this, the torch-compiled code is calling setattr during the creation of that module. This then goes through the patched function here where print(value) is evaluated, which in return calls repr(value) which accesses attributes from that object that don’t exist yet, because the object is being created.

This recursive dependency is the problem, but I don’t know how to solve this. The next step is to reduce this even further to a raw PyTorch example, and then check whether PyTorch can address this use case.

cc @carmocca for visibility

@Borda @carmocca Any updates to this issue ?