iree: Error out on inputs with f64 values and tensors.

Describe the bug

Miscompile of the following Torch code when invoked with tensor<4xf64> (same happens for +)

    def forward(self, a):
        return a * a

Note: In my original test case, the results come out wrong, rather than hanging, with the following tensor statistics:

    FAIL - "ElementwiseMulTensorFloatModule_basic"
        @ trace item #0 - call to "forward"
        @ output of call to "forward"
        ERROR: value (Tensor with shape=[4] min=+0.0, max=+136.0, mean=+34.0) is not close to golden value (Tensor with shape=[4] min=+0.0, max=+9.0, mean=+3.5)

(the full value of the result computed by IREE is [ 0., 136., 0., 0.] instead of [0, 1, 4, 9])

To Reproduce

$ iree-run-mlir /tmp/repro.mlir -function-input="4xf64=0 1 2 3" -iree-hal-target-backends=dylib
EXEC @forward
result[0]: hal.buffer_view
<hang>

Expected result: Not to hang, and to give the result “4xf64=0 1 4 9”.

#map = affine_map<(d0) -> (d0)>
module attributes {torch.debug_module_name = "ElementwiseMulTensorFloatModule"} {
  func @forward(%arg0: tensor<4xf64>) -> tensor<4xf64> {
    %0 = linalg.init_tensor [4] : tensor<4xf64>
    %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<4xf64>) outs(%0 : tensor<4xf64>) {
    ^bb0(%arg1: f64, %arg2: f64):
      %2 = arith.mulf %arg1, %arg1 : f64
      linalg.yield %2 : f64
    } -> tensor<4xf64>
    return %1 : tensor<4xf64>
  }
}

About this issue

  • Original URL
  • State: open
  • Created 2 years ago
  • Comments: 32 (12 by maintainers)

Most upvoted comments

I do think that putting a Gandalf-style “you shall not pass” pass somewhere in the compiler would be useful given the present state.

image

Propose to add to backlog?

Thought of a different approach - always allow f64 input and perform the f64<->f32 conversion by reinterpret casting to i64 and shuffling the bits in linalg ops - then we compile as normal and the apps still see the f64 on the boundary. This wouldn’t work if your tests are also checking for precision - are they?

(this would mean no ABI changes or anything user visible)

Our e2e tests all have a general tolerance (1e-5 or some such). They don’t do super fine-grained checking that we really have 52 vs 24 bits of precision.