text: Error on saving keras custom layer with tensorflow_text.BertTokenizer
Trying so save a keras custom layers with tokenizer in it fails versions info:
tensorflow==2.1.0 tensorflow-text==2.1.1
Code to reproduce:
import tensorflow_text
import tensorflow as tf
class TokenizationLayer(tf.keras.layers.Layer):
def __init__(self, vocab_path, **kwargs):
self.vocab_path =vocab_path
self.tokenizer = tensorflow_text.BertTokenizer(vocab_path, token_out_type=tf.int64)
super(TokenizationLayer, self).__init__(**kwargs)
def get_config(self):
config = super(TokenizationLayer, self).get_config()
config.update({
'vocab_path': self.vocab_path,
})
return config
def call(self,inputs):
return self.tokenizer.tokenize(inputs).to_tensor()
vocab_path = r"/home/resources/bert_en_uncased_L-12_H-768_A-12/1/assets/vocab.txt"
# tensorflow_text.BertTokenizer(vocab_lookup_table = vocab_path, token_out_type=tf.int64)
inputs = tf.keras.layers.Input(shape=(), dtype=tf.string)
tokenization_layer = TokenizationLayer(vocab_path)
outputs = tokenization_layer(inputs)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
model.save("./test")
It also gives error on
def call(self,inputs):
return self.tokenizer.tokenize(inputs)
Error:
AssertionError Traceback (most recent call last)
<ipython-input-55-e49dd5ac9a41> in <module>
----> 1 model.save("./test")
~/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options)
1006 """
1007 save.save_model(self, filepath, overwrite, include_optimizer, save_format,
-> 1008 signatures, options)
1009
1010 def save_weights(self, filepath, overwrite=True, save_format=None):
~/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options)
113 else:
114 saved_model_save.save(model, filepath, overwrite, include_optimizer,
--> 115 signatures, options)
116
117
~/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/saved_model/save.py in save(model, filepath, overwrite, include_optimizer, signatures, options)
76 # we use the default replica context here.
77 with distribution_strategy_context._get_default_replica_context(): # pylint: disable=protected-access
---> 78 save_lib.save(model, filepath, signatures, options)
79
80 if not include_optimizer:
~/.local/lib/python3.6/site-packages/tensorflow_core/python/saved_model/save.py in save(obj, export_dir, signatures, options)
907 object_saver = util.TrackableSaver(checkpoint_graph_view)
908 asset_info, exported_graph = _fill_meta_graph_def(
--> 909 meta_graph_def, saveable_view, signatures, options.namespace_whitelist)
910 saved_model.saved_model_schema_version = (
911 constants.SAVED_MODEL_SCHEMA_VERSION)
~/.local/lib/python3.6/site-packages/tensorflow_core/python/saved_model/save.py in _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions, namespace_whitelist)
585
586 with exported_graph.as_default():
--> 587 signatures = _generate_signatures(signature_functions, resource_map)
588 for concrete_function in saveable_view.concrete_functions:
589 concrete_function.add_to_graph()
~/.local/lib/python3.6/site-packages/tensorflow_core/python/saved_model/save.py in _generate_signatures(signature_functions, resource_map)
456 argument_inputs, signature_key, function.name))
457 outputs = _call_function_with_mapped_captures(
--> 458 function, mapped_inputs, resource_map)
459 signatures[signature_key] = signature_def_utils.build_signature_def(
460 _tensor_dict_to_tensorinfo(exterior_argument_placeholders),
~/.local/lib/python3.6/site-packages/tensorflow_core/python/saved_model/save.py in _call_function_with_mapped_captures(function, args, resource_map)
408 """Calls `function` in the exported graph, using mapped resource captures."""
409 export_captures = _map_captures_to_created_tensors(
--> 410 function.graph.captures, resource_map)
411 # Calls the function quite directly, since we have new captured resource
412 # tensors we need to feed in which weren't part of the original function
~/.local/lib/python3.6/site-packages/tensorflow_core/python/saved_model/save.py in _map_captures_to_created_tensors(original_captures, resource_map)
330 "be tracked by assigning them to an attribute of a tracked object "
331 "or assigned to an attribute of the main object directly.")
--> 332 .format(interior))
333 export_captures.append(mapped_resource)
334 return export_captures
AssertionError: Tried to export a function which references untracked object Tensor("StatefulPartitionedCall/args_1:0", shape=(), dtype=resource).TensorFlow objects (e.g. tf.Variable) captured by functions must be tracked by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly.
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Reactions: 4
- Comments: 19 (7 by maintainers)
Commits related to this issue
- Have Tokenizers extend tf.Module This fixes #224 where tokenizers were unable to be saved from within custom Keras layers. PiperOrigin-RevId: 315784812 — committed to tensorflow/text by broken 4 years ago
- Have Tokenizers extend tf.Module This fixes #224 where tokenizers were unable to be saved from within custom Keras layers. PiperOrigin-RevId: 315784812 — committed to tensorflow/text by broken 4 years ago
- Have Tokenizers extend tf.Module This fixes #224 where tokenizers were unable to be saved from within custom Keras layers. PiperOrigin-RevId: 315784812 — committed to tensorflow/text by broken 4 years ago
Hey all,
Unfortunately we had to push this out to ensure compatibility with DistributionStrategy. I’m working on it now and will have a fix in the nightly as soon as possible.
@Mistobaan now It’s a bit late not sure if still helpful, but you can hack your way around it like this.
colab
2.3.0 is now released, so I’m going to close this bug. Feel free to reopen if a problem arises.
Hey - we are currently working on a BertTokenizer Keras layer, as well as a Wordpiece Keras layer. We expect these to be part of the TF.Text 2.3 release.
fyi: pr https://github.com/tensorflow/text/pull/328 will resolve this issue
@markomernick Any updates regarding this ??
Im having the same issue but with tensorflow_text.SentencepieceTokenizer. Is the fix for all type of tokenizers or only Bert?