tensorflow: Tf lite model fails to provide inference - Output tensor at index 0 is expected to have 3 dimensions, found 2

1. System information

2. Code

Trained a model using tf lite model maker. Attaching the colab with output but without input dataset below: https://colab.research.google.com/drive/1RVa019JFQmODFQNPrfmpOTaSi_GMwtlE?usp=sharing

Convert to tflite using model maker’s model.export()

3. Failure after conversion

Evaluates successfully with model maker’s model.evaluate_tflite But fails to run inference on colab as well as with tf object detection android demo app https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection/android

Python inference on colab:

#@title Load the trained TFLite model and define some visualization functions

import cv2

from PIL import Image

model_path = 'model.tflite'

# Load the labels into a list
classes = ['???'] * model.model_spec.config.num_classes
label_map = model.model_spec.config.label_map
for label_id, label_name in label_map.as_dict().items():
  classes[label_id-1] = label_name

# Define a list of colors for visualization
COLORS = np.random.randint(0, 255, size=(len(classes), 3), dtype=np.uint8)

def preprocess_image(image_path, input_size):
  """Preprocess the input image to feed to the TFLite model"""
  img = tf.io.read_file(image_path)
  img = tf.io.decode_image(img, channels=3)
  img = tf.image.convert_image_dtype(img, tf.uint8)
  original_image = img
  resized_img = tf.image.resize(img, input_size)
  resized_img = resized_img[tf.newaxis, :]
  return resized_img, original_image


def set_input_tensor(interpreter, image):
  """Set the input tensor."""
  tensor_index = interpreter.get_input_details()[0]['index']
  input_tensor = interpreter.tensor(tensor_index)()[0]
  input_tensor[:, :] = image


def get_output_tensor(interpreter, index):
  """Retur the output tensor at the given index."""
  output_details = interpreter.get_output_details()[index]
  tensor = np.squeeze(interpreter.get_tensor(output_details['index']))
  return tensor


def detect_objects(interpreter, image, threshold):
  """Returns a list of detection results, each a dictionary of object info."""
  # Feed the input image to the model
  set_input_tensor(interpreter, image)
  interpreter.invoke()

  # Get all outputs from the model
  boxes = get_output_tensor(interpreter, 0)
  classes = get_output_tensor(interpreter, 1)
  scores = get_output_tensor(interpreter, 2)
  count = int(get_output_tensor(interpreter, 3))

  results = []
  for i in range(count):
    if scores[i] >= threshold:
      result = {
        'bounding_box': boxes[i],
        'class_id': classes[i],
        'score': scores[i]
      }
      results.append(result)
  return results


def run_odt_and_draw_results(image_path, interpreter, threshold=0.5):
  """Run object detection on the input image and draw the detection results"""
  # Load the input shape required by the model
  _, input_height, input_width, _ = interpreter.get_input_details()[0]['shape']

  # Load the input image and preprocess it
  preprocessed_image, original_image = preprocess_image(
      image_path,
      (input_height, input_width)
    )

  # Run object detection on the input image
  results = detect_objects(interpreter, preprocessed_image, threshold=threshold)

  # Plot the detection results on the input image
  original_image_np = original_image.numpy().astype(np.uint8)
  for obj in results:
    # Convert the object bounding box from relative coordinates to absolute
    # coordinates based on the original image resolution
    ymin, xmin, ymax, xmax = obj['bounding_box']
    xmin = int(xmin * original_image_np.shape[1])
    xmax = int(xmax * original_image_np.shape[1])
    ymin = int(ymin * original_image_np.shape[0])
    ymax = int(ymax * original_image_np.shape[0])

    # Find the class index of the current object
    class_id = int(obj['class_id'])

    # Draw the bounding box and label on the image
    color = [int(c) for c in COLORS[class_id]]
    cv2.rectangle(original_image_np, (xmin, ymin), (xmax, ymax), color, 2)
    # Make adjustments to make the label visible for all objects
    y = ymin - 15 if ymin - 15 > 15 else ymin + 15
    label = "{}: {:.0f}%".format(classes[class_id], obj['score'] * 100)
    cv2.putText(original_image_np, label, (xmin, y),
        cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)

  # Return the final image
  original_uint8 = original_image_np.astype(np.uint8)
  return original_uint8
#@title Run object detection and show the detection results

INPUT_IMAGE_URL = "https://storage.googleapis.com/cloud-ml-data/img/openimage/3/2520/3916261642_0a504acd60_o.jpg" #@param {type:"string"}
DETECTION_THRESHOLD = 0.3 #@param {type:"number"}

