onnxruntime: [Mobile] NNAPI not work even if the Model Usability Checker give a positive result
Describe the issue
I converted madlad 3b (without kv-cache, divided into encoder and decoder) to onnx using the pytorch conversion tool (torch.onnx.export, with the axes fixed at 128) and performed the static quantization of the decoder in int8 (using hf optimum and leaving the Add, Softmax, Mul and Unsqueeze operators to fp32). Both the encoder (dynamically quantized) and the statically quantized decoder work perfectly with onnxruntime using CPU, also I wanted to use NNAPI on the quantized decoder, so I ran the Model Usability Checker tool on the decoder getting full compatibility with NNAPI:
I also verified that the operators mentioned in the caveats respected those conditions (and they respect them given that the quantization is static and the input axes are fixed).
However, running the same decoder on Android with NNAPI option:
OrtSession.SessionOptions decoderOptions = new OrtSession.SessionOptions();
EnumSet<NNAPIFlags> flags = EnumSet.of(NNAPIFlags.CPU_DISABLED);
decoderOptions.addNnapi(flags);
decoderOptions.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE);
decoderOptions.setSessionLogVerbosityLevel(0);
decoderSession = onnxEnv.createSession(decoderPath, decoderOptions);
I get much worse performance than running with CPU (around 700ms for each decoder run with CPU and 4000ms using NNAPI), by setting the session log to verbose I get the following result:
[W:onnxruntime:ort-java, nnapi_execution_provider.cc:225 GetCapability] NnapiExecutionProvider::GetCapability,
number of partitions supported by NNAPI: 322
number of nodes in the graph: 5067
number of nodes supported by NNAPI: 3126
In practice, the Model Usability Checker says that the model has only one partition supported by NNAPI and all nodes are supported, while the logger of the session performed on Android says that there are 322 partitions and that the number of supported partitions is 3126 out of 5067.
I tried using only the basic graph optimizations (like the Model Usability Checker does), I tried updating the op version from 11 to 20 with onnxruntime.tools.update_onnx_opset, converting the decoder to ort format, but nothing, the result is always the same.
To reproduce
Code of the static quantization:
def static_quantization_optimum():
model_dir = "onnx/Madlad/Script"
model_name = 'jbochi/madlad400-3b-mt'
quantizer = ORTQuantizer.from_pretrained(model_dir, file_name="Madlad_decoder_complete.onnx")
operators_to_quantize_in = ORT_DEFAULT_OPS_STATIC_QUANTIZATION_QOPS
operators_to_quantize_in.remove('Add')
operators_to_quantize_in.remove('Softmax')
operators_to_quantize_in.remove('Mul')
operators_to_quantize_in.remove('Unsqueeze')
format, mode, operators_to_quantize = default_quantization_parameters(True,operators_to_quantize=operators_to_quantize_in)
qconfig = QuantizationConfig(
is_static=True,
format=format,
mode=mode,
activations_dtype=QuantType.QUInt8,
activations_symmetric=True,
weights_dtype=QuantType.QInt8,
weights_symmetric=True,
per_channel=False,
reduce_range=False,
nodes_to_quantize=[],
nodes_to_exclude=[],
operators_to_quantize=operators_to_quantize,
)
# loading of the tokenizer
tokenizerEn = T5Tokenizer.from_pretrained(model_name)
# loading encoder session (to get the input "encoder_hidden_state" for the decoder)
providers = ['CPUExecutionProvider']
encoder_session = onnxruntime.InferenceSession("onnx/Madlad/Optimum/encoder_model.onnx",providers=providers)
# Create the calibration dataset
calibration_samples = 120 #38 per il preprocess_fn_test
calibration_dataset = quantizer.get_calibration_dataset(
"opus100",
dataset_config_name="en-it",
preprocess_function=functools.partial(preprocess_fn_Madlad, tokenizer=tokenizerEn,encoder_session=encoder_session),
num_samples=calibration_samples,
dataset_split="train",
# preprocess_batch=False
)
calibration_config = AutoCalibrationConfig.entropy(calibration_dataset)
# free the RAM of non useful resources
del encoder_session
# Perform the calibration step: computes the activations quantization ranges (RAM optimized)
shards = 4
for i in range(shards):
shard = calibration_dataset.shard(shards, i)
quantizer.partial_fit(
dataset=shard,
calibration_config=calibration_config,
operators_to_quantize=qconfig.operators_to_quantize,
batch_size=1, #calibration_samples//shards
use_external_data_format=True,
)
ranges = quantizer.compute_ranges()
# remove temp augmented model
os.remove("augmented_model.onnx")
model_quantized_path = quantizer.quantize(
save_dir="onnx/Madlad/Script/StaticQuantization/",
calibration_tensors_range=ranges,
quantization_config=qconfig,
use_external_data_format=True
)
If you need the onnx model of the quantized decoder let me know, I can upload it to my github and put the link in the comments
Urgency
Not so urgent
Platform
Android
OS Version
14 (api 34)
ONNX Runtime Installation
Released Package
Compiler Version (if ‘Built from Source’)
No response
Package Name (if ‘Released Package’)
onnxruntime-android
ONNX Runtime Version or Commit ID
1.17
ONNX Runtime API
Java/Kotlin
Architecture
ARM64
Execution Provider
NNAPI
Execution Provider Library Version
No response
About this issue
- Original URL
- State: closed
- Created 4 months ago
- Comments: 16 (6 by maintainers)
Yes I was using the CPU_DISABLED flag, without that I only get one available device (I imagine the CPU). At first I thought that most processors supported NNAPI (I couldn’t verify because there are no lists of compatible Socs), but now I have some doubts, so I decided to only use onnxruntime with the CPU and to use the kv-cache , that way I was able to get good performance. Thanks so much anyway for the help!