onnxscript: Unsupported FX Nodes: {'call_function': ['aten.roll.default', 'aten.var.correction']}

Hello,

First of all, sorry for this post, I’m still kind lost on how ONNX opset 18 works and how TorchDynamo exports the model to an ONNX protobuf. Well I trained my model and now I’m trying to export to ONNX. Using torch.export.export I can generate an ExportedProgram with the following signature:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[1, 1, 3, 3], arg1_1: f32[1, 1, 3, 3], arg2_1: f32[1, 1, 3, 3], arg3_1: f32[1, 1, 3, 3], arg4_1: f32[1, 1, 9, 9], arg5_1: f32[1, 1, 9, 9], arg6_1: f32[64, 3, 7, 7], arg7_1: f32[64], arg8_1: f32[64], arg9_1: f32[64, 64, 3, 3], arg10_1: f32[64], arg11_1: f32[64], arg12_1: f32[64, 64, 3, 3], arg13_1: f32[64], arg14_1: f32[64], arg15_1: f32[64, 64, 3, 3], arg16_1: f32[64], arg17_1: f32[64], arg18_1: f32[64, 64, 3, 3], arg19_1: f32[64], arg20_1: f32[64], arg21_1: f32[128, 64, 3, 3], arg22_1: f32[128], arg23_1: f32[128], arg24_1: f32[128, 128, 3, 3], arg25_1: f32[128], arg26_1: f32[128], arg27_1: f32[128, 64, 1, 1], arg28_1: f32[128], arg29_1: f32[128], arg30_1: f32[128, 128, 3, 3], arg31_1: f32[128], arg32_1: f32[128], arg33_1: f32[128, 128, 3, 3], arg34_1: f32[128], arg35_1: f32[128], arg36_1: f32[256, 128, 3, 3], arg37_1: f32[256], arg38_1: f32[256], arg39_1: f32[256, 256, 3, 3], arg40_1: f32[256], arg41_1: f32[256], arg42_1: f32[256, 128, 1, 1], arg43_1: f32[256], arg44_1: f32[256], arg45_1: f32[256, 256, 3, 3], arg46_1: f32[256], arg47_1: f32[256], arg48_1: f32[256, 256, 3, 3], arg49_1: f32[256], arg50_1: f32[256], arg51_1: f32[512, 256, 3, 3], arg52_1: f32[512], arg53_1: f32[512], arg54_1: f32[512, 512, 3, 3], arg55_1: f32[512], arg56_1: f32[512], arg57_1: f32[512, 256, 1, 1], arg58_1: f32[512], arg59_1: f32[512], arg60_1: f32[512, 512, 3, 3], arg61_1: f32[512], arg62_1: f32[512], arg63_1: f32[512, 512, 3, 3], arg64_1: f32[512], arg65_1: f32[512], arg66_1: f32[512, 256, 2, 2], arg67_1: f32[256], arg68_1: f32[256, 512, 3, 3], arg69_1: f32[256], arg70_1: f32[256], arg71_1: f32[256, 256, 3, 3], arg72_1: f32[256], arg73_1: f32[256], arg74_1: f32[256, 128, 2, 2], arg75_1: f32[128], arg76_1: f32[128, 256, 3, 3], arg77_1: f32[128], arg78_1: f32[128], arg79_1: f32[128, 128, 3, 3], arg80_1: f32[128], arg81_1: f32[128], arg82_1: f32[128, 64, 2, 2], arg83_1: f32[64], arg84_1: f32[64, 128, 3, 3], arg85_1: f32[64], arg86_1: f32[64], arg87_1: f32[64, 64, 3, 3], arg88_1: f32[64], arg89_1: f32[64], arg90_1: f32[134, 64, 1, 1], arg91_1: f32[134], arg92_1: f32[64], arg93_1: f32[64], arg94_1: i64[], arg95_1: f32[64], arg96_1: f32[64], arg97_1: i64[], arg98_1: f32[64], arg99_1: f32[64], arg100_1: i64[], arg101_1: f32[64], arg102_1: f32[64], arg103_1: i64[], arg104_1: f32[64], arg105_1: f32[64], arg106_1: i64[], arg107_1: f32[128], arg108_1: f32[128], arg109_1: i64[], arg110_1: f32[128], arg111_1: f32[128], arg112_1: i64[], arg113_1: f32[128], arg114_1: f32[128], arg115_1: i64[], arg116_1: f32[128], arg117_1: f32[128], arg118_1: i64[], arg119_1: f32[128], arg120_1: f32[128], arg121_1: i64[], arg122_1: f32[256], arg123_1: f32[256], arg124_1: i64[], arg125_1: f32[256], arg126_1: f32[256], arg127_1: i64[], arg128_1: f32[256], arg129_1: f32[256], arg130_1: i64[], arg131_1: f32[256], arg132_1: f32[256], arg133_1: i64[], arg134_1: f32[256], arg135_1: f32[256], arg136_1: i64[], arg137_1: f32[512], arg138_1: f32[512], arg139_1: i64[], arg140_1: f32[512], arg141_1: f32[512], arg142_1: i64[], arg143_1: f32[512], arg144_1: f32[512], arg145_1: i64[], arg146_1: f32[512], arg147_1: f32[512], arg148_1: i64[], arg149_1: f32[512], arg150_1: f32[512], arg151_1: i64[], arg152_1: f32[256], arg153_1: f32[256], arg154_1: i64[], arg155_1: f32[256], arg156_1: f32[256], arg157_1: i64[], arg158_1: f32[128], arg159_1: f32[128], arg160_1: i64[], arg161_1: f32[128], arg162_1: f32[128], arg163_1: i64[], arg164_1: f32[64], arg165_1: f32[64], arg166_1: i64[], arg167_1: f32[64], arg168_1: f32[64], arg169_1: i64[], arg170_1: f32[1, 1, 512, 512]):
            # 
            arange: i64[512] = torch.ops.aten.arange.start_step(0, 512, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
            lt: b8[512] = torch.ops.aten.lt.Scalar(arange, 256.0)
            _to_copy: f32[512] = torch.ops.aten._to_copy.default(arange, dtype = torch.float32)
            mul: f32[512] = torch.ops.aten.mul.Tensor(_to_copy, 0.0019569471624266144);  _to_copy = None
            add: f32[512] = torch.ops.aten.add.Tensor(mul, -0.5);  mul = None
            sub: i64[512] = torch.ops.aten.sub.Tensor(511, arange);  arange = None
            _to_copy_1: f32[512] = torch.ops.aten._to_copy.default(sub, dtype = torch.float32);  sub = None
            mul_1: f32[512] = torch.ops.aten.mul.Tensor(_to_copy_1, 0.0019569471624266144);  _to_copy_1 = None
            sub_1: f32[512] = torch.ops.aten.sub.Tensor(0.5, mul_1);  mul_1 = None
            where: f32[512] = torch.ops.aten.where.self(lt, add, sub_1);  lt = add = sub_1 = None
            arange_1: i64[512] = torch.ops.aten.arange.start_step(0, 512, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
            lt_1: b8[512] = torch.ops.aten.lt.Scalar(arange_1, 256.0)
            _to_copy_2: f32[512] = torch.ops.aten._to_copy.default(arange_1, dtype = torch.float32)
            mul_2: f32[512] = torch.ops.aten.mul.Tensor(_to_copy_2, 0.0019569471624266144);  _to_copy_2 = None
            add_1: f32[512] = torch.ops.aten.add.Tensor(mul_2, -0.5);  mul_2 = None
            sub_2: i64[512] = torch.ops.aten.sub.Tensor(511, arange_1);  arange_1 = None
            _to_copy_3: f32[512] = torch.ops.aten._to_copy.default(sub_2, dtype = torch.float32);  sub_2 = None
            mul_3: f32[512] = torch.ops.aten.mul.Tensor(_to_copy_3, 0.0019569471624266144);  _to_copy_3 = None
            sub_3: f32[512] = torch.ops.aten.sub.Tensor(0.5, mul_3);  mul_3 = None
            where_1: f32[512] = torch.ops.aten.where.self(lt_1, add_1, sub_3);  lt_1 = add_1 = sub_3 = None
            view: f32[512, 1] = torch.ops.aten.view.default(where, [-1, 1]);  where = None
            expand: f32[512, 512] = torch.ops.aten.expand.default(view, [512, 512]);  view = None
            view_1: f32[1, 512] = torch.ops.aten.view.default(where_1, [1, -1]);  where_1 = None
            expand_1: f32[512, 512] = torch.ops.aten.expand.default(view_1, [512, 512]);  view_1 = None
            pow_1: f32[512, 512] = torch.ops.aten.pow.Tensor_Scalar(expand_1, 2);  expand_1 = None
            pow_2: f32[512, 512] = torch.ops.aten.pow.Tensor_Scalar(expand, 2);  expand = None
            add_2: f32[512, 512] = torch.ops.aten.add.Tensor(pow_1, pow_2);  pow_1 = pow_2 = None
            sqrt: f32[512, 512] = torch.ops.aten.sqrt.default(add_2);  add_2 = None
            add_3: f32[512, 512] = torch.ops.aten.add.Tensor(sqrt, 1e-06);  sqrt = None
            mul_4: f32[512, 512] = torch.ops.aten.mul.Tensor(add_3, 6.283185307179586);  add_3 = None
            mul_5: f32[512, 512] = torch.ops.aten.mul.Tensor(mul_4, 2.5);  mul_4 = None
            pow_3: f32[512, 512] = torch.ops.aten.pow.Tensor_Scalar(mul_5, 4);  mul_5 = None
            add_4: f32[512, 512] = torch.ops.aten.add.Tensor(pow_3, 1);  pow_3 = None
            reciprocal: f32[512, 512] = torch.ops.aten.reciprocal.default(add_4);  add_4 = None
            mul_6: f32[512, 512] = torch.ops.aten.mul.Tensor(reciprocal, 1.0);  reciprocal = None
            unsqueeze: f32[1, 512, 512] = torch.ops.aten.unsqueeze.default(mul_6, 0);  mul_6 = None
            unsqueeze_1: f32[1, 1, 512, 512] = torch.ops.aten.unsqueeze.default(unsqueeze, 1);  unsqueeze = None
            convolution: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(arg170_1, arg0_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1)
            convolution_1: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(arg170_1, arg1_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1)
            pow_4: f32[1, 1, 512, 512] = torch.ops.aten.pow.Tensor_Scalar(convolution, 2);  convolution = None
            pow_5: f32[1, 1, 512, 512] = torch.ops.aten.pow.Tensor_Scalar(convolution_1, 2);  convolution_1 = None
            add_5: f32[1, 1, 512, 512] = torch.ops.aten.add.Tensor(pow_4, pow_5);  pow_4 = pow_5 = None
            sqrt_1: f32[1, 1, 512, 512] = torch.ops.aten.sqrt.default(add_5);  add_5 = None
            add_6: f32[1, 1, 512, 512] = torch.ops.aten.add.Tensor(sqrt_1, 1e-06);  sqrt_1 = None
            _to_copy_4: c64[1, 1, 512, 512] = torch.ops.aten._to_copy.default(add_6, dtype = torch.complex64);  add_6 = None
            _fft_c2c: c64[1, 1, 512, 512] = torch.ops.aten._fft_c2c.default(_to_copy_4, [2, 3], 0, True);  _to_copy_4 = None
            roll: c64[1, 1, 512, 512] = torch.ops.aten.roll.default(_fft_c2c, [256, 256], [2, 3]);  _fft_c2c = None
            mul_7: c64[1, 1, 512, 512] = torch.ops.aten.mul.Tensor(roll, unsqueeze_1);  roll = None
            roll_1: c64[1, 1, 512, 512] = torch.ops.aten.roll.default(mul_7, [256, 256], [2, 3]);  mul_7 = None
            _fft_c2c_1: c64[1, 1, 512, 512] = torch.ops.aten._fft_c2c.default(roll_1, [2, 3], 2, False);  roll_1 = None
            view_as_real: f32[1, 1, 512, 512, 2] = torch.ops.aten.view_as_real.default(_fft_c2c_1);  _fft_c2c_1 = None
            select: f32[1, 1, 512, 512] = torch.ops.aten.select.int(view_as_real, 4, 0);  view_as_real = None
            _to_copy_5: c64[1, 1, 512, 512] = torch.ops.aten._to_copy.default(arg170_1, dtype = torch.complex64)
            _fft_c2c_2: c64[1, 1, 512, 512] = torch.ops.aten._fft_c2c.default(_to_copy_5, [2, 3], 0, True);  _to_copy_5 = None
            roll_2: c64[1, 1, 512, 512] = torch.ops.aten.roll.default(_fft_c2c_2, [256, 256], [2, 3]);  _fft_c2c_2 = None
            mul_8: c64[1, 1, 512, 512] = torch.ops.aten.mul.Tensor(roll_2, unsqueeze_1);  roll_2 = None
            roll_3: c64[1, 1, 512, 512] = torch.ops.aten.roll.default(mul_8, [256, 256], [2, 3]);  mul_8 = None
            _fft_c2c_3: c64[1, 1, 512, 512] = torch.ops.aten._fft_c2c.default(roll_3, [2, 3], 2, False);  roll_3 = None
            view_as_real_1: f32[1, 1, 512, 512, 2] = torch.ops.aten.view_as_real.default(_fft_c2c_3);  _fft_c2c_3 = None
            select_1: f32[1, 1, 512, 512] = torch.ops.aten.select.int(view_as_real_1, 4, 0);  view_as_real_1 = None
            convolution_2: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(select_1, arg0_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg0_1 = None
            convolution_3: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(select_1, arg1_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg1_1 = None
            pow_6: f32[1, 1, 512, 512] = torch.ops.aten.pow.Tensor_Scalar(convolution_2, 2);  convolution_2 = None
            pow_7: f32[1, 1, 512, 512] = torch.ops.aten.pow.Tensor_Scalar(convolution_3, 2);  convolution_3 = None
            add_7: f32[1, 1, 512, 512] = torch.ops.aten.add.Tensor(pow_6, pow_7);  pow_6 = pow_7 = None
            sqrt_2: f32[1, 1, 512, 512] = torch.ops.aten.sqrt.default(add_7);  add_7 = None
            add_8: f32[1, 1, 512, 512] = torch.ops.aten.add.Tensor(sqrt_2, 1e-06);  sqrt_2 = None
            _to_copy_6: c64[1, 1, 512, 512] = torch.ops.aten._to_copy.default(add_8, dtype = torch.complex64);  add_8 = None
            _fft_c2c_4: c64[1, 1, 512, 512] = torch.ops.aten._fft_c2c.default(_to_copy_6, [2, 3], 0, True);  _to_copy_6 = None
            roll_4: c64[1, 1, 512, 512] = torch.ops.aten.roll.default(_fft_c2c_4, [256, 256], [2, 3]);  _fft_c2c_4 = None
            mul_9: c64[1, 1, 512, 512] = torch.ops.aten.mul.Tensor(roll_4, unsqueeze_1);  roll_4 = unsqueeze_1 = None
            roll_5: c64[1, 1, 512, 512] = torch.ops.aten.roll.default(mul_9, [256, 256], [2, 3]);  mul_9 = None
            _fft_c2c_5: c64[1, 1, 512, 512] = torch.ops.aten._fft_c2c.default(roll_5, [2, 3], 2, False);  roll_5 = None
            view_as_real_2: f32[1, 1, 512, 512, 2] = torch.ops.aten.view_as_real.default(_fft_c2c_5);  _fft_c2c_5 = None
            select_2: f32[1, 1, 512, 512] = torch.ops.aten.select.int(view_as_real_2, 4, 0);  view_as_real_2 = None
            sub_4: f32[1, 1, 512, 512] = torch.ops.aten.sub.Tensor(select, select_2);  select_2 = None
            abs_1: f32[1, 1, 512, 512] = torch.ops.aten.abs.default(select);  select = None
            gt: b8[1, 1, 512, 512] = torch.ops.aten.gt.Scalar(abs_1, 1)
            clamp_min: f32[1, 1, 512, 512] = torch.ops.aten.clamp_min.default(abs_1, 1e-06);  abs_1 = None
            div: f32[1, 1, 512, 512] = torch.ops.aten.div.Tensor(sub_4, clamp_min);  clamp_min = None
            full_like: f32[1, 1, 512, 512] = torch.ops.aten.full_like.default(sub_4, 0, pin_memory = False, memory_format = torch.preserve_format);  sub_4 = None
            where_2: f32[1, 1, 512, 512] = torch.ops.aten.where.self(gt, div, full_like);  gt = div = full_like = None
            sub_5: f32[1, 1, 512, 512] = torch.ops.aten.sub.Tensor(where_2, 0.3);  where_2 = None
            div_1: f32[1, 1, 512, 512] = torch.ops.aten.div.Tensor(sub_5, 0.39999999999999997);  sub_5 = None
            clamp: f32[1, 1, 512, 512] = torch.ops.aten.clamp.default(div_1, 0, 1);  div_1 = None
            mul_10: f32[1, 1, 512, 512] = torch.ops.aten.mul.Tensor(clamp, select_1);  select_1 = None
            sub_6: f32[1, 1, 512, 512] = torch.ops.aten.sub.Tensor(1, clamp);  clamp = None
            mul_11: f32[1, 1, 512, 512] = torch.ops.aten.mul.Tensor(sub_6, arg170_1);  sub_6 = None
            add_9: f32[1, 1, 512, 512] = torch.ops.aten.add.Tensor(mul_10, mul_11);  mul_10 = mul_11 = None
            sub_7: f32[1, 1, 512, 512] = torch.ops.aten.sub.Tensor(arg170_1, add_9);  arg170_1 = add_9 = None
            add_10: f32[1, 1, 512, 512] = torch.ops.aten.add.Tensor(sub_7, 20);  sub_7 = None
            mul_12: f32[1, 1, 512, 512] = torch.ops.aten.mul.Tensor(add_10, 255);  add_10 = None
            div_2: f32[1, 1, 512, 512] = torch.ops.aten.div.Tensor(mul_12, 40);  mul_12 = None
            clamp_1: f32[1, 1, 512, 512] = torch.ops.aten.clamp.default(div_2, 0, 255);  div_2 = None
            mean: f32[1, 1, 1, 1] = torch.ops.aten.mean.dim(clamp_1, [1, 2, 3], True)
            var: f32[1, 1, 1, 1] = torch.ops.aten.var.correction(clamp_1, [1, 2, 3], correction = 1, keepdim = True)
            sub_8: f32[1, 1, 512, 512] = torch.ops.aten.sub.Tensor(clamp_1, mean)
            pow_8: f32[1, 1, 512, 512] = torch.ops.aten.pow.Tensor_Scalar(sub_8, 2);  sub_8 = None
            mul_13: f32[1, 1, 512, 512] = torch.ops.aten.mul.Tensor(pow_8, 1);  pow_8 = None
            clamp_min_1: f32[1, 1, 1, 1] = torch.ops.aten.clamp_min.default(var, 1e-06);  var = None
            div_3: f32[1, 1, 512, 512] = torch.ops.aten.div.Tensor(mul_13, clamp_min_1);  mul_13 = clamp_min_1 = None
            sqrt_3: f32[1, 1, 512, 512] = torch.ops.aten.sqrt.default(div_3);  div_3 = None
            gt_1: b8[1, 1, 512, 512] = torch.ops.aten.gt.Tensor(clamp_1, mean);  mean = None
            add_11: f32[1, 1, 512, 512] = torch.ops.aten.add.Tensor(sqrt_3, 0)
            sub_9: f32[1, 1, 512, 512] = torch.ops.aten.sub.Tensor(0, sqrt_3);  sqrt_3 = None
            where_3: f32[1, 1, 512, 512] = torch.ops.aten.where.self(gt_1, add_11, sub_9);  gt_1 = add_11 = sub_9 = None
            convolution_4: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(where_3, arg2_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg2_1 = None
            convolution_5: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(where_3, arg3_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg3_1 = None
            pow_9: f32[1, 1, 512, 512] = torch.ops.aten.pow.Tensor_Scalar(convolution_4, 2)
            convolution_6: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(pow_9, arg4_1, None, [1, 1], [4, 4], [1, 1], False, [0, 0], 1);  pow_9 = None
            pow_10: f32[1, 1, 512, 512] = torch.ops.aten.pow.Tensor_Scalar(convolution_5, 2)
            convolution_7: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(pow_10, arg4_1, None, [1, 1], [4, 4], [1, 1], False, [0, 0], 1);  pow_10 = None
            neg: f32[1, 1, 512, 512] = torch.ops.aten.neg.default(convolution_4);  convolution_4 = None
            mul_14: f32[1, 1, 512, 512] = torch.ops.aten.mul.Tensor(neg, convolution_5);  neg = convolution_5 = None
            convolution_8: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(mul_14, arg4_1, None, [1, 1], [4, 4], [1, 1], False, [0, 0], 1);  mul_14 = arg4_1 = None
            mul_15: f32[1, 1, 512, 512] = torch.ops.aten.mul.Tensor(convolution_8, 2);  convolution_8 = None
            convolution_9: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(mul_15, arg5_1, None, [1, 1], [4, 4], [1, 1], False, [0, 0], 1);  mul_15 = None
            sub_10: f32[1, 1, 512, 512] = torch.ops.aten.sub.Tensor(convolution_6, convolution_7);  convolution_6 = convolution_7 = None
            convolution_10: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(sub_10, arg5_1, None, [1, 1], [4, 4], [1, 1], False, [0, 0], 1);  sub_10 = arg5_1 = None
            cat: f32[1, 3, 512, 512] = torch.ops.aten.cat.default([where_3, convolution_9, convolution_10], 1);  where_3 = convolution_9 = convolution_10 = None
            convolution_11: f32[1, 64, 256, 256] = torch.ops.aten.convolution.default(cat, arg6_1, None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1);  cat = arg6_1 = None
            _native_batch_norm_legit_no_training = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_11, arg7_1, arg8_1, arg92_1, arg93_1, 0.1, 1e-05);  convolution_11 = arg7_1 = arg8_1 = arg92_1 = arg93_1 = None
            getitem: f32[1, 64, 256, 256] = _native_batch_norm_legit_no_training[0];  _native_batch_norm_legit_no_training = None
            relu: f32[1, 64, 256, 256] = torch.ops.aten.relu.default(getitem);  getitem = None
            max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default(relu, [3, 3], [2, 2], [1, 1]);  relu = None
            getitem_1: f32[1, 64, 128, 128] = max_pool2d_with_indices[0];  max_pool2d_with_indices = None
            convolution_12: f32[1, 64, 128, 128] = torch.ops.aten.convolution.default(getitem_1, arg9_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg9_1 = None
            _native_batch_norm_legit_no_training_1 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_12, arg10_1, arg11_1, arg95_1, arg96_1, 0.1, 1e-05);  convolution_12 = arg10_1 = arg11_1 = arg95_1 = arg96_1 = None
            getitem_2: f32[1, 64, 128, 128] = _native_batch_norm_legit_no_training_1[0];  _native_batch_norm_legit_no_training_1 = None
            relu_1: f32[1, 64, 128, 128] = torch.ops.aten.relu.default(getitem_2);  getitem_2 = None
            convolution_13: f32[1, 64, 128, 128] = torch.ops.aten.convolution.default(relu_1, arg12_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  relu_1 = arg12_1 = None
            _native_batch_norm_legit_no_training_2 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_13, arg13_1, arg14_1, arg98_1, arg99_1, 0.1, 1e-05);  convolution_13 = arg13_1 = arg14_1 = arg98_1 = arg99_1 = None
            getitem_3: f32[1, 64, 128, 128] = _native_batch_norm_legit_no_training_2[0];  _native_batch_norm_legit_no_training_2 = None
            add_12: f32[1, 64, 128, 128] = torch.ops.aten.add.Tensor(getitem_3, getitem_1);  getitem_3 = getitem_1 = None
            relu_2: f32[1, 64, 128, 128] = torch.ops.aten.relu.default(add_12);  add_12 = None
            convolution_14: f32[1, 64, 128, 128] = torch.ops.aten.convolution.default(relu_2, arg15_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg15_1 = None
            _native_batch_norm_legit_no_training_3 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_14, arg16_1, arg17_1, arg101_1, arg102_1, 0.1, 1e-05);  convolution_14 = arg16_1 = arg17_1 = arg101_1 = arg102_1 = None
            getitem_4: f32[1, 64, 128, 128] = _native_batch_norm_legit_no_training_3[0];  _native_batch_norm_legit_no_training_3 = None
            relu_3: f32[1, 64, 128, 128] = torch.ops.aten.relu.default(getitem_4);  getitem_4 = None
            convolution_15: f32[1, 64, 128, 128] = torch.ops.aten.convolution.default(relu_3, arg18_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  relu_3 = arg18_1 = None
            _native_batch_norm_legit_no_training_4 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_15, arg19_1, arg20_1, arg104_1, arg105_1, 0.1, 1e-05);  convolution_15 = arg19_1 = arg20_1 = arg104_1 = arg105_1 = None
            getitem_5: f32[1, 64, 128, 128] = _native_batch_norm_legit_no_training_4[0];  _native_batch_norm_legit_no_training_4 = None
            add_13: f32[1, 64, 128, 128] = torch.ops.aten.add.Tensor(getitem_5, relu_2);  getitem_5 = relu_2 = None
            relu_4: f32[1, 64, 128, 128] = torch.ops.aten.relu.default(add_13);  add_13 = None
            convolution_16: f32[1, 128, 64, 64] = torch.ops.aten.convolution.default(relu_4, arg21_1, None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1);  arg21_1 = None
            _native_batch_norm_legit_no_training_5 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_16, arg22_1, arg23_1, arg107_1, arg108_1, 0.1, 1e-05);  convolution_16 = arg22_1 = arg23_1 = arg107_1 = arg108_1 = None
            getitem_6: f32[1, 128, 64, 64] = _native_batch_norm_legit_no_training_5[0];  _native_batch_norm_legit_no_training_5 = None
            relu_5: f32[1, 128, 64, 64] = torch.ops.aten.relu.default(getitem_6);  getitem_6 = None
            convolution_17: f32[1, 128, 64, 64] = torch.ops.aten.convolution.default(relu_5, arg24_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  relu_5 = arg24_1 = None
            _native_batch_norm_legit_no_training_6 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_17, arg25_1, arg26_1, arg110_1, arg111_1, 0.1, 1e-05);  convolution_17 = arg25_1 = arg26_1 = arg110_1 = arg111_1 = None
            getitem_7: f32[1, 128, 64, 64] = _native_batch_norm_legit_no_training_6[0];  _native_batch_norm_legit_no_training_6 = None
            convolution_18: f32[1, 128, 64, 64] = torch.ops.aten.convolution.default(relu_4, arg27_1, None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1);  arg27_1 = None
            _native_batch_norm_legit_no_training_7 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_18, arg28_1, arg29_1, arg113_1, arg114_1, 0.1, 1e-05);  convolution_18 = arg28_1 = arg29_1 = arg113_1 = arg114_1 = None
            getitem_8: f32[1, 128, 64, 64] = _native_batch_norm_legit_no_training_7[0];  _native_batch_norm_legit_no_training_7 = None
            add_14: f32[1, 128, 64, 64] = torch.ops.aten.add.Tensor(getitem_7, getitem_8);  getitem_7 = getitem_8 = None
            relu_6: f32[1, 128, 64, 64] = torch.ops.aten.relu.default(add_14);  add_14 = None
            convolution_19: f32[1, 128, 64, 64] = torch.ops.aten.convolution.default(relu_6, arg30_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg30_1 = None
            _native_batch_norm_legit_no_training_8 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_19, arg31_1, arg32_1, arg116_1, arg117_1, 0.1, 1e-05);  convolution_19 = arg31_1 = arg32_1 = arg116_1 = arg117_1 = None
            getitem_9: f32[1, 128, 64, 64] = _native_batch_norm_legit_no_training_8[0];  _native_batch_norm_legit_no_training_8 = None
            relu_7: f32[1, 128, 64, 64] = torch.ops.aten.relu.default(getitem_9);  getitem_9 = None
            convolution_20: f32[1, 128, 64, 64] = torch.ops.aten.convolution.default(relu_7, arg33_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  relu_7 = arg33_1 = None
            _native_batch_norm_legit_no_training_9 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_20, arg34_1, arg35_1, arg119_1, arg120_1, 0.1, 1e-05);  convolution_20 = arg34_1 = arg35_1 = arg119_1 = arg120_1 = None
            getitem_10: f32[1, 128, 64, 64] = _native_batch_norm_legit_no_training_9[0];  _native_batch_norm_legit_no_training_9 = None
            add_15: f32[1, 128, 64, 64] = torch.ops.aten.add.Tensor(getitem_10, relu_6);  getitem_10 = relu_6 = None
            relu_8: f32[1, 128, 64, 64] = torch.ops.aten.relu.default(add_15);  add_15 = None
            convolution_21: f32[1, 256, 32, 32] = torch.ops.aten.convolution.default(relu_8, arg36_1, None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1);  arg36_1 = None
            _native_batch_norm_legit_no_training_10 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_21, arg37_1, arg38_1, arg122_1, arg123_1, 0.1, 1e-05);  convolution_21 = arg37_1 = arg38_1 = arg122_1 = arg123_1 = None
            getitem_11: f32[1, 256, 32, 32] = _native_batch_norm_legit_no_training_10[0];  _native_batch_norm_legit_no_training_10 = None
            relu_9: f32[1, 256, 32, 32] = torch.ops.aten.relu.default(getitem_11);  getitem_11 = None
            convolution_22: f32[1, 256, 32, 32] = torch.ops.aten.convolution.default(relu_9, arg39_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  relu_9 = arg39_1 = None
            _native_batch_norm_legit_no_training_11 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_22, arg40_1, arg41_1, arg125_1, arg126_1, 0.1, 1e-05);  convolution_22 = arg40_1 = arg41_1 = arg125_1 = arg126_1 = None
            getitem_12: f32[1, 256, 32, 32] = _native_batch_norm_legit_no_training_11[0];  _native_batch_norm_legit_no_training_11 = None
            convolution_23: f32[1, 256, 32, 32] = torch.ops.aten.convolution.default(relu_8, arg42_1, None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1);  arg42_1 = None
            _native_batch_norm_legit_no_training_12 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_23, arg43_1, arg44_1, arg128_1, arg129_1, 0.1, 1e-05);  convolution_23 = arg43_1 = arg44_1 = arg128_1 = arg129_1 = None
            getitem_13: f32[1, 256, 32, 32] = _native_batch_norm_legit_no_training_12[0];  _native_batch_norm_legit_no_training_12 = None
            add_16: f32[1, 256, 32, 32] = torch.ops.aten.add.Tensor(getitem_12, getitem_13);  getitem_12 = getitem_13 = None
            relu_10: f32[1, 256, 32, 32] = torch.ops.aten.relu.default(add_16);  add_16 = None
            convolution_24: f32[1, 256, 32, 32] = torch.ops.aten.convolution.default(relu_10, arg45_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg45_1 = None
            _native_batch_norm_legit_no_training_13 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_24, arg46_1, arg47_1, arg131_1, arg132_1, 0.1, 1e-05);  convolution_24 = arg46_1 = arg47_1 = arg131_1 = arg132_1 = None
            getitem_14: f32[1, 256, 32, 32] = _native_batch_norm_legit_no_training_13[0];  _native_batch_norm_legit_no_training_13 = None
            relu_11: f32[1, 256, 32, 32] = torch.ops.aten.relu.default(getitem_14);  getitem_14 = None
            convolution_25: f32[1, 256, 32, 32] = torch.ops.aten.convolution.default(relu_11, arg48_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  relu_11 = arg48_1 = None
            _native_batch_norm_legit_no_training_14 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_25, arg49_1, arg50_1, arg134_1, arg135_1, 0.1, 1e-05);  convolution_25 = arg49_1 = arg50_1 = arg134_1 = arg135_1 = None
            getitem_15: f32[1, 256, 32, 32] = _native_batch_norm_legit_no_training_14[0];  _native_batch_norm_legit_no_training_14 = None
            add_17: f32[1, 256, 32, 32] = torch.ops.aten.add.Tensor(getitem_15, relu_10);  getitem_15 = relu_10 = None
            relu_12: f32[1, 256, 32, 32] = torch.ops.aten.relu.default(add_17);  add_17 = None
            convolution_26: f32[1, 512, 16, 16] = torch.ops.aten.convolution.default(relu_12, arg51_1, None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1);  arg51_1 = None
            _native_batch_norm_legit_no_training_15 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_26, arg52_1, arg53_1, arg137_1, arg138_1, 0.1, 1e-05);  convolution_26 = arg52_1 = arg53_1 = arg137_1 = arg138_1 = None
            getitem_16: f32[1, 512, 16, 16] = _native_batch_norm_legit_no_training_15[0];  _native_batch_norm_legit_no_training_15 = None
            relu_13: f32[1, 512, 16, 16] = torch.ops.aten.relu.default(getitem_16);  getitem_16 = None
            convolution_27: f32[1, 512, 16, 16] = torch.ops.aten.convolution.default(relu_13, arg54_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  relu_13 = arg54_1 = None
            _native_batch_norm_legit_no_training_16 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_27, arg55_1, arg56_1, arg140_1, arg141_1, 0.1, 1e-05);  convolution_27 = arg55_1 = arg56_1 = arg140_1 = arg141_1 = None
            getitem_17: f32[1, 512, 16, 16] = _native_batch_norm_legit_no_training_16[0];  _native_batch_norm_legit_no_training_16 = None
            convolution_28: f32[1, 512, 16, 16] = torch.ops.aten.convolution.default(relu_12, arg57_1, None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1);  arg57_1 = None
            _native_batch_norm_legit_no_training_17 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_28, arg58_1, arg59_1, arg143_1, arg144_1, 0.1, 1e-05);  convolution_28 = arg58_1 = arg59_1 = arg143_1 = arg144_1 = None
            getitem_18: f32[1, 512, 16, 16] = _native_batch_norm_legit_no_training_17[0];  _native_batch_norm_legit_no_training_17 = None
            add_18: f32[1, 512, 16, 16] = torch.ops.aten.add.Tensor(getitem_17, getitem_18);  getitem_17 = getitem_18 = None
            relu_14: f32[1, 512, 16, 16] = torch.ops.aten.relu.default(add_18);  add_18 = None
            convolution_29: f32[1, 512, 16, 16] = torch.ops.aten.convolution.default(relu_14, arg60_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg60_1 = None
            _native_batch_norm_legit_no_training_18 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_29, arg61_1, arg62_1, arg146_1, arg147_1, 0.1, 1e-05);  convolution_29 = arg61_1 = arg62_1 = arg146_1 = arg147_1 = None
            getitem_19: f32[1, 512, 16, 16] = _native_batch_norm_legit_no_training_18[0];  _native_batch_norm_legit_no_training_18 = None
            relu_15: f32[1, 512, 16, 16] = torch.ops.aten.relu.default(getitem_19);  getitem_19 = None
            convolution_30: f32[1, 512, 16, 16] = torch.ops.aten.convolution.default(relu_15, arg63_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  relu_15 = arg63_1 = None
            _native_batch_norm_legit_no_training_19 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_30, arg64_1, arg65_1, arg149_1, arg150_1, 0.1, 1e-05);  convolution_30 = arg64_1 = arg65_1 = arg149_1 = arg150_1 = None
            getitem_20: f32[1, 512, 16, 16] = _native_batch_norm_legit_no_training_19[0];  _native_batch_norm_legit_no_training_19 = None
            add_19: f32[1, 512, 16, 16] = torch.ops.aten.add.Tensor(getitem_20, relu_14);  getitem_20 = relu_14 = None
            relu_16: f32[1, 512, 16, 16] = torch.ops.aten.relu.default(add_19);  add_19 = None
            convolution_31: f32[1, 256, 32, 32] = torch.ops.aten.convolution.default(relu_16, arg66_1, arg67_1, [2, 2], [0, 0], [1, 1], True, [0, 0], 1);  relu_16 = arg66_1 = arg67_1 = None
            cat_1: f32[1, 512, 32, 32] = torch.ops.aten.cat.default([relu_12, convolution_31], 1);  relu_12 = convolution_31 = None
            convolution_32: f32[1, 256, 32, 32] = torch.ops.aten.convolution.default(cat_1, arg68_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  cat_1 = arg68_1 = None
            _native_batch_norm_legit_no_training_20 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_32, arg69_1, arg70_1, arg152_1, arg153_1, 0.1, 1e-05);  convolution_32 = arg69_1 = arg70_1 = arg152_1 = arg153_1 = None
            getitem_21: f32[1, 256, 32, 32] = _native_batch_norm_legit_no_training_20[0];  _native_batch_norm_legit_no_training_20 = None
            leaky_relu: f32[1, 256, 32, 32] = torch.ops.aten.leaky_relu.default(getitem_21);  getitem_21 = None
            convolution_33: f32[1, 256, 32, 32] = torch.ops.aten.convolution.default(leaky_relu, arg71_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  leaky_relu = arg71_1 = None
            _native_batch_norm_legit_no_training_21 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_33, arg72_1, arg73_1, arg155_1, arg156_1, 0.1, 1e-05);  convolution_33 = arg72_1 = arg73_1 = arg155_1 = arg156_1 = None
            getitem_22: f32[1, 256, 32, 32] = _native_batch_norm_legit_no_training_21[0];  _native_batch_norm_legit_no_training_21 = None
            leaky_relu_1: f32[1, 256, 32, 32] = torch.ops.aten.leaky_relu.default(getitem_22);  getitem_22 = None
            convolution_34: f32[1, 128, 64, 64] = torch.ops.aten.convolution.default(leaky_relu_1, arg74_1, arg75_1, [2, 2], [0, 0], [1, 1], True, [0, 0], 1);  leaky_relu_1 = arg74_1 = arg75_1 = None
            cat_2: f32[1, 256, 64, 64] = torch.ops.aten.cat.default([relu_8, convolution_34], 1);  relu_8 = convolution_34 = None
            convolution_35: f32[1, 128, 64, 64] = torch.ops.aten.convolution.default(cat_2, arg76_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  cat_2 = arg76_1 = None
            _native_batch_norm_legit_no_training_22 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_35, arg77_1, arg78_1, arg158_1, arg159_1, 0.1, 1e-05);  convolution_35 = arg77_1 = arg78_1 = arg158_1 = arg159_1 = None
            getitem_23: f32[1, 128, 64, 64] = _native_batch_norm_legit_no_training_22[0];  _native_batch_norm_legit_no_training_22 = None
            leaky_relu_2: f32[1, 128, 64, 64] = torch.ops.aten.leaky_relu.default(getitem_23);  getitem_23 = None
            convolution_36: f32[1, 128, 64, 64] = torch.ops.aten.convolution.default(leaky_relu_2, arg79_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  leaky_relu_2 = arg79_1 = None
            _native_batch_norm_legit_no_training_23 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_36, arg80_1, arg81_1, arg161_1, arg162_1, 0.1, 1e-05);  convolution_36 = arg80_1 = arg81_1 = arg161_1 = arg162_1 = None
            getitem_24: f32[1, 128, 64, 64] = _native_batch_norm_legit_no_training_23[0];  _native_batch_norm_legit_no_training_23 = None
            leaky_relu_3: f32[1, 128, 64, 64] = torch.ops.aten.leaky_relu.default(getitem_24);  getitem_24 = None
            convolution_37: f32[1, 64, 128, 128] = torch.ops.aten.convolution.default(leaky_relu_3, arg82_1, arg83_1, [2, 2], [0, 0], [1, 1], True, [0, 0], 1);  leaky_relu_3 = arg82_1 = arg83_1 = None
            cat_3: f32[1, 128, 128, 128] = torch.ops.aten.cat.default([relu_4, convolution_37], 1);  relu_4 = convolution_37 = None
            convolution_38: f32[1, 64, 128, 128] = torch.ops.aten.convolution.default(cat_3, arg84_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  cat_3 = arg84_1 = None
            _native_batch_norm_legit_no_training_24 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_38, arg85_1, arg86_1, arg164_1, arg165_1, 0.1, 1e-05);  convolution_38 = arg85_1 = arg86_1 = arg164_1 = arg165_1 = None
            getitem_25: f32[1, 64, 128, 128] = _native_batch_norm_legit_no_training_24[0];  _native_batch_norm_legit_no_training_24 = None
            leaky_relu_4: f32[1, 64, 128, 128] = torch.ops.aten.leaky_relu.default(getitem_25);  getitem_25 = None
            convolution_39: f32[1, 64, 128, 128] = torch.ops.aten.convolution.default(leaky_relu_4, arg87_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  leaky_relu_4 = arg87_1 = None
            _native_batch_norm_legit_no_training_25 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_39, arg88_1, arg89_1, arg167_1, arg168_1, 0.1, 1e-05);  convolution_39 = arg88_1 = arg89_1 = arg167_1 = arg168_1 = None
            getitem_26: f32[1, 64, 128, 128] = _native_batch_norm_legit_no_training_25[0];  _native_batch_norm_legit_no_training_25 = None
            leaky_relu_5: f32[1, 64, 128, 128] = torch.ops.aten.leaky_relu.default(getitem_26);  getitem_26 = None
            convolution_40: f32[1, 134, 128, 128] = torch.ops.aten.convolution.default(leaky_relu_5, arg90_1, arg91_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  leaky_relu_5 = arg90_1 = arg91_1 = None
            split_with_sizes = torch.ops.aten.split_with_sizes.default(convolution_40, [1, 33, 33, 33, 33, 1], 1);  convolution_40 = None
            getitem_27: f32[1, 1, 128, 128] = split_with_sizes[0]
            getitem_28: f32[1, 33, 128, 128] = split_with_sizes[1]
            getitem_29: f32[1, 33, 128, 128] = split_with_sizes[2]
            getitem_30: f32[1, 33, 128, 128] = split_with_sizes[3]
            getitem_31: f32[1, 33, 128, 128] = split_with_sizes[4]
            getitem_32: f32[1, 1, 128, 128] = split_with_sizes[5];  split_with_sizes = None
            sigmoid: f32[1, 1, 128, 128] = torch.ops.aten.sigmoid.default(getitem_27);  getitem_27 = None
            sigmoid_1: f32[1, 1, 128, 128] = torch.ops.aten.sigmoid.default(getitem_32);  getitem_32 = None
            alias: f32[1, 1, 128, 128] = torch.ops.aten.alias.default(sigmoid)
            mul_16: f32[1, 1, 128, 128] = torch.ops.aten.mul.Tensor(sigmoid_1, alias);  sigmoid_1 = alias = None
            _tensor_constant0: i64[2] = self._tensor_constant0
            lift_fresh_copy: i64[2] = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
            clone: i64[2] = torch.ops.aten.clone.default(lift_fresh_copy);  lift_fresh_copy = None
            _tensor_constant1: f64[] = self._tensor_constant1
            lift_fresh_copy_1: f64[] = torch.ops.aten.lift_fresh_copy.default(_tensor_constant1);  _tensor_constant1 = None
            mul_17: f64[2] = torch.ops.aten.mul.Tensor(lift_fresh_copy_1, clone);  lift_fresh_copy_1 = clone = None
            _tensor_constant2: f64[] = self._tensor_constant2
            lift_fresh_copy_2: f64[] = torch.ops.aten.lift_fresh_copy.default(_tensor_constant2);  _tensor_constant2 = None
            div_4: f64[2] = torch.ops.aten.div.Tensor(mul_17, lift_fresh_copy_2);  mul_17 = lift_fresh_copy_2 = None
            _softmax: f32[1, 33, 128, 128] = torch.ops.aten._softmax.default(getitem_28, 1, False)
            _softmax_1: f32[1, 33, 128, 128] = torch.ops.aten._softmax.default(getitem_29, 1, False)
            _softmax_2: f32[1, 33, 128, 128] = torch.ops.aten._softmax.default(getitem_30, 1, False)
            _softmax_3: f32[1, 33, 128, 128] = torch.ops.aten._softmax.default(getitem_31, 1, False)
            arange_2: i64[34] = torch.ops.aten.arange.start_step(0, 34, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
            lt_2: b8[34] = torch.ops.aten.lt.Scalar(arange_2, 17.0)
            _to_copy_7: f32[34] = torch.ops.aten._to_copy.default(arange_2, dtype = torch.float32)
            mul_18: f32[34] = torch.ops.aten.mul.Tensor(_to_copy_7, 0.06060606060606061);  _to_copy_7 = None
            add_20: f32[34] = torch.ops.aten.add.Tensor(mul_18, -1);  mul_18 = None
            sub_11: i64[34] = torch.ops.aten.sub.Tensor(33, arange_2);  arange_2 = None
            _to_copy_8: f32[34] = torch.ops.aten._to_copy.default(sub_11, dtype = torch.float32);  sub_11 = None
            mul_19: f32[34] = torch.ops.aten.mul.Tensor(_to_copy_8, 0.06060606060606061);  _to_copy_8 = None
            sub_12: f32[34] = torch.ops.aten.sub.Tensor(1, mul_19);  mul_19 = None
            where_4: f32[34] = torch.ops.aten.where.self(lt_2, add_20, sub_12);  lt_2 = add_20 = sub_12 = None
            clamp_2: f32[34] = torch.ops.aten.clamp.default(where_4, -1, 1);  where_4 = None
            abs_2: f32[34] = torch.ops.aten.abs.default(clamp_2)
            sub_13: f32[34] = torch.ops.aten.sub.Tensor(2, abs_2);  abs_2 = None
            div_5: f32[34] = torch.ops.aten.div.Tensor(clamp_2, sub_13);  clamp_2 = sub_13 = None
            slice_1: f32[33] = torch.ops.aten.slice.Tensor(div_5, 0, 0, -1)
            slice_2: f32[33] = torch.ops.aten.slice.Tensor(div_5, 0, 1, 9223372036854775807);  div_5 = None
            add_21: f32[33] = torch.ops.aten.add.Tensor(slice_1, slice_2);  slice_1 = slice_2 = None
            div_6: f32[33] = torch.ops.aten.div.Tensor(add_21, 2);  add_21 = None
            view_2: f32[1, 33, 1, 1] = torch.ops.aten.view.default(div_6, [1, -1, 1, 1]);  div_6 = None
            select_3: f64[] = torch.ops.aten.select.int(div_4, 0, 1)
            mul_20: f32[1, 33, 1, 1] = torch.ops.aten.mul.Tensor(view_2, select_3);  view_2 = select_3 = None
            arange_3: i64[34] = torch.ops.aten.arange.start_step(0, 34, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
            lt_3: b8[34] = torch.ops.aten.lt.Scalar(arange_3, 17.0)
            _to_copy_9: f32[34] = torch.ops.aten._to_copy.default(arange_3, dtype = torch.float32)
            mul_21: f32[34] = torch.ops.aten.mul.Tensor(_to_copy_9, 0.06060606060606061);  _to_copy_9 = None
            add_22: f32[34] = torch.ops.aten.add.Tensor(mul_21, -1);  mul_21 = None
            sub_14: i64[34] = torch.ops.aten.sub.Tensor(33, arange_3);  arange_3 = None
            _to_copy_10: f32[34] = torch.ops.aten._to_copy.default(sub_14, dtype = torch.float32);  sub_14 = None
            mul_22: f32[34] = torch.ops.aten.mul.Tensor(_to_copy_10, 0.06060606060606061);  _to_copy_10 = None
            sub_15: f32[34] = torch.ops.aten.sub.Tensor(1, mul_22);  mul_22 = None
            where_5: f32[34] = torch.ops.aten.where.self(lt_3, add_22, sub_15);  lt_3 = add_22 = sub_15 = None
            clamp_3: f32[34] = torch.ops.aten.clamp.default(where_5, -1, 1);  where_5 = None
            abs_3: f32[34] = torch.ops.aten.abs.default(clamp_3)
            sub_16: f32[34] = torch.ops.aten.sub.Tensor(2, abs_3);  abs_3 = None
            div_7: f32[34] = torch.ops.aten.div.Tensor(clamp_3, sub_16);  clamp_3 = sub_16 = None
            slice_3: f32[33] = torch.ops.aten.slice.Tensor(div_7, 0, 0, -1)
            slice_4: f32[33] = torch.ops.aten.slice.Tensor(div_7, 0, 1, 9223372036854775807);  div_7 = None
            add_23: f32[33] = torch.ops.aten.add.Tensor(slice_3, slice_4);  slice_3 = slice_4 = None
            div_8: f32[33] = torch.ops.aten.div.Tensor(add_23, 2);  add_23 = None
            view_3: f32[1, 33, 1, 1] = torch.ops.aten.view.default(div_8, [1, -1, 1, 1]);  div_8 = None
            select_4: f64[] = torch.ops.aten.select.int(div_4, 0, 0);  div_4 = None
            mul_23: f32[1, 33, 1, 1] = torch.ops.aten.mul.Tensor(view_3, select_4);  view_3 = select_4 = None
            mul_24: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(_softmax, mul_20)
            sum_1: f32[1, 1, 128, 128] = torch.ops.aten.sum.dim_IntList(mul_24, [1], True);  mul_24 = None
            mul_25: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(_softmax_1, mul_23)
            sum_2: f32[1, 1, 128, 128] = torch.ops.aten.sum.dim_IntList(mul_25, [1], True);  mul_25 = None
            mul_26: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(_softmax_2, mul_20);  _softmax_2 = None
            sum_3: f32[1, 1, 128, 128] = torch.ops.aten.sum.dim_IntList(mul_26, [1], True);  mul_26 = None
            mul_27: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(_softmax_3, mul_23);  _softmax_3 = None
            sum_4: f32[1, 1, 128, 128] = torch.ops.aten.sum.dim_IntList(mul_27, [1], True);  mul_27 = None
            alias_1: f32[1, 1, 128, 128] = torch.ops.aten.alias.default(sum_2)
            mul_28: f32[1, 1, 128, 128] = torch.ops.aten.mul.Tensor(sum_3, alias_1);  alias_1 = None
            alias_2: f32[1, 1, 128, 128] = torch.ops.aten.alias.default(sum_1)
            mul_29: f32[1, 1, 128, 128] = torch.ops.aten.mul.Tensor(sum_4, alias_2);  alias_2 = None
            sub_17: f32[1, 1, 128, 128] = torch.ops.aten.sub.Tensor(mul_28, mul_29);  mul_28 = mul_29 = None
            alias_3: f32[1, 1, 128, 128] = torch.ops.aten.alias.default(sum_1)
            mul_30: f32[1, 1, 128, 128] = torch.ops.aten.mul.Tensor(sum_3, alias_3);  alias_3 = None
            alias_4: f32[1, 1, 128, 128] = torch.ops.aten.alias.default(sum_2)
            mul_31: f32[1, 1, 128, 128] = torch.ops.aten.mul.Tensor(sum_4, alias_4);  alias_4 = None
            add_24: f32[1, 1, 128, 128] = torch.ops.aten.add.Tensor(mul_30, mul_31);  mul_30 = mul_31 = None
            mul_32: f32[1, 1, 128, 128] = torch.ops.aten.mul.Tensor(sub_17, mul_16);  sub_17 = None
            mean_1: f32[1] = torch.ops.aten.mean.dim(mul_32, [1, 2, 3]);  mul_32 = None
            mul_33: f32[1, 1, 128, 128] = torch.ops.aten.mul.Tensor(add_24, mul_16);  add_24 = None
            mean_2: f32[1] = torch.ops.aten.mean.dim(mul_33, [1, 2, 3]);  mul_33 = None
            atan2: f32[1] = torch.ops.aten.atan2.default(mean_1, mean_2);  mean_1 = mean_2 = None
            mul_34: f32[1] = torch.ops.aten.mul.Tensor(atan2, 180);  atan2 = None
            div_9: f32[1] = torch.ops.aten.div.Tensor(mul_34, 3.141592653589793);  mul_34 = None
            arange_4: i64[128] = torch.ops.aten.arange.start_step(0, 128, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
            lt_4: b8[128] = torch.ops.aten.lt.Scalar(arange_4, 64.0)
            _to_copy_11: f32[128] = torch.ops.aten._to_copy.default(arange_4, dtype = torch.float32)
            mul_35: f32[128] = torch.ops.aten.mul.Tensor(_to_copy_11, 0.015748031496062992);  _to_copy_11 = None
            add_25: f32[128] = torch.ops.aten.add.Tensor(mul_35, -1);  mul_35 = None
            sub_18: i64[128] = torch.ops.aten.sub.Tensor(127, arange_4);  arange_4 = None
            _to_copy_12: f32[128] = torch.ops.aten._to_copy.default(sub_18, dtype = torch.float32);  sub_18 = None
            mul_36: f32[128] = torch.ops.aten.mul.Tensor(_to_copy_12, 0.015748031496062992);  _to_copy_12 = None
            sub_19: f32[128] = torch.ops.aten.sub.Tensor(1, mul_36);  mul_36 = None
            where_6: f32[128] = torch.ops.aten.where.self(lt_4, add_25, sub_19);  lt_4 = add_25 = sub_19 = None
            arange_5: i64[128] = torch.ops.aten.arange.start_step(0, 128, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
            lt_5: b8[128] = torch.ops.aten.lt.Scalar(arange_5, 64.0)
            _to_copy_13: f32[128] = torch.ops.aten._to_copy.default(arange_5, dtype = torch.float32)
            mul_37: f32[128] = torch.ops.aten.mul.Tensor(_to_copy_13, 0.015748031496062992);  _to_copy_13 = None
            add_26: f32[128] = torch.ops.aten.add.Tensor(mul_37, -1);  mul_37 = None
            sub_20: i64[128] = torch.ops.aten.sub.Tensor(127, arange_5);  arange_5 = None
            _to_copy_14: f32[128] = torch.ops.aten._to_copy.default(sub_20, dtype = torch.float32);  sub_20 = None
            mul_38: f32[128] = torch.ops.aten.mul.Tensor(_to_copy_14, 0.015748031496062992);  _to_copy_14 = None
            sub_21: f32[128] = torch.ops.aten.sub.Tensor(1, mul_38);  mul_38 = None
            where_7: f32[128] = torch.ops.aten.where.self(lt_5, add_26, sub_21);  lt_5 = add_26 = sub_21 = None
            view_4: f32[128, 1] = torch.ops.aten.view.default(where_6, [-1, 1]);  where_6 = None
            expand_2: f32[128, 128] = torch.ops.aten.expand.default(view_4, [128, 128]);  view_4 = None
            view_5: f32[1, 128] = torch.ops.aten.view.default(where_7, [1, -1]);  where_7 = None
            expand_3: f32[128, 128] = torch.ops.aten.expand.default(view_5, [128, 128]);  view_5 = None
            add_27: f32[128, 128] = torch.ops.aten.add.Tensor(expand_3, 1);  expand_3 = None
            div_10: f32[128, 128] = torch.ops.aten.div.Tensor(add_27, 2);  add_27 = None
            mul_39: f32[128, 128] = torch.ops.aten.mul.Tensor(div_10, 511);  div_10 = None
            add_28: f32[128, 128] = torch.ops.aten.add.Tensor(expand_2, 1);  expand_2 = None
            div_11: f32[128, 128] = torch.ops.aten.div.Tensor(add_28, 2);  add_28 = None
            mul_40: f32[128, 128] = torch.ops.aten.mul.Tensor(div_11, 511);  div_11 = None
            arange_6: i64[34] = torch.ops.aten.arange.start_step(0, 34, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
            lt_6: b8[34] = torch.ops.aten.lt.Scalar(arange_6, 17.0)
            _to_copy_15: f32[34] = torch.ops.aten._to_copy.default(arange_6, dtype = torch.float32)
            mul_41: f32[34] = torch.ops.aten.mul.Tensor(_to_copy_15, 0.06060606060606061);  _to_copy_15 = None
            add_29: f32[34] = torch.ops.aten.add.Tensor(mul_41, -1);  mul_41 = None
            sub_22: i64[34] = torch.ops.aten.sub.Tensor(33, arange_6);  arange_6 = None
            _to_copy_16: f32[34] = torch.ops.aten._to_copy.default(sub_22, dtype = torch.float32);  sub_22 = None
            mul_42: f32[34] = torch.ops.aten.mul.Tensor(_to_copy_16, 0.06060606060606061);  _to_copy_16 = None
            sub_23: f32[34] = torch.ops.aten.sub.Tensor(1, mul_42);  mul_42 = None
            where_8: f32[34] = torch.ops.aten.where.self(lt_6, add_29, sub_23);  lt_6 = add_29 = sub_23 = None
            clamp_4: f32[34] = torch.ops.aten.clamp.default(where_8, -1, 1);  where_8 = None
            abs_4: f32[34] = torch.ops.aten.abs.default(clamp_4)
            sub_24: f32[34] = torch.ops.aten.sub.Tensor(2, abs_4);  abs_4 = None
            div_12: f32[34] = torch.ops.aten.div.Tensor(clamp_4, sub_24);  clamp_4 = sub_24 = None
            slice_5: f32[33] = torch.ops.aten.slice.Tensor(div_12, 0, 1, 9223372036854775807)
            slice_6: f32[33] = torch.ops.aten.slice.Tensor(div_12, 0, 0, -1);  div_12 = None
            sub_25: f32[33] = torch.ops.aten.sub.Tensor(slice_5, slice_6);  slice_5 = slice_6 = None
            view_6: f32[1, 33, 1, 1] = torch.ops.aten.view.default(sub_25, [1, -1, 1, 1]);  sub_25 = None
            arange_7: i64[34] = torch.ops.aten.arange.start_step(0, 34, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
            lt_7: b8[34] = torch.ops.aten.lt.Scalar(arange_7, 17.0)
            _to_copy_17: f32[34] = torch.ops.aten._to_copy.default(arange_7, dtype = torch.float32)
            mul_43: f32[34] = torch.ops.aten.mul.Tensor(_to_copy_17, 0.06060606060606061);  _to_copy_17 = None
            add_30: f32[34] = torch.ops.aten.add.Tensor(mul_43, -1);  mul_43 = None
            sub_26: i64[34] = torch.ops.aten.sub.Tensor(33, arange_7);  arange_7 = None
            _to_copy_18: f32[34] = torch.ops.aten._to_copy.default(sub_26, dtype = torch.float32);  sub_26 = None
            mul_44: f32[34] = torch.ops.aten.mul.Tensor(_to_copy_18, 0.06060606060606061);  _to_copy_18 = None
            sub_27: f32[34] = torch.ops.aten.sub.Tensor(1, mul_44);  mul_44 = None
            where_9: f32[34] = torch.ops.aten.where.self(lt_7, add_30, sub_27);  lt_7 = add_30 = sub_27 = None
            clamp_5: f32[34] = torch.ops.aten.clamp.default(where_9, -1, 1);  where_9 = None
            abs_5: f32[34] = torch.ops.aten.abs.default(clamp_5)
            sub_28: f32[34] = torch.ops.aten.sub.Tensor(2, abs_5);  abs_5 = None
            div_13: f32[34] = torch.ops.aten.div.Tensor(clamp_5, sub_28);  clamp_5 = sub_28 = None
            slice_7: f32[33] = torch.ops.aten.slice.Tensor(div_13, 0, 1, 9223372036854775807)
            slice_8: f32[33] = torch.ops.aten.slice.Tensor(div_13, 0, 0, -1);  div_13 = None
            sub_29: f32[33] = torch.ops.aten.sub.Tensor(slice_7, slice_8);  slice_7 = slice_8 = None
            view_7: f32[1, 33, 1, 1] = torch.ops.aten.view.default(sub_29, [1, -1, 1, 1]);  sub_29 = None
            clamp_min_2: f32[1, 33, 1, 1] = torch.ops.aten.clamp_min.default(view_6, 1e-06);  view_6 = None
            div_14: f32[1, 33, 128, 128] = torch.ops.aten.div.Tensor(_softmax, clamp_min_2);  _softmax = clamp_min_2 = None
            clamp_min_3: f32[1, 33, 1, 1] = torch.ops.aten.clamp_min.default(view_7, 1e-06);  view_7 = None
            div_15: f32[1, 33, 128, 128] = torch.ops.aten.div.Tensor(_softmax_1, clamp_min_3);  _softmax_1 = clamp_min_3 = None
            unsqueeze_2: f32[1, 128, 128] = torch.ops.aten.unsqueeze.default(mul_39, 0);  mul_39 = None
            unsqueeze_3: f32[1, 1, 128, 128] = torch.ops.aten.unsqueeze.default(unsqueeze_2, 1);  unsqueeze_2 = None
            add_31: f32[1, 33, 128, 128] = torch.ops.aten.add.Tensor(mul_20, unsqueeze_3);  mul_20 = unsqueeze_3 = None
            unsqueeze_4: f32[1, 128, 128] = torch.ops.aten.unsqueeze.default(mul_40, 0);  mul_40 = None
            unsqueeze_5: f32[1, 1, 128, 128] = torch.ops.aten.unsqueeze.default(unsqueeze_4, 1);  unsqueeze_4 = None
            add_32: f32[1, 33, 128, 128] = torch.ops.aten.add.Tensor(mul_23, unsqueeze_5);  mul_23 = unsqueeze_5 = None
            mul_45: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(add_31, div_14);  add_31 = None
            mul_46: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(mul_45, sigmoid);  mul_45 = None
            sum_5: f32[1] = torch.ops.aten.sum.dim_IntList(mul_46, [1, 2, 3]);  mul_46 = None
            mul_47: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(div_14, sigmoid);  div_14 = None
            sum_6: f32[1] = torch.ops.aten.sum.dim_IntList(mul_47, [1, 2, 3]);  mul_47 = None
            clamp_min_4: f32[1] = torch.ops.aten.clamp_min.default(sum_6, 1e-06);  sum_6 = None
            div_16: f32[1] = torch.ops.aten.div.Tensor(sum_5, clamp_min_4);  sum_5 = clamp_min_4 = None
            mul_48: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(add_32, div_15);  add_32 = None
            mul_49: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(mul_48, sigmoid);  mul_48 = None
            sum_7: f32[1] = torch.ops.aten.sum.dim_IntList(mul_49, [1, 2, 3]);  mul_49 = None
            mul_50: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(div_15, sigmoid);  div_15 = None
            sum_8: f32[1] = torch.ops.aten.sum.dim_IntList(mul_50, [1, 2, 3]);  mul_50 = None
            clamp_min_5: f32[1] = torch.ops.aten.clamp_min.default(sum_8, 1e-06);  sum_8 = None
            div_17: f32[1] = torch.ops.aten.div.Tensor(sum_7, clamp_min_5);  sum_7 = clamp_min_5 = None
            slice_9: f32[1] = torch.ops.aten.slice.Tensor(div_16, 0, 0, 9223372036854775807);  div_16 = None
            unsqueeze_6: f32[1, 1] = torch.ops.aten.unsqueeze.default(slice_9, 1);  slice_9 = None
            slice_10: f32[1] = torch.ops.aten.slice.Tensor(div_17, 0, 0, 9223372036854775807);  div_17 = None
            unsqueeze_7: f32[1, 1] = torch.ops.aten.unsqueeze.default(slice_10, 1);  slice_10 = None
            slice_11: f32[1] = torch.ops.aten.slice.Tensor(div_9, 0, 0, 9223372036854775807);  div_9 = None
            unsqueeze_8: f32[1, 1] = torch.ops.aten.unsqueeze.default(slice_11, 1);  slice_11 = None
            cat_4: f32[1, 3] = torch.ops.aten.cat.default([unsqueeze_6, unsqueeze_7, unsqueeze_8], 1);  unsqueeze_6 = unsqueeze_7 = unsqueeze_8 = None
            return (getitem_28, getitem_29, getitem_30, getitem_31, cat_4, sigmoid, clamp_1, sum_3, sum_4, mul_16, sum_1, sum_2)
            
Graph Signature: **Removed this to spare some lines**
Symbol to range: {}

Unfortunatelly this architecture is resulting in the error: torch.onnx._internal.diagnostics.infra.context.RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten.roll.default', 'aten.var.correction']}.

Is there any guideline on how to solve this problem and implement the support for the aforementioned operations? Thank you, and sorry for the long post.

About this issue

  • Original URL
  • State: closed
  • Created 7 months ago
  • Reactions: 1
  • Comments: 23 (23 by maintainers)

Commits related to this issue

Most upvoted comments

So we are good?

Yeah we good, there might be a problem in complex operations which is mentioned in Edit4, but it is all good, if eventually I need it to handle any shape without preprocessing, I open a new Issue. I’m closing this now.

Thank you once again @justinchuby.

Sure.

The export script:

import torch
import onnx
import yaml
from torch.onnx import export
from pathlib import Path
from loguru import logger
from argparse import ArgumentParser

from models.model_zoo import GRIDNET4, GRIDTIMMNET4

def main(args):
    
    torch.set_default_device("cpu")
    
    root = args.model_dir
    config_file = root / "configs.yaml"
    checkpoint = root / args.model_version
    logger.info(f"Exporting model {root} to ONNX")
    
    onnx_model = checkpoint.parent / checkpoint.with_suffix(".onnx").name
    
    logger.info(f"Loading Model config file {config_file}")
    with open(config_file, "r") as f:
        config = yaml.load(f.read(), yaml.Loader)
    
    if config["exp_name"] == "gridnet4":
        model = GRIDNET4(
            num_pose_2d=config["num_pose_2d"],
            num_layers=config["num_layers"],
            img_ppi=config["img_ppi"],
            middle_shape=config["middle_shape"],
            with_tv=config["with_tv"],
            with_enh=config["with_enh"],
            bin_type=config["bin_type"],
            activate=config["activate"],
            pretrained=False
        )
    else:
        model = GRIDTIMMNET4(**config)
    
    logger.info(f"Loading model checkpoint {checkpoint}")
    
    state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))
    
    # model.load_state_dict(state_dict["model"]) # Load model state_dict previously
    model.eval()
    
    logger.info(f"Exporting model to {onnx_model}")
    
    # TorchDynamo exports correctly but there are still unsupported onnxscript ops.
    # Waiting Issue Response to solve this.
    
    # onnx_export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
    
    onnx_prog = torch.onnx.dynamo_export(
        model, torch.randn(1, 1, 512, 512, dtype=torch.float32), 
        # export_options=onnx_export_options
    )
    # onnx_prog.save(onnx_model.as_posix()) #version without model dict load
    onnx_prog.save(onnx_model.as_posix(), model_state_dict=state_dict["model"])
    
    onnx_model = onnx.load(onnx_model)
    print(onnx.helper.printable_graph(onnx_model.graph))
    onnx.checker.check_model(onnx_model)

def parse_args():
    
    parser = ArgumentParser()
    
    parser.add_argument("model_dir", help="Parent directory from model", type=Path)
    parser.add_argument("--model_version", help="Version of exported checkpoint", type=str, default="best.pth")
    
    return parser.parse_args()

if __name__ == "__main__":
    main(parse_args())

The Comparison Script (Edit added function to compare the parameters)

import cv2
import yaml
import torch
import onnx
import numpy as np
import onnxruntime as ort

from pathlib import Path
from loguru import logger
from argparse import ArgumentParser

from deploy_gridnet import process_img
from models.model_zoo import GRIDNET4, GRIDTIMMNET4

def get_onnx_tensor_dict(onnx_load):
    return {t.name: onnx.numpy_helper.to_array(t) for t in onnx_load.graph.initializer}

def compare_onnx_graph_and_state_dict(onnx_dict, state_dict):
    torch_keys = [k for k in state_dict.keys() if k not in onnx_dict]
    onnx_keys = [k for k in onnx_dict.keys() if k not in state_dict]
    for k, v in onnx_dict.items():
        if k in onnx_keys: continue
        is_close = np.isclose(
            v,
            state_dict[k].numpy()
        )
        if not is_close.all():
            logger.warning(
                f"Parameter {k} is Divirging. {is_close}"
            )
            
    logger.warning(f"ONNX Keys not in PyTorch {onnx_keys}")
    logger.warning(f"PyTorch Keys not in ONNX {torch_keys}")

def main(args):
    
    config_file = args.root / "configs.yaml"
    logger.info(f"Reading config file {config_file}")
    with open(config_file, "r") as f:
        config = yaml.load(f, yaml.Loader)
        
    if config.get("architecture", None):
        model = GRIDTIMMNET4(**config)
    else:
        model = GRIDNET4(**config)
    
    checkpoint_file = args.root / f"{args.name}.pth"
    onnx_file = checkpoint_file.with_suffix(".onnx")
    
    logger.info(f"Loading Checkpoint {checkpoint_file}")
    checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu"))
    model.load_state_dict(checkpoint["model"])
    model.eval()
    
    logger.info(f"Loading ONNX Model {onnx_file}")
    onnx_load = onnx.load(onnx_file)
    onnx_model = ort.InferenceSession(onnx_file)
    
    logger.info(f"Comparing ONNX Graph and PyTorch State Dict")
    compare_onnx_graph_and_state_dict(
        get_onnx_tensor_dict(onnx_load),
        checkpoint["model"],
    )
    
    image_file = Path(__file__).parent / "image/1_1.tif"
    logger.info(f"Loading image {image_file}")
    # im = cv2.imread(
    #     image_file.as_posix(), 
    #     cv2.IMREAD_GRAYSCALE,
    # ).astype(np.float32)
    im = np.random.randn(512, 512).astype(np.float32)
    # im, _, _ = process_img(im, 500, None)
    print(im.shape)
    im_tensor = torch.from_numpy(im)[None][None]
    print(im_tensor.shape)
    with torch.no_grad():
        torch_out = model(im_tensor)
    onnx_out = onnx_model.run(["cat_4"], {"l_input_": im[None, None]}) 
    
    print(torch_out["pose_2d"])
    print(onnx_out)
    

def parse_args():
    
    parser = ArgumentParser()
    
    parser.add_argument("root", type=Path, help="Path to output models")
    parser.add_argument(
        "--name", "-n", type=str, help="Name of model file", default="best"
    )
    
    return parser.parse_args()

if __name__ == "__main__":
    main(parse_args())

Edit: Checking manually some parameters, they seem to have been converted fine. I’ll make a script to verify if everything is in order.

Edit2 : Seems like there are no deviations in parameters, leading me to believe that the graph is incorrect. Will have to verify it further.

Edit3: I always get the same result independently of the input in the ONNX Model. I also get a broadcast error if I set dynamic shapes and give a shape different from the one used in export function ([1, 1, 512, 512]). I’ll upload the model if you wish to check it out. Model

(384, 384)
torch.Size([1, 1, 384, 384])
2023-11-29 16:50:38.636816288 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running Mul node. Name:'_inline_aten_mul_complex_token_108n19' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 384 by 512

Traceback (most recent call last):
  File "/home/griaule/Fingerprint-2DPose-Dense-Voting/onnx_sanity_check.py", line 104, in <module>
    main(parse_args())
  File "/home/griaule/Fingerprint-2DPose-Dense-Voting/onnx_sanity_check.py", line 78, in main
    onnx_out = onnx_model.run(
               ^^^^^^^^^^^^^^^
  File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
    return self._sess.run(output_names, input_feed, run_options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Mul node. Name:'_inline_aten_mul_complex_token_108n19' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 384 by 512

Edit4: Comparing the ONNX Graph and the State Dict, the following result is seen: image

Edit5: Seems like I screwed up some reference in the aten_linspace using CastLike. When I removed the changes from aten_linspace and forced torch.float32 in the linspace calls, I got a result very similar to PyTorch.

FYI if you find the functions with too many if branches are bothering you because of performance, you may consider https://github.com/microsoft/onnxscript/pull/1178

Great - we will need to fix linspace. If you like please feel free to add a cast like in the lines above and see if it gives you the correct model. I will create a fix this week.

Cool, the CastLike solves the problem of Div, but right after the Mul on return also has inconsistent type (due to start and range_tensors type mismatch probably). if I cast the range_tensor also like start, the following error happens later on the graph:

>>> ort.InferenceSession("best.onnx")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 463, in _create_inference_session
    sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Node (_inline_models_model_zoo_DenseHoughVoter_voter_1aten_rsub_186) Op (aten_rsub) [ShapeInferenceError] (op_type:Sub, node name: n3): B has inconsistent type tensor(int64)

Edit: The problem is still the linspace. When I gave the keyword dtype to linspace functions I got a correct model.

Thank you very much for all your attention and help @justinchuby. I’ll check what is failing on my PR and try to contribute with the aten::roll for complex and the var (dim and correction).

I changed the backbock from a manually built resnet to the backbone from timm and retrained the model. The new model architecture seems to have a way straighter graph (checked on Netron), but it still has the same problem occuring in Div_148. Here is the new model.

Python 3.11.4 (main, Jul  5 2023, 13:45:01) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import onnxruntime as ort
>>> ort.InferenceSession("best.onnx")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 452, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from best.onnx failed:Node (models_units_FastCartoonTexture_preprocess_tv_1_7) Op (models_units_FastCartoonTexture_preprocess_tv_1) [ShapeInferenceError] (op_type:Div, node name: Div_148): B has inconsistent type tensor(int64)
>>> 

Also noticed the following UserWarning from torch.onnx.dynamo_export that might be giving a clue that something is wrong during graph build.

/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/onnx/_internal/fx/passes/readability.py:53: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  new_node = self.module.graph.get_attr(normalized_name)

I installed the torch-nightly build and still the same Warning regarding ShapeInferenceError and TypeInferenceError. The only difference now is that the indexing of the nodes changed

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 452, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from best.onnx failed:Node (models_units_FastCartoonTexture_preprocess_tv_1_149) Op (models_units_FastCartoonTexture_preprocess_tv_1) [ShapeInferenceError] (op_type:Div, node name: Div_148): B has inconsistent type tensor(int64