TEMP_FILE = '/tmp/image.png'

!wget -q -O $TEMP_FILE $INPUT_IMAGE_URL
im = Image.open(TEMP_FILE)
im.thumbnail((512, 512), Image.ANTIALIAS)
im.save(TEMP_FILE, 'PNG')

# Load the TFLite model
interpreter = tf.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()

# Run inference and draw detection result on the local copy of the original file
detection_result_image = run_odt_and_draw_results(
    TEMP_FILE,
    interpreter,
    threshold=DETECTION_THRESHOLD
)

# Show the detection result
Image.fromarray(detection_result_image)

Error:


TypeError Traceback (most recent call last) <ipython-input-19-2bfddfa5022e> in <module>() 19 TEMP_FILE, 20 interpreter, —> 21 threshold=DETECTION_THRESHOLD 22 ) 23

1 frames <ipython-input-18-568dde2e930c> in run_odt_and_draw_results(image_path, interpreter, threshold) 77 78 # Run object detection on the input image —> 79 results = detect_objects(interpreter, preprocessed_image, threshold=threshold) 80 81 # Plot the detection results on the input image

<ipython-input-18-568dde2e930c> in detect_objects(interpreter, image, threshold) 51 classes = get_output_tensor(interpreter, 1) 52 scores = get_output_tensor(interpreter, 2) —> 53 count = int(get_output_tensor(interpreter, 3)) 54 55 results = []

TypeError: only size-1 arrays can be converted to Python scalars

Android studio fatal exception:

E/AndroidRuntime: FATAL EXCEPTION: main Process: org.tensorflow.lite.examples.detection, PID: 27414 java.lang.AssertionError: Error occurred when initializing ObjectDetector: Output tensor at index 0 is expected to have 3 dimensions, found 2. at org.tensorflow.lite.task.vision.detector.ObjectDetector.initJniWithByteBuffer(Native Method) at org.tensorflow.lite.task.vision.detector.ObjectDetector.access$100(ObjectDetector.java:86) at org.tensorflow.lite.task.vision.detector.ObjectDetector$3.createHandle(ObjectDetector.java:211) at org.tensorflow.lite.task.core.TaskJniUtils.createHandleFromLibrary(TaskJniUtils.java:91) at org.tensorflow.lite.task.vision.detector.ObjectDetector.createFromBufferAndOptions(ObjectDetector.java:207) at org.tensorflow.lite.examples.detection.tflite.TFLiteObjectDetectionAPIModel.<init>(TFLiteObjectDetectionAPIModel.java:87) at org.tensorflow.lite.examples.detection.tflite.TFLiteObjectDetectionAPIModel.create(TFLiteObjectDetectionAPIModel.java:81) at org.tensorflow.lite.examples.detection.DetectorActivity.onPreviewSizeChosen(DetectorActivity.java:99) at org.tensorflow.lite.examples.detection.CameraActivity$7.onPreviewSizeChosen(CameraActivity.java:446) at org.tensorflow.lite.examples.detection.CameraConnectionFragment.setUpCameraOutputs(CameraConnectionFragment.java:357) at org.tensorflow.lite.examples.detection.CameraConnectionFragment.openCamera(CameraConnectionFragment.java:362) at org.tensorflow.lite.examples.detection.CameraConnectionFragment.access$300(CameraConnectionFragment.java:66) at org.tensorflow.lite.examples.detection.CameraConnectionFragment$3.onSurfaceTextureAvailable(CameraConnectionFragment.java:171) at android.view.TextureView.getTextureLayer(TextureView.java:415) at android.view.TextureView.draw(TextureView.java:360) at android.view.View.updateDisplayListIfDirty(View.java:21389) at android.view.View.draw(View.java:22254) at android.view.ViewGroup.drawChild(ViewGroup.java:4541) at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4302) at android.view.View.updateDisplayListIfDirty(View.java:21380) at android.view.View.draw(View.java:22254) at android.view.ViewGroup.drawChild(ViewGroup.java:4541) at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4302) at android.view.View.updateDisplayListIfDirty(View.java:21380) at android.view.View.draw(View.java:22254) at android.view.ViewGroup.drawChild(ViewGroup.java:4541) at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4302) at android.view.View.draw(View.java:22538) at android.view.View.updateDisplayListIfDirty(View.java:21389) at android.view.View.draw(View.java:22254) at android.view.ViewGroup.drawChild(ViewGroup.java:4541) at androidx.coordinatorlayout.widget.CoordinatorLayout.drawChild(CoordinatorLayout.java:1246) at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4302) at android.view.View.draw(View.java:22538) at android.view.View.updateDisplayListIfDirty(View.java:21389) at android.view.View.draw(View.java:22254) at android.view.ViewGroup.drawChild(ViewGroup.java:4541) at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4302) at android.view.View.updateDisplayListIfDirty(View.java:21380) at android.view.View.draw(View.java:22254) at android.view.ViewGroup.drawChild(ViewGroup.java:4541) at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4302) at android.view.View.updateDisplayListIfDirty(View.java:21380) at android.view.View.draw(View.java:22254) at android.view.ViewGroup.drawChild(ViewGroup.java:4541) at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4302) at android.view.View.updateDisplayListIfDirty(View.java:21380) at android.view.View.draw(View.java:22254) at android.view.ViewGroup.drawChild(ViewGroup.java:4541) at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4302) at android.view.View.updateDisplayListIfDirty(View.java:21380) at android.view.View.draw(View.java:22254) at android.view.ViewGroup.drawChild(ViewGroup.java:4541) at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4302) E/AndroidRuntime: at android.view.View.draw(View.java:22538) at com.android.internal.policy.DecorView.draw(DecorView.java:848) at android.view.View.updateDisplayListIfDirty(View.java:21389) at android.view.ThreadedRenderer.updateViewTreeDisplayList(ThreadedRenderer.java:559) at android.view.ThreadedRenderer.updateRootDisplayList(ThreadedRenderer.java:565) at android.view.ThreadedRenderer.draw(ThreadedRenderer.java:647) at android.view.ViewRootImpl.draw(ViewRootImpl.java:4417) at android.view.ViewRootImpl.performDraw(ViewRootImpl.java:4144) at android.view.ViewRootImpl.performTraversals(ViewRootImpl.java:3391) at android.view.ViewRootImpl.doTraversal(ViewRootImpl.java:2182) at android.view.ViewRootImpl$TraversalRunnable.run(ViewRootImpl.java:8730) at android.view.Choreographer$CallbackRecord.run(Choreographer.java:1352) at android.view.Choreographer.doCallbacks(Choreographer.java:1149) at android.view.Choreographer.doFrame(Choreographer.java:1049) at android.view.Choreographer$FrameHandler.handleMessage(Choreographer.java:1275) at android.os.Handler.dispatchMessage(Handler.java:106) at android.os.Looper.loop(Looper.java:233) at android.app.ActivityThread.main(ActivityThread.java:8010) at java.lang.reflect.Method.invoke(Native Method) at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:631) at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:978)

