text: Error on saving keras custom layer with tensorflow_text.BertTokenizer

Trying so save a keras custom layers with tokenizer in it fails versions info:

tensorflow==2.1.0 tensorflow-text==2.1.1

Code to reproduce:


import tensorflow_text
import tensorflow as tf


class TokenizationLayer(tf.keras.layers.Layer):
    def __init__(self, vocab_path, **kwargs):
        self.vocab_path =vocab_path
        self.tokenizer = tensorflow_text.BertTokenizer(vocab_path, token_out_type=tf.int64)
        super(TokenizationLayer, self).__init__(**kwargs)
        
    def get_config(self):
        config = super(TokenizationLayer, self).get_config()
        config.update({
            'vocab_path': self.vocab_path,
        })
        return config

    def call(self,inputs):
        return self.tokenizer.tokenize(inputs).to_tensor()


vocab_path = r"/home/resources/bert_en_uncased_L-12_H-768_A-12/1/assets/vocab.txt"
# tensorflow_text.BertTokenizer(vocab_lookup_table = vocab_path, token_out_type=tf.int64)
inputs = tf.keras.layers.Input(shape=(), dtype=tf.string)
tokenization_layer = TokenizationLayer(vocab_path)
outputs = tokenization_layer(inputs)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)

model.save("./test")

It also gives error on

 def call(self,inputs):
        return self.tokenizer.tokenize(inputs)

Error:

AssertionError                            Traceback (most recent call last)
<ipython-input-55-e49dd5ac9a41> in <module>
----> 1 model.save("./test")

