torchdynamo: Dynamo can not optimize a model with MaxPool2d on XLA devices

๐Ÿ› Describe the bug

Found this bug when integrate dynamo with torchxla for resnet model. If we move the model and its inputs to XLA device before running dynamo, we would hit this bug. Check the minimal repro below.

cc @jansel @wconstab @jackcaog

Error logs

File โ€œ/pytorch/torch/_dynamo/convert_frame.pyโ€, line 118, in _fn return fn(*args, **kwargs) File โ€œ/pytorch/torch/_dynamo/utils.pyโ€, line 92, in time_wrapper File โ€œ/pytorch/torch/_dynamo/convert_frame.pyโ€, line 118, in _fn File โ€œ/pytorch/torch/_refs/init.pyโ€, line 45, in <module> from torch.fx.experimental.symbolic_shapes import sym_float, sym_int File โ€œ/pytorch/torch/fx/experimental/symbolic_shapes.pyโ€, line 17, in <module> import sympy # type: ignore[import] File โ€œ/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/init.pyโ€, line 51, in <module> from .core import (sympify, SympifyError, cacheit, Basic, Atom, File โ€œ/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/core/init.pyโ€, line 4, in <module> from .sympify import sympify, SympifyError File โ€œ/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/core/sympify.pyโ€, line 9, in <module> from .compatibility import iterable File โ€œ/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/core/compatibility.pyโ€, line 11, in <module> from sympy.external import import_module File โ€œ/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/external/init.pyโ€, line 18, in <module> from sympy.external.importtools import import_module File โ€œ/root/anaconda3/envs/pytorch/lib/python3.7/site-packages/sympy/external/importtools.pyโ€, line 4, in <module> from distutils.version import LooseVersion File โ€œ<frozen importlib._bootstrap>โ€, line 983, in _find_and_load File โ€œ<frozen importlib._bootstrap>โ€, line 963, in _find_and_load_unlocked File โ€œ<frozen importlib._bootstrap>โ€, line 906, in _find_spec return fn(*args, **kwargs) File โ€œ/pytorch/torch/_dynamo/utils.pyโ€, line 92, in time_wrapper r = func(*args, **kwargs) File โ€œ/pytorch/torch/_dynamo/convert_frame.pyโ€, line 356, in _convert_frame_assert frame, File โ€œ/pytorch/torch/_dynamo/convert_frame.pyโ€, line 402, in _compile out_code = transform_code_object(code, transform) File โ€œ/pytorch/torch/_dynamo/bytecode_transformation.pyโ€, line 341, in transform_code_object transformations(instructions, code_options) File โ€œ/pytorch/torch/_dynamo/convert_frame.pyโ€, line 390, in transform tracer.run() File โ€œ/pytorch/torch/_dynamo/symbolic_convert.pyโ€, line 1468, in run super().run() File โ€œ/pytorch/torch/_dynamo/symbolic_convert.pyโ€, line 352, in run and self.step() File โ€œ/pytorch/torch/_dynamo/symbolic_convert.pyโ€, line 322, in step getattr(self, inst.opname)(inst) File โ€œ/pytorch/torch/_dynamo/symbolic_convert.pyโ€, line 174, in wrapper return inner_fn(self, inst) File โ€œ/pytorch/torch/_dynamo/symbolic_convert.pyโ€, line 766, in CALL_FUNCTION self.call_function(fn, args, {}) File โ€œ/pytorch/torch/_dynamo/symbolic_convert.pyโ€, line 264, in call_function self.push(fn.call_function(self, args, kwargs)) File โ€œ/pytorch/torch/_dynamo/variables/nn_module.pyโ€, line 209, in call_function **options, File โ€œ/pytorch/torch/_dynamo/convert_frame.pyโ€, line 118, in _fn return fn(*args, **kwargs) File โ€œ/pytorch/torch/_dynamo/utils.pyโ€, line 92, in time_wrapper r = func(*args, **kwargs) File โ€œ/pytorch/torch/_dynamo/convert_frame.pyโ€, line 356, in _convert_frame_assert frame, File โ€œ/pytorch/torch/_dynamo/convert_frame.pyโ€, line 402, in _compile out_code = transform_code_object(code, transform) File โ€œ/pytorch/torch/_dynamo/bytecode_transformation.pyโ€, line 341, in transform_code_object transformations(instructions, code_options) File โ€œ/pytorch/torch/_dynamo/convert_frame.pyโ€, line 390, in transform tracer.run() File โ€œ/pytorch/torch/_dynamo/symbolic_convert.pyโ€, line 1468, in run super().run() File โ€œ/pytorch/torch/_dynamo/symbolic_convert.pyโ€, line 352, in run and self.step() File โ€œ/pytorch/torch/_dynamo/symbolic_convert.pyโ€, line 322, in step getattr(self, inst.opname)(inst) File โ€œ/pytorch/torch/_dynamo/symbolic_convert.pyโ€, line 174, in wrapper return inner_fn(self, inst) File โ€œ/pytorch/torch/_dynamo/symbolic_convert.pyโ€, line 766, in CALL_FUNCTION self.call_function(fn, args, {}) File โ€œ/pytorch/torch/_dynamo/symbolic_convert.pyโ€, line 264, in call_function self.push(fn.call_function(self, args, kwargs)) File โ€œ/pytorch/torch/_dynamo/variables/nn_module.pyโ€, line 209, in call_function **options, File โ€œ/pytorch/torch/_dynamo/variables/tensor.pyโ€, line 201, in create example_value = _get_fake_value(proxy.node, tx) File โ€œ/pytorch/torch/_dynamo/variables/tensor.pyโ€, line 145, in _get_fake_value raise TorchRuntimeError() from e torch._dynamo.exc.TorchRuntimeError:

from user code: File โ€œmyscripts/repro_maxpool.pyโ€, line 14, in forward out = self.pool(out)

Set torch._dynamo.config.verbose=True for more information

You can suppress this exception and fall back to eager by setting: torch._dynamo.config.suppress_errors = True

Minified repro

repro_maxpool.py

from torch import nn
import torch
import torch._dynamo as dynamo
import torch_xla.core.xla_model as xm

class MaxPoolModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 6, kernel_size=3, stride=2)
        self.pool = nn.MaxPool2d(kernel_size=3, stride=2)

    def forward(self, x):
        out = self.conv(x)
        out = self.pool(out)
        return out

    def get_random_inputs(self):
        return (torch.rand(2, 3, 10, 10),)

xla_dev = xm.xla_device()
model = MaxPoolModule().to(device=xla_dev)
inputs = map(lambda x: x.to(device=xla_dev), model.get_random_inputs())
dynamo.optimize(lambda gm, _: gm)(lambda: model(*inputs))()

Command:

GPU_NUM_DEVICES=1 python repro_maxpool.py

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 26 (22 by maintainers)

Commits related to this issue

Most upvoted comments

Training for resnet18 works now with aot_eager or aot_torchxla_trivial backend if replacing MaxPool to AvgPool. aot_torchxla_trace_once still have some problem, but thatโ€™s not related to this issue. Overall, replace MaxPool to AvgPool can temporarily unblock the project before the real fix is in.

Oh- Iโ€™m not at my laptop, but this should be fixable.

instead of redispatching directly to the python key, we need to make sure that we can hit any other dispatch keys that need to run first, like functionalization.