SHARK: bf16 result mismatch for Conv2D op
- Following is the Conv2d pytorch module.
import torch
import torch.nn as nn
class op_conv2d(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Conv2d(8, 10, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
)
def forward(self, x):
return self.layers(x)
model = op_conv2d()
model_bf16 = model.to(torch.bfloat16)
test_input_bf16 = torch.randn(2, 8, 12, 16).to(torch.bfloat16)
test_output_bf16 = model_bf16(test_input_bf16)
print("Input:", test_input_bf16)
print("Output:", test_output_bf16)
- This is the linalg IR of the above pytorch module:
#map = affine_map<(d0, d1, d2, d3) -> (d1)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
func.func @main_graph(%arg0: tensor<2x8x12x16xbf16>) -> tensor<2x10x7x16xbf16> {
%cst = arith.constant dense<"0xtensor<10x8x3x5xbf16>
%cst_0 = arith.constant dense<[8.056640e-02, 6.738280e-02, -3.637700e-02, -2.111820e-02, 7.568360e-02, 7.519530e-02, -3.112790e-02, 4.663090e-02, -4.589840e-02, 5.908200e-02]> : tensor<10xbf16>
%cst_1 = arith.constant 0.000000e+00 : bf16
%padded = tensor.pad %arg0 low[0, 0, 4, 2] high[0, 0, 4, 2] {
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
tensor.yield %cst_1 : bf16
} : tensor<2x8x12x16xbf16> to tensor<2x8x20x20xbf16>
%0 = tensor.empty() : tensor<2x10x7x16xbf16>
%1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst_0 : tensor<10xbf16>) outs(%0 : tensor<2x10x7x16xbf16>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
} -> tensor<2x10x7x16xbf16>
%2 = linalg.conv_2d_nchw_fchw {dilations = dense<[3, 1]> : vector<2xi64>, strides = dense<[2, 1]> : vector<2xi64>} ins(%padded, %cst : tensor<2x8x20x20xbf16>, tensor<10x8x3x5xbf16>) outs(%1 : tensor<2x10x7x16xbf16>) -> tensor<2x10x7x16xbf16>
return %2 : tensor<2x10x7x16xbf16>
}
}
Running the above module through the IREE cpu backend generates incorrect results wrt the pytorch output.
About this issue
- Original URL
- State: open
- Created 4 months ago
- Comments: 30 (26 by maintainers)
It’s ok, I think I have the patch ready soon.
https://github.com/llvm/llvm-project/pull/83180 is merged, so you’ll get it in the next integrate or can cherry-pick it locally until then to verify it fixed your issue.
I think we should not close this unless we can conclude on handling bf16 in cpu. I mean how to verify the model is producing the correct outputs through onnx pipeline.