iree: JAX RFFT causes compiler crash due to failed isIntOrFloat() assertion

What happened?

Compiling jnp.fft.rfft causes the following compiler crash:

iree-compile: .../iree/third_party/llvm-project/mlir/lib/IR/Types.cpp:92: unsigned int mlir::Type::getIntOrFloatBitWidth() const: Assertion `isIntOrFloat() && "only integers and floats have a bitwidth"' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
 #0 0x00007f84e1305a33 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) .../iree/third_party/llvm-project/llvm/lib/Support/Unix/Signals.inc:569:13
 #1 0x00007f84e1303b10 llvm::sys::RunSignalHandlers() .../iree/third_party/llvm-project/llvm/lib/Support/Signals.cpp:104:18
 #2 0x00007f84e1305d9a SignalHandler(int) .../iree/third_party/llvm-project/llvm/lib/Support/Unix/Signals.inc:407:1
 #3 0x00007f84dcff9920 (/lib/x86_64-linux-gnu/libc.so.6+0x38920)
 #4 0x00007f84dcff98a1 raise ./signal/../sysdeps/unix/sysv/linux/raise.c:50:1
 #5 0x00007f84dcfe3546 abort ./stdlib/abort.c:81:7
 #6 0x00007f84dcfe342f get_sysdep_segment_value ./intl/loadmsgcat.c:509:8
 #7 0x00007f84dcfe342f _nl_load_domain ./intl/loadmsgcat.c:970:34
 #8 0x00007f84dcff2222 (/lib/x86_64-linux-gnu/libc.so.6+0x31222)
 #9 0x00007f84e0bf1a6f mlir::FloatType mlir::Type::cast<mlir::FloatType>() const .../iree/third_party/llvm-project/mlir/include/mlir/IR/Types.h:272:3
#10 0x00007f84e0bf1a6f mlir::Type::getIntOrFloatBitWidth() const .../iree/third_party/llvm-project/mlir/lib/IR/Types.cpp:95:10
#11 0x00007f84e1ed6cac mlir::iree_compiler::IREE::Util::getRoundedElementByteWidth(mlir::Type) .../iree/compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.h:155:3
#12 0x00007f84e1ed6cac mlir::iree_compiler::IREE::Stream::(anonymous namespace)::EncodeTensorSizeOfOp::matchAndRewrite(mlir::iree_compiler::IREE::Stream::TensorSizeOfOp, mlir::PatternRewriter&) const .../iree/compiler/src/iree/compiler/Dialect/Stream/Transforms/EncodeTensors.cpp:219:9
#13 0x00007f84e53dd462 mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>) .../iree/third_party/llvm-project/mlir/lib/Rewrite/PatternApplicator.cpp:201:25
#14 0x00007f84e53af8a2 (anonymous namespace)::GreedyPatternRewriteDriver::simplify(llvm::MutableArrayRef<mlir::Region>) .../iree/third_party/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:270:19
#15 0x00007f84e53af8a2 mlir::applyPatternsAndFoldGreedily(llvm::MutableArrayRef<mlir::Region>, mlir::FrozenRewritePatternSet const&, mlir::GreedyRewriteConfig) .../iree/third_party/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:404:27
#16 0x00007f84e1ed635c mlir::LogicalResult::succeeded() const .../iree/third_party/llvm-project/mlir/include/mlir/Support/LogicalResult.h:41:35
#17 0x00007f84e1ed635c mlir::LogicalResult::failed() const .../iree/third_party/llvm-project/mlir/include/mlir/Support/LogicalResult.h:44:33
#18 0x00007f84e1ed635c mlir::failed(mlir::LogicalResult) .../iree/third_party/llvm-project/mlir/include/mlir/Support/LogicalResult.h:72:58
#19 0x00007f84e1ed635c mlir::iree_compiler::IREE::Stream::(anonymous namespace)::EncodeHostTensorsPass::runOnOperation() .../iree/compiler/src/iree/compiler/Dialect/Stream/Transforms/EncodeTensors.cpp:604:9
#20 0x00007f84e1341b50 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) .../iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:0:11
#21 0x00007f84e13423b6 mlir::LogicalResult::succeeded() const .../iree/third_party/llvm-project/mlir/include/mlir/Support/LogicalResult.h:41:35
#22 0x00007f84e13423b6 mlir::LogicalResult::failed() const .../iree/third_party/llvm-project/mlir/include/mlir/Support/LogicalResult.h:44:33
#23 0x00007f84e13423b6 mlir::failed(mlir::LogicalResult) .../iree/third_party/llvm-project/mlir/include/mlir/Support/LogicalResult.h:72:58
#24 0x00007f84e13423b6 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) .../iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:534:9
#25 0x00007f84e1347b5f mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_12::operator()(mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo&) const .../iree/third_party/llvm-project/mlir/lib/Pass/Pass.cpp:759:5
#26 0x00007f84e1347b5f mlir::LogicalResult mlir::failableParallelForEach<__gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_12&>(mlir::MLIRContext*, __gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, __gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_12&)::'lambda'()::operator()() const .../iree/third_party/llvm-project/mlir/include/mlir/IR/Threading.h:62:18
#27 0x00007f84e1347b5f __gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > > std::__invoke_impl<void, mlir::LogicalResult mlir::failableParallelForEach<__gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_12&>(mlir::MLIRContext*, __gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, __gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_12&)::'lambda'()&>(std::__invoke_other, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_12&) /usr/bin/../lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/invoke.h:61:14
#28 0x00007f84e1347b5f std::enable_if<__and_<std::is_void<__gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > > >, std::__is_invocable<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_12&> >::value, __gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > > >::type std::__invoke_r<void, mlir::LogicalResult mlir::failableParallelForEach<__gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_12&>(mlir::MLIRContext*, __gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, __gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_12&)::'lambda'()&>(mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_12&) /usr/bin/../lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/invoke.h:154:7
#29 0x00007f84e1347b5f std::_Function_handler<void (), mlir::LogicalResult mlir::failableParallelForEach<__gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_12&>(mlir::MLIRContext*, __gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, __gnu_cxx::__normal_iterator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo*, std::vector<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo, std::allocator<mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::OpPMInfo> > >, mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool)::$_12&)::'lambda'()>::_M_invoke(std::_Any_data const&) /usr/bin/../lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/std_function.h:290:9
#30 0x00007f84e0bf5716 std::__shared_ptr<std::promise<void>, (__gnu_cxx::_Lock_policy)2>::get() const /usr/bin/../lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/shared_ptr_base.h:1296:16
#31 0x00007f84e0bf5716 std::__shared_ptr_access<std::promise<void>, (__gnu_cxx::_Lock_policy)2, false, false>::_M_get() const /usr/bin/../lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/shared_ptr_base.h:993:66
#32 0x00007f84e0bf5716 std::__shared_ptr_access<std::promise<void>, (__gnu_cxx::_Lock_policy)2, false, false>::operator->() const /usr/bin/../lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/shared_ptr_base.h:987:9
#33 0x00007f84e0bf5716 llvm::ThreadPool::createTaskAndFuture(std::function<void ()>)::'lambda'()::operator()() const .../iree/third_party/llvm-project/llvm/include/llvm/Support/ThreadPool.h:136:15
#34 0x00007f84e0bf5716 void std::__invoke_impl<void, llvm::ThreadPool::createTaskAndFuture(std::function<void ()>)::'lambda'()&>(std::__invoke_other, llvm::ThreadPool::createTaskAndFuture(std::function<void ()>)::'lambda'()&) /usr/bin/../lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/invoke.h:61:14
#35 0x00007f84e0bf5716 std::enable_if<__and_<std::is_void<void>, std::__is_invocable<llvm::ThreadPool::createTaskAndFuture(std::function<void ()>)::'lambda'()&> >::value, void>::type std::__invoke_r<void, llvm::ThreadPool::createTaskAndFuture(std::function<void ()>)::'lambda'()&>(llvm::ThreadPool::createTaskAndFuture(std::function<void ()>)::'lambda'()&) /usr/bin/../lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/invoke.h:154:7
#36 0x00007f84e0bf5716 std::_Function_handler<void (), llvm::ThreadPool::createTaskAndFuture(std::function<void ()>)::'lambda'()>::_M_invoke(std::_Any_data const&) /usr/bin/../lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/std_function.h:290:9
#37 0x00007f84e12a4690 llvm::ThreadPool::processTasks(llvm::ThreadPoolTaskGroup*) .../iree/third_party/llvm-project/llvm/lib/Support/ThreadPool.cpp:102:5
#38 0x00007f84e12a5453 std::default_delete<std::tuple<llvm::ThreadPool::grow(int)::$_0> >::operator()(std::tuple<llvm::ThreadPool::grow(int)::$_0>*) const /usr/bin/../lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/unique_ptr.h:85:2
#39 0x00007f84e12a5453 std::unique_ptr<std::tuple<llvm::ThreadPool::grow(int)::$_0>, std::default_delete<std::tuple<llvm::ThreadPool::grow(int)::$_0> > >::~unique_ptr() /usr/bin/../lib/gcc/x86_64-linux-gnu/11/../../../../include/c++/11/bits/unique_ptr.h:361:4
#40 0x00007f84e12a5453 void llvm::thread::GenericThreadProxy<std::tuple<llvm::ThreadPool::grow(int)::$_0> >(void*) .../iree/third_party/llvm-project/llvm/include/llvm/Support/thread.h:51:3
#41 0x00007f84e12a5453 void* llvm::thread::ThreadProxy<std::tuple<llvm::ThreadPool::grow(int)::$_0> >(void*) .../iree/third_party/llvm-project/llvm/include/llvm/Support/thread.h:60:5
#42 0x00007f84dcf9fd80 start_thread ./nptl/pthread_create.c:482:7
#43 0x00007f84dd0bb76f __clone ./misc/../sysdeps/unix/sysv/linux/x86_64/clone.S:97:0

Steps to reproduce your issue

Python Reproducer:

import jax
import jax.numpy as jnp

def rfft(x):
  return jnp.fft.rfft(x, 512, axis=2)

jax.jit(rfft, backend="iree")(jnp.zeros((3, 2190, 400, 1)))

MLIR Reproducer:

module @jit_rfft.1 {
  func.func public @main(%arg0: tensor<3x2190x400x1xf32>) -> tensor<3x2190x257x1xcomplex<f32>> {
    %0 = "mhlo.transpose"(%arg0) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<3x2190x400x1xf32>) -> tensor<3x2190x1x400xf32>
    %1 = mhlo.constant dense<0> : tensor<i32>
    %2 = call @_pad(%0, %1) : (tensor<3x2190x1x400xf32>, tensor<i32>) -> tensor<3x2190x1x512xf32>
    %3 = call @fft(%2) : (tensor<3x2190x1x512xf32>) -> tensor<3x2190x1x257xcomplex<f32>>
    %4 = "mhlo.transpose"(%3) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<3x2190x1x257xcomplex<f32>>) -> tensor<3x2190x257x1xcomplex<f32>>
    return %4 : tensor<3x2190x257x1xcomplex<f32>>
  }
  func.func private @_pad(%arg0: tensor<3x2190x1x400xf32>, %arg1: tensor<i32>) -> tensor<3x2190x1x512xf32> {
    %0 = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<4x2xi32>
    %1 = mhlo.convert(%0) : (tensor<4x2xi32>) -> tensor<4x2xf32>
    %2 = mhlo.constant dense<0> : tensor<i32>
    %3 = mhlo.constant dense<0> : tensor<i32>
    %4 = "mhlo.compare"(%2, %3) {compare_type = #mhlo<comparison_type SIGNED>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %5 = mhlo.constant dense<0> : tensor<i32>
    %6 = mhlo.constant dense<4> : tensor<i32>
    %7 = mhlo.add %5, %6 : tensor<i32>
    %8 = mhlo.constant dense<0> : tensor<i32>
    %9 = "mhlo.select"(%4, %7, %8) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    %10 = mhlo.convert %9 : tensor<i32>
    %11 = mhlo.constant dense<0> : tensor<i32>
    %12 = mhlo.constant dense<0> : tensor<i32>
    %13 = "mhlo.compare"(%11, %12) {compare_type = #mhlo<comparison_type SIGNED>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %14 = mhlo.constant dense<0> : tensor<i32>
    %15 = mhlo.constant dense<2> : tensor<i32>
    %16 = mhlo.add %14, %15 : tensor<i32>
    %17 = mhlo.constant dense<0> : tensor<i32>
    %18 = "mhlo.select"(%13, %16, %17) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    %19 = mhlo.convert %18 : tensor<i32>
    %20 = "mhlo.broadcast_in_dim"(%10) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %21 = "mhlo.broadcast_in_dim"(%19) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %22 = "mhlo.concatenate"(%20, %21) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %23 = "mhlo.gather"(%1, %22) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %24 = "mhlo.pad"(%arg0, %23) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<3x2190x1x400xf32>, tensor<f32>) -> tensor<3x2190x1x400xf32>
    %25 = mhlo.constant dense<0> : tensor<i32>
    %26 = mhlo.constant dense<0> : tensor<i32>
    %27 = "mhlo.compare"(%25, %26) {compare_type = #mhlo<comparison_type SIGNED>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %28 = mhlo.constant dense<0> : tensor<i32>
    %29 = mhlo.constant dense<4> : tensor<i32>
    %30 = mhlo.add %28, %29 : tensor<i32>
    %31 = mhlo.constant dense<0> : tensor<i32>
    %32 = "mhlo.select"(%27, %30, %31) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    %33 = mhlo.convert %32 : tensor<i32>
    %34 = mhlo.constant dense<1> : tensor<i32>
    %35 = mhlo.constant dense<0> : tensor<i32>
    %36 = "mhlo.compare"(%34, %35) {compare_type = #mhlo<comparison_type SIGNED>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %37 = mhlo.constant dense<1> : tensor<i32>
    %38 = mhlo.constant dense<2> : tensor<i32>
    %39 = mhlo.add %37, %38 : tensor<i32>
    %40 = mhlo.constant dense<1> : tensor<i32>
    %41 = "mhlo.select"(%36, %39, %40) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    %42 = mhlo.convert %41 : tensor<i32>
    %43 = "mhlo.broadcast_in_dim"(%33) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %44 = "mhlo.broadcast_in_dim"(%42) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %45 = "mhlo.concatenate"(%43, %44) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %46 = "mhlo.gather"(%1, %45) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %47 = "mhlo.pad"(%24, %46) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<3x2190x1x400xf32>, tensor<f32>) -> tensor<3x2190x1x400xf32>
    %48 = mhlo.constant dense<1> : tensor<i32>
    %49 = mhlo.constant dense<0> : tensor<i32>
    %50 = "mhlo.compare"(%48, %49) {compare_type = #mhlo<comparison_type SIGNED>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %51 = mhlo.constant dense<1> : tensor<i32>
    %52 = mhlo.constant dense<4> : tensor<i32>
    %53 = mhlo.add %51, %52 : tensor<i32>
    %54 = mhlo.constant dense<1> : tensor<i32>
    %55 = "mhlo.select"(%50, %53, %54) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    %56 = mhlo.convert %55 : tensor<i32>
    %57 = mhlo.constant dense<0> : tensor<i32>
    %58 = mhlo.constant dense<0> : tensor<i32>
    %59 = "mhlo.compare"(%57, %58) {compare_type = #mhlo<comparison_type SIGNED>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %60 = mhlo.constant dense<0> : tensor<i32>
    %61 = mhlo.constant dense<2> : tensor<i32>
    %62 = mhlo.add %60, %61 : tensor<i32>
    %63 = mhlo.constant dense<0> : tensor<i32>
    %64 = "mhlo.select"(%59, %62, %63) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    %65 = mhlo.convert %64 : tensor<i32>
    %66 = "mhlo.broadcast_in_dim"(%56) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %67 = "mhlo.broadcast_in_dim"(%65) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %68 = "mhlo.concatenate"(%66, %67) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %69 = "mhlo.gather"(%1, %68) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %70 = "mhlo.pad"(%47, %69) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<3x2190x1x400xf32>, tensor<f32>) -> tensor<3x2190x1x400xf32>
    %71 = mhlo.constant dense<1> : tensor<i32>
    %72 = mhlo.constant dense<0> : tensor<i32>
    %73 = "mhlo.compare"(%71, %72) {compare_type = #mhlo<comparison_type SIGNED>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %74 = mhlo.constant dense<1> : tensor<i32>
    %75 = mhlo.constant dense<4> : tensor<i32>
    %76 = mhlo.add %74, %75 : tensor<i32>
    %77 = mhlo.constant dense<1> : tensor<i32>
    %78 = "mhlo.select"(%73, %76, %77) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    %79 = mhlo.convert %78 : tensor<i32>
    %80 = mhlo.constant dense<1> : tensor<i32>
    %81 = mhlo.constant dense<0> : tensor<i32>
    %82 = "mhlo.compare"(%80, %81) {compare_type = #mhlo<comparison_type SIGNED>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %83 = mhlo.constant dense<1> : tensor<i32>
    %84 = mhlo.constant dense<2> : tensor<i32>
    %85 = mhlo.add %83, %84 : tensor<i32>
    %86 = mhlo.constant dense<1> : tensor<i32>
    %87 = "mhlo.select"(%82, %85, %86) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    %88 = mhlo.convert %87 : tensor<i32>
    %89 = "mhlo.broadcast_in_dim"(%79) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %90 = "mhlo.broadcast_in_dim"(%88) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %91 = "mhlo.concatenate"(%89, %90) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %92 = "mhlo.gather"(%1, %91) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %93 = "mhlo.pad"(%70, %92) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<3x2190x1x400xf32>, tensor<f32>) -> tensor<3x2190x1x400xf32>
    %94 = mhlo.constant dense<2> : tensor<i32>
    %95 = mhlo.constant dense<0> : tensor<i32>
    %96 = "mhlo.compare"(%94, %95) {compare_type = #mhlo<comparison_type SIGNED>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %97 = mhlo.constant dense<2> : tensor<i32>
    %98 = mhlo.constant dense<4> : tensor<i32>
    %99 = mhlo.add %97, %98 : tensor<i32>
    %100 = mhlo.constant dense<2> : tensor<i32>
    %101 = "mhlo.select"(%96, %99, %100) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    %102 = mhlo.convert %101 : tensor<i32>
    %103 = mhlo.constant dense<0> : tensor<i32>
    %104 = mhlo.constant dense<0> : tensor<i32>
    %105 = "mhlo.compare"(%103, %104) {compare_type = #mhlo<comparison_type SIGNED>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %106 = mhlo.constant dense<0> : tensor<i32>
    %107 = mhlo.constant dense<2> : tensor<i32>
    %108 = mhlo.add %106, %107 : tensor<i32>
    %109 = mhlo.constant dense<0> : tensor<i32>
    %110 = "mhlo.select"(%105, %108, %109) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    %111 = mhlo.convert %110 : tensor<i32>
    %112 = "mhlo.broadcast_in_dim"(%102) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %113 = "mhlo.broadcast_in_dim"(%111) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %114 = "mhlo.concatenate"(%112, %113) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %115 = "mhlo.gather"(%1, %114) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %116 = "mhlo.pad"(%93, %115) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<3x2190x1x400xf32>, tensor<f32>) -> tensor<3x2190x1x400xf32>
    %117 = mhlo.constant dense<2> : tensor<i32>
    %118 = mhlo.constant dense<0> : tensor<i32>
    %119 = "mhlo.compare"(%117, %118) {compare_type = #mhlo<comparison_type SIGNED>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %120 = mhlo.constant dense<2> : tensor<i32>
    %121 = mhlo.constant dense<4> : tensor<i32>
    %122 = mhlo.add %120, %121 : tensor<i32>
    %123 = mhlo.constant dense<2> : tensor<i32>
    %124 = "mhlo.select"(%119, %122, %123) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    %125 = mhlo.convert %124 : tensor<i32>
    %126 = mhlo.constant dense<1> : tensor<i32>
    %127 = mhlo.constant dense<0> : tensor<i32>
    %128 = "mhlo.compare"(%126, %127) {compare_type = #mhlo<comparison_type SIGNED>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %129 = mhlo.constant dense<1> : tensor<i32>
    %130 = mhlo.constant dense<2> : tensor<i32>
    %131 = mhlo.add %129, %130 : tensor<i32>
    %132 = mhlo.constant dense<1> : tensor<i32>
    %133 = "mhlo.select"(%128, %131, %132) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    %134 = mhlo.convert %133 : tensor<i32>
    %135 = "mhlo.broadcast_in_dim"(%125) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %136 = "mhlo.broadcast_in_dim"(%134) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %137 = "mhlo.concatenate"(%135, %136) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %138 = "mhlo.gather"(%1, %137) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %139 = "mhlo.pad"(%116, %138) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<3x2190x1x400xf32>, tensor<f32>) -> tensor<3x2190x1x400xf32>
    %140 = mhlo.constant dense<3> : tensor<i32>
    %141 = mhlo.constant dense<0> : tensor<i32>
    %142 = "mhlo.compare"(%140, %141) {compare_type = #mhlo<comparison_type SIGNED>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %143 = mhlo.constant dense<3> : tensor<i32>
    %144 = mhlo.constant dense<4> : tensor<i32>
    %145 = mhlo.add %143, %144 : tensor<i32>
    %146 = mhlo.constant dense<3> : tensor<i32>
    %147 = "mhlo.select"(%142, %145, %146) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    %148 = mhlo.convert %147 : tensor<i32>
    %149 = mhlo.constant dense<0> : tensor<i32>
    %150 = mhlo.constant dense<0> : tensor<i32>
    %151 = "mhlo.compare"(%149, %150) {compare_type = #mhlo<comparison_type SIGNED>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %152 = mhlo.constant dense<0> : tensor<i32>
    %153 = mhlo.constant dense<2> : tensor<i32>
    %154 = mhlo.add %152, %153 : tensor<i32>
    %155 = mhlo.constant dense<0> : tensor<i32>
    %156 = "mhlo.select"(%151, %154, %155) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    %157 = mhlo.convert %156 : tensor<i32>
    %158 = "mhlo.broadcast_in_dim"(%148) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %159 = "mhlo.broadcast_in_dim"(%157) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %160 = "mhlo.concatenate"(%158, %159) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %161 = "mhlo.gather"(%1, %160) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %162 = "mhlo.pad"(%139, %161) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<3x2190x1x400xf32>, tensor<f32>) -> tensor<3x2190x1x400xf32>
    %163 = mhlo.constant dense<3> : tensor<i32>
    %164 = mhlo.constant dense<0> : tensor<i32>
    %165 = "mhlo.compare"(%163, %164) {compare_type = #mhlo<comparison_type SIGNED>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %166 = mhlo.constant dense<3> : tensor<i32>
    %167 = mhlo.constant dense<4> : tensor<i32>
    %168 = mhlo.add %166, %167 : tensor<i32>
    %169 = mhlo.constant dense<3> : tensor<i32>
    %170 = "mhlo.select"(%165, %168, %169) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    %171 = mhlo.convert %170 : tensor<i32>
    %172 = mhlo.constant dense<1> : tensor<i32>
    %173 = mhlo.constant dense<0> : tensor<i32>
    %174 = "mhlo.compare"(%172, %173) {compare_type = #mhlo<comparison_type SIGNED>, comparison_direction = #mhlo<comparison_direction LT>} : (tensor<i32>, tensor<i32>) -> tensor<i1>
    %175 = mhlo.constant dense<1> : tensor<i32>
    %176 = mhlo.constant dense<2> : tensor<i32>
    %177 = mhlo.add %175, %176 : tensor<i32>
    %178 = mhlo.constant dense<1> : tensor<i32>
    %179 = "mhlo.select"(%174, %177, %178) : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
    %180 = mhlo.convert %179 : tensor<i32>
    %181 = "mhlo.broadcast_in_dim"(%171) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %182 = "mhlo.broadcast_in_dim"(%180) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %183 = "mhlo.concatenate"(%181, %182) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %184 = "mhlo.gather"(%1, %183) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %185 = "mhlo.pad"(%162, %184) {edge_padding_high = dense<[0, 0, 0, 112]> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<3x2190x1x400xf32>, tensor<f32>) -> tensor<3x2190x1x512xf32>
    return %185 : tensor<3x2190x1x512xf32>
  }
  func.func private @fft(%arg0: tensor<3x2190x1x512xf32>) -> tensor<3x2190x1x257xcomplex<f32>> {
    %0 = "mhlo.fft"(%arg0) {fft_length = dense<512> : tensor<1xi64>, fft_type = #mhlo<fft_type RFFT>} : (tensor<3x2190x1x512xf32>) -> tensor<3x2190x1x257xcomplex<f32>>
    return %0 : tensor<3x2190x1x257xcomplex<f32>>
  }
}

{-#
  external_resources: {
    mlir_reproducer: {
      pipeline: "func.func(mhlo-legalize-control-flow,iree-top-level-scf-to-cfg,iree-mhlo-to-mhlo-preprocessing{order-conv-features=true},canonicalize{  max-iterations=10 region-simplify=true top-down=true},shape-to-shape-lowering), convert-shape-to-std, func.func(canonicalize{  max-iterations=10 region-simplify=true top-down=true}), inline{default-pipeline= max-iterations=4 }, iree-util-demote-i64-to-i32, iree-util-demote-f64-to-f32, func.func(canonicalize{  max-iterations=10 region-simplify=true top-down=true},cse,hlo-legalize-shape-computations,iree-mhlo-to-linalg-ext,iree-mhlo-to-linalg-on-tensors), reconcile-unrealized-casts, func.func(canonicalize{  max-iterations=10 region-simplify=true top-down=true}), iree-mhlo-verify-compiler-input-legality, iree-import-public, iree-import-ml-program, iree-sanitize-module-names, iree-abi-wrap-entry-points, inline{default-pipeline= max-iterations=4 }, func.func(canonicalize{  max-iterations=10 region-simplify=true top-down=true},cse), symbol-dce, iree-util-demote-f64-to-f32, func.func(iree-flow-convert-conv2d-1x1-to-matmul,iree-flow-detach-elementwise-from-named-ops,iree-verify-input-legality),util.initializer(iree-flow-convert-conv2d-1x1-to-matmul,iree-flow-detach-elementwise-from-named-ops,iree-verify-input-legality), linalg-named-op-conversion, iree-flow-expand-tensor-shapes, iree-util-fixed-point-iterator{max-iterations=10 pipeline=func.func(iree-util-simplify-global-accesses),util.initializer(iree-util-simplify-global-accesses),iree-util-apply-patterns,iree-util-fold-globals,func.func(canonicalize{  max-iterations=10 region-simplify=true top-down=true},cse),util.initializer(canonicalize{  max-iterations=10 region-simplify=true top-down=true},cse)}, func.func(iree-flow-pad-tensor-to-tensor-insert-slice,convert-elementwise-to-linalg,linalg-fold-unit-extent-dims{fold-one-trip-loops-only=false},iree-flow-interchange-generic-ops,resolve-shaped-type-result-dims,canonicalize{  max-iterations=10 region-simplify=true top-down=true},cse,iree-flow-fusion-of-tensor-ops,canonicalize{  max-iterations=10 region-simplify=true top-down=true},cse,iree-flow-split-reduction-ops,iree-flow-interchange-generic-ops,iree-flow-dispatch-linalg-on-tensors-pass,iree-flow-capture-dispatch-dynamic-dims,canonicalize{  max-iterations=10 region-simplify=true top-down=true},cse),util.initializer(iree-flow-pad-tensor-to-tensor-insert-slice,convert-elementwise-to-linalg,linalg-fold-unit-extent-dims{fold-one-trip-loops-only=false},iree-flow-interchange-generic-ops,resolve-shaped-type-result-dims,canonicalize{  max-iterations=10 region-simplify=true top-down=true},cse,iree-flow-fusion-of-tensor-ops,canonicalize{  max-iterations=10 region-simplify=true top-down=true},cse,iree-flow-split-reduction-ops,iree-flow-interchange-generic-ops,iree-flow-dispatch-linalg-on-tensors-pass,iree-flow-capture-dispatch-dynamic-dims,canonicalize{  max-iterations=10 region-simplify=true top-down=true},cse), iree-flow-initialize-empty-tensors, iree-flow-outline-dispatch-regions, flow.executable(iree-util-strip-debug-ops),func.func(canonicalize{  max-iterations=10 region-simplify=true top-down=true}),util.initializer(canonicalize{  max-iterations=10 region-simplify=true top-down=true}), iree-flow-deduplicate-executables, flow.executable(canonicalize{  max-iterations=10 region-simplify=true top-down=true},cse),func.func(iree-flow-cleanup-tensor-shapes,canonicalize{  max-iterations=10 region-simplify=true top-down=true},cse),util.initializer(iree-flow-cleanup-tensor-shapes,canonicalize{  max-iterations=10 region-simplify=true top-down=true},cse), symbol-dce, iree-stream-verify-input, iree-stream-outline-constants, canonicalize{  max-iterations=10 region-simplify=true top-down=true}, cse, func.func(iree-util-simplify-global-accesses),util.initializer(iree-util-simplify-global-accesses), iree-util-apply-patterns, iree-util-fold-globals, iree-util-fuse-globals, iree-stream-conversion, iree-stream-verify-lowering-to-tensors, canonicalize{  max-iterations=10 region-simplify=true top-down=true}, cse, func.func(iree-util-simplify-global-accesses),util.initializer(iree-util-simplify-global-accesses), iree-util-apply-patterns, iree-util-fold-globals, iree-util-fuse-globals, iree-util-combine-initializers, func.func(iree-stream-encode-host-tensors),stream.executable(iree-stream-encode-device-tensors),util.initializer(iree-stream-encode-host-tensors), iree-stream-materialize-builtins, canonicalize{  max-iterations=10 region-simplify=true top-down=true}, cse, func.func(iree-util-simplify-global-accesses),util.initializer(iree-util-simplify-global-accesses), iree-util-apply-patterns, iree-util-fold-globals, iree-util-fuse-globals, func.func(iree-stream-materialize-copy-on-write),util.initializer(iree-stream-materialize-copy-on-write), iree-stream-elide-async-copies, iree-stream-refine-usage, func.func(iree-stream-schedule-execution,iree-stream-schedule-concurrency),util.initializer(iree-stream-schedule-execution,iree-stream-schedule-concurrency), iree-stream-propagate-timepoints, canonicalize{  max-iterations=10 region-simplify=true top-down=true}, cse, func.func(iree-util-simplify-global-accesses),util.initializer(iree-util-simplify-global-accesses), iree-util-apply-patterns, iree-util-fold-globals, iree-util-fuse-globals, iree-stream-verify-lowering-to-async, func.func(iree-stream-schedule-allocation,iree-stream-pack-constants,iree-stream-pack-allocations,iree-stream-layout-slices),util.initializer(iree-stream-schedule-allocation,iree-stream-pack-constants,iree-stream-pack-allocations,iree-stream-layout-slices), iree-util-propagate-subranges, canonicalize{  max-iterations=10 region-simplify=true top-down=true}, cse, func.func(iree-util-simplify-global-accesses),util.initializer(iree-util-simplify-global-accesses), iree-util-apply-patterns, iree-util-fold-globals, iree-util-fuse-globals, iree-stream-verify-lowering-to-cmd, canonicalize{  max-iterations=10 region-simplify=true top-down=true}, cse, func.func(iree-util-simplify-global-accesses),util.initializer(iree-util-simplify-global-accesses), iree-util-apply-patterns, iree-util-fold-globals, iree-util-fuse-globals, iree-stream-fuse-dispatch-bindings{alias-mutable-bindings=false}, iree-stream-pack-dispatch-operands, cse, iree-stream-fold-uniform-operands, iree-stream-annotate-dispatch-arguments, canonicalize{  max-iterations=10 region-simplify=true top-down=true}, cse, func.func(iree-util-simplify-global-accesses),util.initializer(iree-util-simplify-global-accesses), iree-util-apply-patterns, iree-util-fold-globals, iree-util-fuse-globals, symbol-dce, canonicalize{  max-iterations=10 region-simplify=true top-down=true}, cse, func.func(iree-util-simplify-global-accesses),util.initializer(iree-util-simplify-global-accesses), iree-util-apply-patterns, iree-util-fold-globals, iree-util-fuse-globals, iree-hal-assign-target-devices{targets=dylib}, iree-hal-verify-target-environment, iree-hal-materialize-interfaces, hal.executable(iree-hal-translate-executables), iree-hal-conversion, canonicalize{  max-iterations=10 region-simplify=true top-down=true}, cse, func.func(iree-util-simplify-global-accesses),util.initializer(iree-util-simplify-global-accesses), iree-util-apply-patterns, iree-util-fold-globals, iree-util-fuse-globals, iree-hal-link-executables, iree-hal-resolve-export-ordinals, iree-hal-materialize-resource-caches, func.func(iree-hal-inline-device-switches),util.initializer(iree-hal-inline-device-switches), iree-hal-memoize-device-queries, canonicalize{  max-iterations=10 region-simplify=true top-down=true}, cse, func.func(iree-util-simplify-global-accesses),util.initializer(iree-util-simplify-global-accesses), iree-util-apply-patterns, iree-util-fold-globals, iree-util-fuse-globals, func.func(iree-hal-elide-redundant-commands),util.initializer(iree-hal-elide-redundant-commands), lower-affine, iree-util-combine-initializers, canonicalize{  max-iterations=10 region-simplify=true top-down=true}, cse, func.func(iree-util-simplify-global-accesses),util.initializer(iree-util-simplify-global-accesses), iree-util-apply-patterns, iree-util-fold-globals, iree-util-fuse-globals, hal.executable(iree-hal-serialize-executables), symbol-dce, func.func(affine-loop-coalescing,scf-for-loop-canonicalization,loop-invariant-code-motion,convert-scf-to-cf),util.initializer(loop-invariant-code-motion,convert-scf-to-cf), iree-util-propagate-subranges, canonicalize{  max-iterations=10 region-simplify=true top-down=true}, cse, func.func(iree-util-simplify-global-accesses),util.initializer(iree-util-simplify-global-accesses), iree-util-apply-patterns, iree-util-fold-globals, iree-util-fuse-globals, iree-vm-conversion, vm.module(iree-vm-hoist-inlined-rodata,iree-vm-deduplicate-rodata,iree-vm-sink-global-buffer-loads,iree-vm-global-initialization), inline{default-pipeline= max-iterations=4 }, canonicalize{  max-iterations=10 region-simplify=true top-down=true}, cse, vm.module(iree-vm-drop-empty-module-initializers), symbol-dce, vm.module(iree-vm-sink-defining-ops), iree-util-drop-compiler-hints",
      disable_threading: false,
      verify_each: true
    }
  }
#-}

What component(s) does this issue relate to?

No response

Version information

Failing at 59602cce17b87beea9e8353b6f4e765b1e968c8b (and a version from last week as well).

Additional context

This call works with a batch_size <= 2 in the ASR model, but doesn’t work with batch_size = 1 outside it, which is strange. The example above is for batch_size = 3.

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 22 (12 by maintainers)

Commits related to this issue

Most upvoted comments

#9906 fixes the stack allocation issue. Now it fails during lowering to LLVM due to

    %377 = complex.create %364, %376 : complex<f32>
    %378 = builtin.unrealized_conversion_cast %377 : complex<f32> to !llvm.struct<(f32, f32)>

@rsuderman or @jpienaar anything that needs to be added to the LLVM lowering to handle this?

I worked around the issue and fixed the size computation that caused the assert in the initial IR but then codegen fails due to a large stack alloc:

D:\Dev\iree/../iree-tmp/rfft_min.mlir:9:10: error: 'builtin.module' op expected total size of stack allocation is not greater than 32768 bytes, but got 49152 bytes

So there may need to be some tuning here to get things sized right /cc @MaheshRavishankar.

Will look at that one, one I am passed the issue at Flow level. Thanks!

(if we can decompose things - which I think there’s a lower_complex pass in tf for this - that’s always going to produce better code - but for the times they remain we can at least carry them around without dying 😃