onnx-tensorrt: Gather in Upsample problem

Hi! Cant export model from onnx to tensorrt.

`---------------------------------------------------------------- Input filename: model.onnx ONNX IR version: 0.0.4 Opset version: 9 Producer name: pytorch Producer version: 1.1 Domain:
Model version: 0 Doc string:

WARNING: ONNX model has a newer ir_version (0.0.4) than this parser was built against (0.0.3). Parsing model WARNING: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32. Successfully casted down to INT32. While parsing node number 69 [Gather -> “208”]: ERROR: /home/alex/tools/onnx-tensorrt/onnx2trt_utils.hpp:335 In function convert_axis: [8] Assertion failed: axis >= 0 && axis < nbDims %206 : Long() = onnx::Constantvalue={2}, scope: ResNet18_OneConvDecoder/DecoderBlock[center]/Sequential[block]/Upsample[0] %207 : Tensor = onnx::Shape(%205), scope: ResNet18_OneConvDecoder/DecoderBlock[center]/Sequential[block]/Upsample[0] %208 : Long() = onnx::Gather[axis=0](%207, %206), scope: ResNet18_OneConvDecoder/DecoderBlock[center]/Sequential[block]/Upsample[0] %209 : Tensor = onnx::Constantvalue={2} %210 : Tensor = onnx::Mul(%208, %209)`

About this issue

  • Original URL
  • State: closed
  • Created 5 years ago
  • Comments: 19 (2 by maintainers)

Most upvoted comments

I also had this issue with a model coming from PyTorch. Here’s an explanation of what I did to work around the problem:

This PyTorch model, when exported to ONNX, fails when importing in TensorRT because of the Gather operation:

class ShapeModel(nn.Module):
    def __init__(self):
        super(ShapeModel, self).__init__()
    def forward(self, x):
        return x.shape

ShapeDummyModel Assertion failed: axis >= 0 && axis < nbDims

However, this model works:

class ShapeModel(nn.Module):
    def __init__(self):
        super(ShapeModel, self).__init__()
    def forward(self, x):
        return torch.tensor(x.shape)

with just a warning during export to onnx that the trace might not generalize to other inputs.

A single PyTorch upsampling by a factor of 2 gets traced like this:

class ResizeModel(nn.Module):
    def __init__(self):
        super(ResizeModel, self).__init__()
    def forward(self, x):
        return F.interpolate(x, scale_factor=(2, 2), mode='nearest')

ResizeDummyModel in which a lot of work is done to determine the desired size of the output tensor (and in which Gather appears).

A PyTorch interpolate function will also work if you supply not the upsampling factor, but the already-known future size of your tensor. Below an upsampler for (batch_size x channels x H x W) tensors:

class ResizeModel(nn.Module):
    def __init__(self):
        super(ResizeModel, self).__init__()
    def forward(self, x):
        sh = torch.tensor(x.shape)
        return F.interpolate(x, size=(sh[2] * 2, sh[3] * 2), mode='nearest')

Which gets traced to ONNX like this: ResizeDummyModel_workaround thus avoiding the Gather and which functions in TensorRT.

@jinfagang try

up3 = F.interpolate(output3, size=(int(output2.size(2)), int(output2.size(3))), mode="nearest")