onnxruntime: [WebGPU] DETR Object Detection Model extremely slow due to multiple copies between CPU and GPU
Describe the issue
As described in the title, using the WebGPU execution provider is extremely slow for this DETR object detection model: https://huggingface.co/Xenova/detr-resnet-50. The unquantized version is around 160MB (quantized version is 40MB, but I am not testing that one here).
In comparison with the WASM EP, it’s around 2 times slower (WebGPU is ~14 seconds, WASM is ~8 seconds).
Enabling logging and profiling shows that this is due to numerous copies between the GPU and CPU during model execution.
Log file: webgpu_detr.log
Although, most of the time (12 seconds) can be traced back to a single copy (line 38689 of the log file):
log.ts:17 [V,2023-11-26T17:52:10.647Z][WebGPU] GpuDataManager.release(id=815), gpuDataId=815
>>> log.ts:17 [V,2023-11-26T17:52:10.648Z][WebGPU] jsepCopyGpuToCpu: gpuDataId=816, dataOffset=273302656, size=625
log.ts:17 [V,2023-11-26T17:52:22.976Z][WebGPU] GpuDataManager.release(id=816), gpuDataId=816
log.ts:17 [V,2023-11-26T17:52:22.976Z][WebGPU] GpuDataManager.create(size=625) => id=817
log.ts:17 [V,2023-11-26T17:52:22.976Z][WebGPU] jsepCopyCpuToGpu: dataOffset=345775616, gpuDataId=817, size=625
log.ts:17 [V,2023-11-26T17:52:22.977Z][WebGPU] GpuDataManager.upload(id=817)
log.ts:17 [V,2023-11-26T17:52:22.977Z][WebGPU] jsepRun: kernel=343388760, contextDataOffset=343725312
To reproduce
import 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.16.3/dist/ort.webgpu.min.js'
ort.env.wasm.wasmPaths='https://cdn.jsdelivr.net/npm/onnxruntime-web@1.16.3/dist/'
ort.env.logLevel = 'verbose';
ort.env.debug = true;
ort.env.webgpu.profilingMode = 'default';
const options = {
executionProviders: ['webgpu'],
graphOptimizationLevel: 'disabled',
}
const response = await fetch('https://huggingface.co/Xenova/detr-resnet-50/resolve/main/onnx/model.onnx');
const buffer = new Uint8Array(await response.arrayBuffer());
const session = await ort.InferenceSession.create(buffer, options);
console.log(session);
// Dummy inputs
const inputs = {
pixel_values: new ort.Tensor('float32', new Float32Array(1 * 3 * 800 * 800), [1, 3, 800, 800]),
pixel_mask: new ort.Tensor('int64', new BigInt64Array(1 * 64 * 64).fill(1n), [1, 64, 64])
}
const start = performance.now();
const output = await session.run(inputs);
const end = performance.now();
console.log(end - start, output);
Urgency
High - blocks WebGPU usage in Transformers.js
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.16.3
Execution Provider
‘webgpu’ (WebGPU)
About this issue
- Original URL
- State: closed
- Created 7 months ago
- Comments: 15 (10 by maintainers)
Commits related to this issue
- [js/webgpu] add bool type for Expand BUG #18584 — committed to qjia7/onnxruntime by qjia7 7 months ago
the cross-device copy qjia7 mentioned - yeah, silly, we need to fix those: (green gpu, amber cpu, red is cross-device copy)
the latest builds should be much faster for the 1st inference (time in ms):
and if you run this only once - we need to compile the shaders on the first inference which is pretty costly. There have been lots of improvement on this recently, but those are not in 1.16.3. Going to capture some more numbers.
Three issues in ORT-webgpu: 1)
CumSumwithfloattype is not registered in jsep. 2)Gatherwithbooltype is not registered in jsep. 3)Expandwithboolis running on cpu ep not jsep. Maybe we should also support it on jsep. What’s why there are some copies between the GPU and CPU during model execution.this should be fixed.
I used a modified version of netron https://github.com/guschmue/netron that takes a json file to annotate nodes. The json file from from the ort trace with this tool: https://github.com/guschmue/ort-web-perf/blob/master/ort-trace-color.py
on main as of 12/05 (time in ms):
I can confirm that recent builds (1.17.0-dev.20231128-a6d8726407) are significantly faster on the first run: 882ms now (and then 108ms for 2nd run).
I’ve been testing a few more things to improve speed, like keeping the input and output tensors on the GPU, but when trying to create a tensor on the GPU (following this tutorial), I get numerous errors, and the output is NaN. I also haven’t been able to find documentation on this. Could you perhaps provide an explanation and/or example code? Thanks!
going to test the model with fp16 as well
latest dev builds should have those improvements on the 1st run. We are adding support for CumSum so the right side of the cross device copies will go away which should help some more. Left side - looking at that one.
Currently the easiest way is to feed the model with dummy data (but correct shape) and run once to warm it up.