This is a new issue that has cropped up since tf version upgrade to 2.6.0 and model maker to 0.3.3 I have older models trained the same way that work perfectly fine on both colab and with the demo app. Is there something else that needs to be added while exporting the tflite model or is there some version conflict that is causing this. Does anybody have any clue about this error?

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 26 (4 by maintainers)

Commits related to this issue

Most upvoted comments

By adding some print statements in detect_objects and eyeballing the data, it looks rather like the order of the output tensors has changed.

I can make the thing run to completion by swapping around the indices in the middle section of detect_objects as follows:

  # Get all outputs from the model
  boxes = get_output_tensor(interpreter, 1)
  classes = get_output_tensor(interpreter, 3)
  scores = get_output_tensor(interpreter, 0)
  count = int(get_output_tensor(interpreter, 2))

~That said, I’m getting utterly useless detection results out of this from my own model, and I haven’t yet figured out why, so there may be deeper problems.~

EDIT: The useless detection results I was getting turned out to be down to a silly mistake elsewhere. So this fix seems to be working well for me.

By adding some print statements in detect_objects and eyeballing the data, it looks rather like the order of the output tensors has changed.

I can make the thing run to completion by swapping around the indices in the middle section of detect_objects as follows:

  # Get all outputs from the model
  boxes = get_output_tensor(interpreter, 1)
  classes = get_output_tensor(interpreter, 3)
  scores = get_output_tensor(interpreter, 0)
  count = int(get_output_tensor(interpreter, 2))

~That said, I’m getting utterly useless detection results out of this from my own model, and I haven’t yet figured out why, so there may be deeper problems.~

EDIT: The useless detection results I was getting turned out to be down to a silly mistake elsewhere. So this fix seems to be working well for me.

This saved my thesis! ❤️

I’m Having the same issue. I hope they can figure it out ASAP 😃

I had a same error ,then I change the TensorFlow version to 2.5.0 . In a few minutes, Android was running perfectly good luck

How on earth can this problem be solved?

@jawj I would try the same and provide an update. Thanks for figuring out a resolution.