tensorflow: Can not load tflite model
TF information
tf-nightly 1.15.0.dev20190730
Model Information
I use the bert’s run_classifier.py
to export the classifier model
1141 name_to_features = {
1142 "input_ids": tf.VarLenFeature(tf.int64),
1143 "input_mask": tf.VarLenFeature(tf.int64),
1144 "segment_ids": tf.VarLenFeature(tf.int64),
1145 "label_ids": tf.FixedLenFeature([], tf.int64)
1146 }
1147 serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(name_to_features)
1148 estimator.export_savedmodel(FLAGS.output_dir, serving_input_receiver_fn)
Then the estimator to export model.
Code to generate tflite model
import tensorflow as tf
import sys
saved_model_dir = sys.argv[1]
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS]
tflite_quant_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_quant_model)
Code to load tflite model
import tensorflow as tf
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
print(input_details)
output_details = interpreter.get_output_details()
print(output_details)
# Test model on random input data.
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)
When load tflite load model path I get error message
$python test_quant.py
WARNING: Logging before flag parsing goes to stderr.
W0801 15:01:28.343195 140634983794496 __init__.py:328] Limited tf.compat.v2.summary API due to missing TensorBoard installation.
INFO: Initialized TensorFlow Lite runtime.
Traceback (most recent call last):
File "test_quant.py", line 5, in <module>
interpreter = tf.lite.Interpreter(model_path="converted_model.tflite")
File "/home/guohuawu.wgh/miniconda3/envs/quant/lib/python3.6/site-packages/tensorflow_core/lite/python/interpreter.py", line 206, in __init__
model_path))
ValueError: Input array not provided for operation 'reshape'.
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Comments: 20 (5 by maintainers)
@wugh Hello, have you solved the problem? I met the same issue.
@zldrobit - Thanks; your instructions worked to fix this issue for me.
@suharshs I can readily reproduce this issue on TF2. Feel free to contact me if you need any info. LDAP is jbetker.
Hi all, I tried to debug the tensorflow source code to find the reason. It seems that while quantizing the model, toco(tflite_convert) misses putting the new_shape attribute in the ReshapeOptions of flatbuffers. The only difference in ReshapeOptions between the unquantized and the quantized model is whether to have the new_shape attribute. More interesting, the unquantized model barely has an empty list for new_shape. By inserting an empty list for new_shape, I can successfully run my quantized model. I also made a script fix_reshape.py to temporarily fix the issue, plz refer to Quantization section in onnx_tflite_yolov3
I have the same problem. Does tf2.0 support the reshape operation in tflite?