~/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options)
   1006     """
   1007     save.save_model(self, filepath, overwrite, include_optimizer, save_format,
-> 1008                     signatures, options)
   1009 
   1010   def save_weights(self, filepath, overwrite=True, save_format=None):

~/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options)
    113   else:
    114     saved_model_save.save(model, filepath, overwrite, include_optimizer,
--> 115                           signatures, options)
    116 
    117 

~/.local/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/saved_model/save.py in save(model, filepath, overwrite, include_optimizer, signatures, options)
     76     # we use the default replica context here.
     77     with distribution_strategy_context._get_default_replica_context():  # pylint: disable=protected-access
---> 78       save_lib.save(model, filepath, signatures, options)
     79 
     80   if not include_optimizer:

~/.local/lib/python3.6/site-packages/tensorflow_core/python/saved_model/save.py in save(obj, export_dir, signatures, options)
    907   object_saver = util.TrackableSaver(checkpoint_graph_view)
    908   asset_info, exported_graph = _fill_meta_graph_def(
--> 909       meta_graph_def, saveable_view, signatures, options.namespace_whitelist)
    910   saved_model.saved_model_schema_version = (
    911       constants.SAVED_MODEL_SCHEMA_VERSION)

~/.local/lib/python3.6/site-packages/tensorflow_core/python/saved_model/save.py in _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions, namespace_whitelist)
    585 
    586   with exported_graph.as_default():
--> 587     signatures = _generate_signatures(signature_functions, resource_map)
    588     for concrete_function in saveable_view.concrete_functions:
    589       concrete_function.add_to_graph()

~/.local/lib/python3.6/site-packages/tensorflow_core/python/saved_model/save.py in _generate_signatures(signature_functions, resource_map)
    456             argument_inputs, signature_key, function.name))
    457     outputs = _call_function_with_mapped_captures(
--> 458         function, mapped_inputs, resource_map)
    459     signatures[signature_key] = signature_def_utils.build_signature_def(
    460         _tensor_dict_to_tensorinfo(exterior_argument_placeholders),

~/.local/lib/python3.6/site-packages/tensorflow_core/python/saved_model/save.py in _call_function_with_mapped_captures(function, args, resource_map)
    408   """Calls `function` in the exported graph, using mapped resource captures."""
    409   export_captures = _map_captures_to_created_tensors(
--> 410       function.graph.captures, resource_map)
    411   # Calls the function quite directly, since we have new captured resource
    412   # tensors we need to feed in which weren't part of the original function

~/.local/lib/python3.6/site-packages/tensorflow_core/python/saved_model/save.py in _map_captures_to_created_tensors(original_captures, resource_map)
    330            "be tracked by assigning them to an attribute of a tracked object "
    331            "or assigned to an attribute of the main object directly.")
--> 332           .format(interior))
    333     export_captures.append(mapped_resource)
    334   return export_captures

AssertionError: Tried to export a function which references untracked object Tensor("StatefulPartitionedCall/args_1:0", shape=(), dtype=resource).TensorFlow objects (e.g. tf.Variable) captured by functions must be tracked by assigning them to an attribute of a tracked object or assigned to an attribute of the main object directly.

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Reactions: 4
  • Comments: 19 (7 by maintainers)

Commits related to this issue

Most upvoted comments

Hey all,

Unfortunately we had to push this out to ensure compatibility with DistributionStrategy. I’m working on it now and will have a fix in the nightly as soon as possible.

@Mistobaan now It’s a bit late not sure if still helpful, but you can hack your way around it like this.

colab

import tensorflow as tf
import tensorflow_text as tf_text

BERT_MAX_SEQ_LEN = 512
class BertTokenizer(tf.Module):
    def __init__(self, vocab_file_path, sequence_length=BERT_MAX_SEQ_LEN, lower_case=True):
        self.CLS_ID = tf.constant(101, dtype=tf.int64)
        self.SEP_ID = tf.constant(102, dtype=tf.int64)
        self.PAD_ID = tf.constant(0, dtype=tf.int64)

        self.sequence_length = tf.constant(sequence_length)

        vocab = self.load_vocab(vocab_file_path)


        # These two lines are basically what makes it work
        # assigning the vocab to a tf.Module and then later assigning the
        # intantiated Module to e.g. a Keras Model
        self.create_vocab_table(vocab)
        self.bert_tokenizer = tf_text.BertTokenizer(
            vocab_lookup_table=self.vocab_table,
            token_out_type=tf.int64,
            lower_case=lower_case,
        )

    def load_vocab(self, vocab_file):
        """Loads a vocabulary file into a list."""
        vocab = []
        with tf.io.gfile.GFile(vocab_file, "r") as reader:
            while True:
                token = reader.readline()
                if not token:
                    break
                token = token.strip()
                vocab.append(token)
        return vocab

    def create_vocab_table(self, vocab, num_oov=1):
        vocab_values = tf.range(tf.size(vocab, out_type=tf.int64), dtype=tf.int64)
        self.init = tf.lookup.KeyValueTensorInitializer(
            keys=vocab, values=vocab_values, key_dtype=tf.string, value_dtype=tf.int64
        )
        self.vocab_table = tf.lookup.StaticVocabularyTable(
            self.init, num_oov, lookup_key_dtype=tf.string
        )

    @tf.function
    def __call__(self, text: tf.Tensor) -> tf.Tensor:
        """
        Perform the BERT preprocessing from text -> input token ids
        """
        # Convert text into token ids
        tokens = self.bert_tokenizer.tokenize(text)

        # Flatten the ragged tensors
        tokens = tokens.merge_dims(1, 2)

        # Add start and end token ids to the id sequence
        start_tokens = tf.fill([tf.shape(text)[0], 1], self.CLS_ID)
        end_tokens = tf.fill([tf.shape(text)[0], 1], self.SEP_ID)
        tokens = tf.concat([start_tokens, tokens, end_tokens], axis=1)

        # Truncate to sequence length
        tokens = tokens[:, : self.sequence_length]

        # Convert ragged tensor to tensor and pad with PAD_ID
        tokens = tokens.to_tensor(default_value=self.PAD_ID)

        # Pad to sequence length
        pad = self.sequence_length - tf.shape(tokens)[1]
        tokens = tf.pad(tokens, [[0, 0], [0, pad]], constant_values=self.PAD_ID)

        return tf.reshape(tokens, [-1, self.sequence_length])


# Dummy model to show that serialization works
model = tf.keras.Sequential([
    tf.keras.Input(shape=(1,), dtype=tf.float32),
    tf.keras.layers.Dense(1)
])

model.tokenizer = BertTokenizer(vocab_file_path='./test_file')
model.save('./saved_model', signatures=model.tokenizer.__call__.get_concrete_function(tf.TensorSpec(None, tf.string)))

2.3.0 is now released, so I’m going to close this bug. Feel free to reopen if a problem arises.

Hey - we are currently working on a BertTokenizer Keras layer, as well as a Wordpiece Keras layer. We expect these to be part of the TF.Text 2.3 release.

fyi: pr https://github.com/tensorflow/text/pull/328 will resolve this issue

@markomernick Any updates regarding this ??

Im having the same issue but with tensorflow_text.SentencepieceTokenizer. Is the fix for all type of tokenizers or only Bert?