burn: Continuous model inference may experience intermittent blocking

Describe the bug Continuous model inference may experience intermittent blocking

To Reproduce

burn = { version = "0.12.1", features = ["wgpu", "fusion"] }
resnet-burn = { path = "models/resnet-burn" }
use resnet_burn::model::resnet::ResNet;
use burn::backend::{Fusion, Wgpu};
use burn::module::Module;

type Backend = Fusion<Wgpu>;

fn main() {
    let device = Default::default();
    let model: ResNet<Backend, _> = ResNet::resnet50(2, &device);
    // Resize to 224x224
    let model = model.no_grad();
    let x = burn::tensor::Tensor::zeros([10, 3, 224, 224], &device).no_grad();
    for i in 0..5000 {
        let x = x.clone();
        let t = std::time::Instant::now();
        let y = model.forward(x);
        println!("{i}: cost time: {} ms {:?}", t.elapsed().as_nanos() as f64 / 1000000., y.shape());
    }
}

Screenshots

0: cost time: 19.9434 ms Shape { dims: [10, 2] }
1: cost time: 8.8942 ms Shape { dims: [10, 2] }
2: cost time: 8.2037 ms Shape { dims: [10, 2] }
3: cost time: 7.5952 ms Shape { dims: [10, 2] }
4: cost time: 7.6082 ms Shape { dims: [10, 2] }
5: cost time: 7.3022 ms Shape { dims: [10, 2] }
6: cost time: 7.7949 ms Shape { dims: [10, 2] }
7: cost time: 8.026 ms Shape { dims: [10, 2] }
...
53: cost time: 8.663 ms Shape { dims: [10, 2] }
54: cost time: 7.8703 ms Shape { dims: [10, 2] }
55: cost time: 9.1731 ms Shape { dims: [10, 2] }
56: cost time: 9.6698 ms Shape { dims: [10, 2] }
57: cost time: 9.4025 ms Shape { dims: [10, 2] }
58: cost time: 29458.3501 ms Shape { dims: [10, 2] }  <<<=============
59: cost time: 6.3148 ms Shape { dims: [10, 2] }
60: cost time: 6.0733 ms Shape { dims: [10, 2] }
61: cost time: 6.483 ms Shape { dims: [10, 2] }
62: cost time: 6.4576 ms Shape { dims: [10, 2] }
63: cost time: 7.126 ms Shape { dims: [10, 2] }
64: cost time: 6.9491 ms Shape { dims: [10, 2] }
65: cost time: 6.4537 ms Shape { dims: [10, 2] }
66: cost time: 6.5734 ms Shape { dims: [10, 2] }
67: cost time: 7.6258 ms Shape { dims: [10, 2] }
68: cost time: 6.7114 ms Shape { dims: [10, 2] }
...
105: cost time: 7.8975 ms Shape { dims: [10, 2] }
106: cost time: 7.7744 ms Shape { dims: [10, 2] }
107: cost time: 7.9532 ms Shape { dims: [10, 2] }
108: cost time: 9.6862 ms Shape { dims: [10, 2] }
109: cost time: 10.329 ms Shape { dims: [10, 2] }
110: cost time: 35719.9928 ms Shape { dims: [10, 2] } <<<=============
111: cost time: 6.9334 ms Shape { dims: [10, 2] }
112: cost time: 7.1723 ms Shape { dims: [10, 2] }
113: cost time: 7.4832 ms Shape { dims: [10, 2] }
114: cost time: 7.4207 ms Shape { dims: [10, 2] }
115: cost time: 7.1357 ms Shape { dims: [10, 2] }
...

Desktop (please complete the following information):

  • OS: Windows 11

About this issue

  • Original URL
  • State: closed
  • Created 3 months ago
  • Comments: 23 (11 by maintainers)

Most upvoted comments

Thank you for the work you’ve done. I’m looking forward to the next version.

The problem has been solved. I compared the inference speed of burn, onnxruntime, and pytorch on resnet50. Do you know the reason for the difference in speed between burn and them?

At this time we know that the wgpu backend is not completely optimized for vision ops. We wanted portability first and prepped the table for performance. As @nathanielsimard mentioned, we’re going to focus on that in the following releases 😃

@antimora on main branch: burn/crates/burn-wgpu/src/lib.rs

#[cfg(feature = "fusion")]
/// Tensor backend that uses the [wgpu] crate for executing GPU compute shaders.
///
/// This backend can target multiple graphics APIs, including:
///   - [Vulkan] on Linux, Windows, and Android.
///   - [OpenGL](crate::OpenGl) on Linux, Windows, and Android.
///   - [DirectX 12](crate::Dx12) on Windows.
///   - [Metal] on Apple hardware.
///   - [WebGPU](crate::WebGpu) on supported browsers and `wasm` runtimes.
///
/// # Notes
///
/// This version of the [wgpu] backend uses [burn_fusion] to compile and optimize streams of tensor
/// operations for improved performance.
///
/// You can disable the `fusion` feature flag to remove that functionality, which might be
/// necessary on `wasm` for now.
pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> =
    burn_fusion::Fusion<JitBackend<WgpuRuntime<G, F, I>>>;

The API didn’t exist yet, but I created a PR with the necessary changes: https://github.com/tracel-ai/burn/pull/1505/files