tensorflow: Cannot convert lstm with "stateful=True" to tflite

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):Window 7
  • TensorFlow installed from (source or binary):binary
  • TensorFlow version (or github SHA if from source):2.2.0-rc0

Command used to run the converter or code if you’re using the Python API If possible, please share a link to Colab/Jupyter/any notebook.

import tensorflow as tf
model= tf.keras.Sequential()
model.add(tf.keras.layers.Input(shape=(None, 32), batch_size=1,name='input'))
model.add(tf.keras.layers.LSTM(256, return_sequences=True, stateful=True))
model.compile(optimizer=tf.keras.optimizers.Adam(), loss='binary_crossentropy', metrics=['acc'])
print(model.input)
print(model_ctor.summary())
print(tf.__version__)
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,tf.lite.OpsSet.SELECT_TF_OPS]
converter.experimental_new_converter = True
tflite_model = converter.convert()

**Copy and paste the output here.

Tensor("input_12:0", shape=(1, None, 32), dtype=float32)
Model: "sequential_12"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
lstm_15 (LSTM)               (1, None, 256)            295936    
=================================================================
Total params: 295,936
Trainable params: 295,936
Non-trainable params: 0
_________________________________________________________________
None
2.2.0-rc0

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
c:\program files\python37\lib\site-packages\tensorflow\python\framework\importer.py in _import_graph_def_internal(graph_def, input_map, return_elements, validate_colocation_constraints, name, producer_op_list)
    496         results = c_api.TF_GraphImportGraphDefWithResults(
--> 497             graph._c_graph, serialized, options)  # pylint: disable=protected-access
    498         results = c_api_util.ScopedTFImportGraphDefResults(results)

InvalidArgumentError: Input 0 of node sequential_13/lstm_16/AssignVariableOp was passed float from sequential_13/lstm_16/23648:0 incompatible with expected resource.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-26-c3e3cbbc4fe1> in <module>
     10 converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,                                    tf.lite.OpsSet.SELECT_TF_OPS]
     11 converter.experimental_new_converter = True
---> 12 tflite_model = converter.convert()

c:\program files\python37\lib\site-packages\tensorflow\lite\python\lite.py in convert(self)
    462     frozen_func, graph_def = (
    463         _convert_to_constants.convert_variables_to_constants_v2_as_graph(
--> 464             self._funcs[0], lower_control_flow=False))
    465     input_tensors = [
    466         tensor for tensor in frozen_func.inputs

c:\program files\python37\lib\site-packages\tensorflow\python\framework\convert_to_constants.py in convert_variables_to_constants_v2_as_graph(func, lower_control_flow, aggressive_inlining)
    705   graph_def, converted_inputs = _convert_variables_to_constants_v2_impl(
    706       func, lower_control_flow, aggressive_inlining)
--> 707   frozen_func = _construct_concrete_function(func, graph_def, converted_inputs)
    708   return frozen_func, graph_def

c:\program files\python37\lib\site-packages\tensorflow\python\framework\convert_to_constants.py in _construct_concrete_function(func, output_graph_def, converted_input_indices)
    404   new_func = wrap_function.function_from_graph_def(output_graph_def,
    405                                                    new_input_names,
--> 406                                                    new_output_names)
    407 
    408   # Manually propagate shape for input tensors where the shape is not correctly

c:\program files\python37\lib\site-packages\tensorflow\python\eager\wrap_function.py in function_from_graph_def(graph_def, inputs, outputs)
    631     importer.import_graph_def(graph_def, name="")
    632 
--> 633   wrapped_import = wrap_function(_imports_graph_def, [])
    634   import_graph = wrapped_import.graph
    635   return wrapped_import.prune(

c:\program files\python37\lib\site-packages\tensorflow\python\eager\wrap_function.py in wrap_function(fn, signature, name)
    609           signature=signature,
    610           add_control_dependencies=False,
--> 611           collections={}),
    612       variable_holder=holder,
    613       signature=signature)

c:\program files\python37\lib\site-packages\tensorflow\python\framework\func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    979         _, original_func = tf_decorator.unwrap(python_func)
    980 
--> 981       func_outputs = python_func(*func_args, **func_kwargs)
    982 
    983       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

c:\program files\python37\lib\site-packages\tensorflow\python\eager\wrap_function.py in __call__(self, *args, **kwargs)
     84 
     85   def __call__(self, *args, **kwargs):
