tensorflow: Tf lite model fails to provide inference - Output tensor at index 0 is expected to have 3 dimensions, found 2
1. System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Colab
- TensorFlow installation (pip package or built from source): 2.6.0
- TensorFlow library (version, if pip package or github SHA, if built from source): all other libraries cloned from https://github.com/tensorflow/examples/tensorflow_examples/lite/model_maker/pip_package
2. Code
Trained a model using tf lite model maker. Attaching the colab with output but without input dataset below: https://colab.research.google.com/drive/1RVa019JFQmODFQNPrfmpOTaSi_GMwtlE?usp=sharing
Convert to tflite using model maker’s model.export()
3. Failure after conversion
Evaluates successfully with model maker’s model.evaluate_tflite But fails to run inference on colab as well as with tf object detection android demo app https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection/android
Python inference on colab:
#@title Load the trained TFLite model and define some visualization functions
import cv2
from PIL import Image
model_path = 'model.tflite'
# Load the labels into a list
classes = ['???'] * model.model_spec.config.num_classes
label_map = model.model_spec.config.label_map
for label_id, label_name in label_map.as_dict().items():
classes[label_id-1] = label_name
# Define a list of colors for visualization
COLORS = np.random.randint(0, 255, size=(len(classes), 3), dtype=np.uint8)
def preprocess_image(image_path, input_size):
"""Preprocess the input image to feed to the TFLite model"""
img = tf.io.read_file(image_path)
img = tf.io.decode_image(img, channels=3)
img = tf.image.convert_image_dtype(img, tf.uint8)
original_image = img
resized_img = tf.image.resize(img, input_size)
resized_img = resized_img[tf.newaxis, :]
return resized_img, original_image
def set_input_tensor(interpreter, image):
"""Set the input tensor."""
tensor_index = interpreter.get_input_details()[0]['index']
input_tensor = interpreter.tensor(tensor_index)()[0]
input_tensor[:, :] = image
def get_output_tensor(interpreter, index):
"""Retur the output tensor at the given index."""
output_details = interpreter.get_output_details()[index]
tensor = np.squeeze(interpreter.get_tensor(output_details['index']))
return tensor
def detect_objects(interpreter, image, threshold):
"""Returns a list of detection results, each a dictionary of object info."""
# Feed the input image to the model
set_input_tensor(interpreter, image)
interpreter.invoke()
# Get all outputs from the model
boxes = get_output_tensor(interpreter, 0)
classes = get_output_tensor(interpreter, 1)
scores = get_output_tensor(interpreter, 2)
count = int(get_output_tensor(interpreter, 3))
results = []
for i in range(count):
if scores[i] >= threshold:
result = {
'bounding_box': boxes[i],
'class_id': classes[i],
'score': scores[i]
}
results.append(result)
return results
def run_odt_and_draw_results(image_path, interpreter, threshold=0.5):
"""Run object detection on the input image and draw the detection results"""
# Load the input shape required by the model
_, input_height, input_width, _ = interpreter.get_input_details()[0]['shape']
# Load the input image and preprocess it
preprocessed_image, original_image = preprocess_image(
image_path,
(input_height, input_width)
)
# Run object detection on the input image
results = detect_objects(interpreter, preprocessed_image, threshold=threshold)
# Plot the detection results on the input image
original_image_np = original_image.numpy().astype(np.uint8)
for obj in results:
# Convert the object bounding box from relative coordinates to absolute
# coordinates based on the original image resolution
ymin, xmin, ymax, xmax = obj['bounding_box']
xmin = int(xmin * original_image_np.shape[1])
xmax = int(xmax * original_image_np.shape[1])
ymin = int(ymin * original_image_np.shape[0])
ymax = int(ymax * original_image_np.shape[0])
# Find the class index of the current object
class_id = int(obj['class_id'])
# Draw the bounding box and label on the image
color = [int(c) for c in COLORS[class_id]]
cv2.rectangle(original_image_np, (xmin, ymin), (xmax, ymax), color, 2)
# Make adjustments to make the label visible for all objects
y = ymin - 15 if ymin - 15 > 15 else ymin + 15
label = "{}: {:.0f}%".format(classes[class_id], obj['score'] * 100)
cv2.putText(original_image_np, label, (xmin, y),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
# Return the final image
original_uint8 = original_image_np.astype(np.uint8)
return original_uint8
#@title Run object detection and show the detection results
INPUT_IMAGE_URL = "https://storage.googleapis.com/cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg" #@param {type:"string"}
DETECTION_THRESHOLD = 0.3 #@param {type:"number"}
TEMP_FILE = '/tmp/image.png'
!wget -q -O $TEMP_FILE $INPUT_IMAGE_URL
im = Image.open(TEMP_FILE)
im.thumbnail((512, 512), Image.ANTIALIAS)
im.save(TEMP_FILE, 'PNG')
# Load the TFLite model
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
# Run inference and draw detection result on the local copy of the original file
detection_result_image = run_odt_and_draw_results(
TEMP_FILE,
interpreter,
threshold=DETECTION_THRESHOLD
)
# Show the detection result
Image.fromarray(detection_result_image)
Error:
TypeError Traceback (most recent call last) <ipython-input-19-2bfddfa5022e> in <module>() 19 TEMP_FILE, 20 interpreter, —> 21 threshold=DETECTION_THRESHOLD 22 ) 23
1 frames <ipython-input-18-568dde2e930c> in run_odt_and_draw_results(image_path, interpreter, threshold) 77 78 # Run object detection on the input image —> 79 results = detect_objects(interpreter, preprocessed_image, threshold=threshold) 80 81 # Plot the detection results on the input image
<ipython-input-18-568dde2e930c> in detect_objects(interpreter, image, threshold) 51 classes = get_output_tensor(interpreter, 1) 52 scores = get_output_tensor(interpreter, 2) —> 53 count = int(get_output_tensor(interpreter, 3)) 54 55 results = []
TypeError: only size-1 arrays can be converted to Python scalars
Android studio fatal exception:
E/AndroidRuntime: FATAL EXCEPTION: main Process: org.tensorflow.lite.examples.detection, PID: 27414 java.lang.AssertionError: Error occurred when initializing ObjectDetector: Output tensor at index 0 is expected to have 3 dimensions, found 2. at org.tensorflow.lite.task.vision.detector.ObjectDetector.initJniWithByteBuffer(Native Method) at org.tensorflow.lite.task.vision.detector.ObjectDetector.access$100(ObjectDetector.java:86) at org.tensorflow.lite.task.vision.detector.ObjectDetector$3.createHandle(ObjectDetector.java:211) at org.tensorflow.lite.task.core.TaskJniUtils.createHandleFromLibrary(TaskJniUtils.java:91) at org.tensorflow.lite.task.vision.detector.ObjectDetector.createFromBufferAndOptions(ObjectDetector.java:207) at org.tensorflow.lite.examples.detection.tflite.TFLiteObjectDetectionAPIModel.<init>(TFLiteObjectDetectionAPIModel.java:87) at org.tensorflow.lite.examples.detection.tflite.TFLiteObjectDetectionAPIModel.create(TFLiteObjectDetectionAPIModel.java:81) at org.tensorflow.lite.examples.detection.DetectorActivity.onPreviewSizeChosen(DetectorActivity.java:99) at org.tensorflow.lite.examples.detection.CameraActivity$7.onPreviewSizeChosen(CameraActivity.java:446) at org.tensorflow.lite.examples.detection.CameraConnectionFragment.setUpCameraOutputs(CameraConnectionFragment.java:357) at org.tensorflow.lite.examples.detection.CameraConnectionFragment.openCamera(CameraConnectionFragment.java:362) at org.tensorflow.lite.examples.detection.CameraConnectionFragment.access$300(CameraConnectionFragment.java:66) at org.tensorflow.lite.examples.detection.CameraConnectionFragment$3.onSurfaceTextureAvailable(CameraConnectionFragment.java:171) at android.view.TextureView.getTextureLayer(TextureView.java:415) at android.view.TextureView.draw(TextureView.java:360) at android.view.View.updateDisplayListIfDirty(View.java:21389) at android.view.View.draw(View.java:22254) at android.view.ViewGroup.drawChild(ViewGroup.java:4541) at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4302) at android.view.View.updateDisplayListIfDirty(View.java:21380) at android.view.View.draw(View.java:22254) at android.view.ViewGroup.drawChild(ViewGroup.java:4541) at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4302) at android.view.View.updateDisplayListIfDirty(View.java:21380) at android.view.View.draw(View.java:22254) at android.view.ViewGroup.drawChild(ViewGroup.java:4541) at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4302) at android.view.View.draw(View.java:22538) at android.view.View.updateDisplayListIfDirty(View.java:21389) at android.view.View.draw(View.java:22254) at android.view.ViewGroup.drawChild(ViewGroup.java:4541) at androidx.coordinatorlayout.widget.CoordinatorLayout.drawChild(CoordinatorLayout.java:1246) at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4302) at android.view.View.draw(View.java:22538) at android.view.View.updateDisplayListIfDirty(View.java:21389) at android.view.View.draw(View.java:22254) at android.view.ViewGroup.drawChild(ViewGroup.java:4541) at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4302) at android.view.View.updateDisplayListIfDirty(View.java:21380) at android.view.View.draw(View.java:22254) at android.view.ViewGroup.drawChild(ViewGroup.java:4541) at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4302) at android.view.View.updateDisplayListIfDirty(View.java:21380) at android.view.View.draw(View.java:22254) at android.view.ViewGroup.drawChild(ViewGroup.java:4541) at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4302) at android.view.View.updateDisplayListIfDirty(View.java:21380) at android.view.View.draw(View.java:22254) at android.view.ViewGroup.drawChild(ViewGroup.java:4541) at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4302) at android.view.View.updateDisplayListIfDirty(View.java:21380) at android.view.View.draw(View.java:22254) at android.view.ViewGroup.drawChild(ViewGroup.java:4541) at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4302) E/AndroidRuntime: at android.view.View.draw(View.java:22538) at com.android.internal.policy.DecorView.draw(DecorView.java:848) at android.view.View.updateDisplayListIfDirty(View.java:21389) at android.view.ThreadedRenderer.updateViewTreeDisplayList(ThreadedRenderer.java:559) at android.view.ThreadedRenderer.updateRootDisplayList(ThreadedRenderer.java:565) at android.view.ThreadedRenderer.draw(ThreadedRenderer.java:647) at android.view.ViewRootImpl.draw(ViewRootImpl.java:4417) at android.view.ViewRootImpl.performDraw(ViewRootImpl.java:4144) at android.view.ViewRootImpl.performTraversals(ViewRootImpl.java:3391) at android.view.ViewRootImpl.doTraversal(ViewRootImpl.java:2182) at android.view.ViewRootImpl$TraversalRunnable.run(ViewRootImpl.java:8730) at android.view.Choreographer$CallbackRecord.run(Choreographer.java:1352) at android.view.Choreographer.doCallbacks(Choreographer.java:1149) at android.view.Choreographer.doFrame(Choreographer.java:1049) at android.view.Choreographer$FrameHandler.handleMessage(Choreographer.java:1275) at android.os.Handler.dispatchMessage(Handler.java:106) at android.os.Looper.loop(Looper.java:233) at android.app.ActivityThread.main(ActivityThread.java:8010) at java.lang.reflect.Method.invoke(Native Method) at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:631) at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:978)
This is a new issue that has cropped up since tf version upgrade to 2.6.0 and model maker to 0.3.3 I have older models trained the same way that work perfectly fine on both colab and with the demo app. Is there something else that needs to be added while exporting the tflite model or is there some version conflict that is causing this. Does anybody have any clue about this error?
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Comments: 26 (4 by maintainers)
Commits related to this issue
- Swap indices in detect_objects, reflecting reordered output tensors Fixes issue https://github.com/tensorflow/tensorflow/issues/51591 — committed to jawj/tensorflow by jawj 3 years ago
- Update detect.py Order of output of tensors seem to have changed per https://github.com/tensorflow/tensorflow/issues/51591#issuecomment-923051299 — committed to HerbertSu/TFODRPi by HerbertSu a year ago
By adding some
print
statements indetect_objects
and eyeballing the data, it looks rather like the order of the output tensors has changed.I can make the thing run to completion by swapping around the indices in the middle section of
detect_objects
as follows:~That said, I’m getting utterly useless detection results out of this from my own model, and I haven’t yet figured out why, so there may be deeper problems.~
EDIT: The useless detection results I was getting turned out to be down to a silly mistake elsewhere. So this fix seems to be working well for me.
This saved my thesis! ❤️
I’m Having the same issue. I hope they can figure it out ASAP 😃
I had a same error ,then I change the TensorFlow version to 2.5.0 . In a few minutes, Android was running perfectly good luck
How on earth can this problem be solved?
@jawj I would try the same and provide an update. Thanks for figuring out a resolution.