---> 86     return self.call_with_variable_creator_scope(self._fn)(*args, **kwargs)
     87 
     88   def call_with_variable_creator_scope(self, fn):

c:\program files\python37\lib\site-packages\tensorflow\python\eager\wrap_function.py in wrapped(*args, **kwargs)
     90     def wrapped(*args, **kwargs):
     91       with variable_scope.variable_creator_scope(self.variable_creator_scope):
---> 92         return fn(*args, **kwargs)
     93 
     94     return wrapped

c:\program files\python37\lib\site-packages\tensorflow\python\eager\wrap_function.py in _imports_graph_def()
    629 
    630   def _imports_graph_def():
--> 631     importer.import_graph_def(graph_def, name="")
    632 
    633   wrapped_import = wrap_function(_imports_graph_def, [])

c:\program files\python37\lib\site-packages\tensorflow\python\util\deprecation.py in new_func(*args, **kwargs)
    505                 'in a future version' if date is None else ('after %s' % date),
    506                 instructions)
--> 507       return func(*args, **kwargs)
    508 
    509     doc = _add_deprecated_arg_notice_to_docstring(

c:\program files\python37\lib\site-packages\tensorflow\python\framework\importer.py in import_graph_def(***failed resolving arguments***)
    403       return_elements=return_elements,
    404       name=name,
--> 405       producer_op_list=producer_op_list)
    406 
    407 

c:\program files\python37\lib\site-packages\tensorflow\python\framework\importer.py in _import_graph_def_internal(graph_def, input_map, return_elements, validate_colocation_constraints, name, producer_op_list)
    499       except errors.InvalidArgumentError as e:
    500         # Convert to ValueError for backwards compatibility.
--> 501         raise ValueError(str(e))
    502 
    503     # Create _DefinedFunctions for any imported functions.

ValueError: Input 0 of node sequential_13/lstm_16/AssignVariableOp was passed float from sequential_13/lstm_16/23648:0 incompatible with expected resource.


Also, please include a link to the saved model or GraphDef

# Put link here or attach to the issue.

Failure details

When I try to convert a lstm model with “stateful=False”, it can convert success.But when I change stateful to True, it convert fail. I need lstm with stateful. How can I do?

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Reactions: 4
  • Comments: 29 (16 by maintainers)

Most upvoted comments

@zrct0 Thanks for letting us know. We are looking into fixing the stateful conversion as well. But there is a workaround here, you can still use stateless Keras RNN layer and create a stateful layer on top by managing the state in your python program (not as part of TF). This should still be equivalent to a stateful keras. As a side, we are looking into the accuracy issues between stateful vs stateless

@ashwinmurthy I’m also getting this bug. Can you please give a small example for this workaround? Thx

This is what I managed to come up with. Haven’t setup training yet. But I imagine it will be painful with keep tracked of all the states. Looking forward to an update on this.

import tensorflow as tf
from tensorflow.keras.layers import LSTM, Dense, Flatten, TimeDistributed
from tensorflow.keras import Input, Model
from tensorflow.keras.optimizers import Adam
import numpy as np

input = Input(batch_shape=(1, 1, 3, 3))
hidden_state = Input(batch_shape=(1, 8))
cell_state = Input(batch_shape=(1, 8))
x = TimeDistributed(Flatten())(input)
x, new_hidden_state, new_cell_state = LSTM(units=8, return_state=True)(x, initial_state=[hidden_state, cell_state])
output = Dense(4)(x)
model = Model(inputs=[input, hidden_state, cell_state], outputs=[output, new_hidden_state, new_cell_state])
model.compile(loss='mse', optimizer=Adam(lr=0.01))

ob = np.random.random((1, 1, 3, 3))
hidden_state = np.random.random((1, 8))
output, new_hidden_state, new_cell_state = model.predict([ob, hidden_state, cell_state])
print(model.predict([ob, new_hidden_state, new_cell_state]))

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

@zrct0 Thanks for letting us know. We are looking into fixing the stateful conversion as well. But there is a workaround here, you can still use stateless Keras RNN layer and create a stateful layer on top by managing the state in your python program (not as part of TF). This should still be equivalent to a stateful keras. As a side, we are looking into the accuracy issues between stateful vs stateless

@ashwinmurthy I’m also getting this bug. Can you please give a small example for this workaround? Thx

@zrct0 how much is the accuracy drop when you moved to stateless?

The decrease in accuracy is in actual use, not in the test set, so I can’t quantify the decrease.