java: Error saving SavedModel Using the Exporter API
Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): No
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04 x86_64): Windows 10
- TensorFlow Python installation : pip install tensorflow==2.4.1
- TensorFlow Java version : 2.3.1
- Java version (i.e., the output of
java -version): Java 8 - Python version (if transferring a model trained in Python): python 3.8
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version: No Cuda
- GPU model and memory: No GPU
Describe the current behavior
I keep getting the Error No Operation named [StatefulPartitionedCall_2:0] in the Graph
when Using SavedModelBundle.exporter to save the model
ConcreteFunction serveFunction = savedModel.function("serve_model");
SavedModelBundle.exporter(exportDir)
.withFunction(serveFunction)
.export();
To access and inspect Graph operations, i can see the StatefulPartitionedCall_2
But without the : at the end of the operation name.
Iterator<Operation> operationIterator = serveFunction.graph().operations();
while(operationIterator.hasNext()){
System.out.println(operationIterator.next().name());
}
code snippet output
Adam/iter
Adam/iter/Read/ReadVariableOp
Adam/beta_1
Adam/beta_1/Read/ReadVariableOp
Adam/beta_2
...
...
...
train_model_labels
StatefulPartitionedCall_1
saver_filename
StatefulPartitionedCall_2
StatefulPartitionedCall_3
Works fine when invoking directly the Op from session.runner()
String checkpointPath = "...";
session.runner()
.feed("saver_filename:0", checkpointPath)
.fetch("StatefulPartitionedCall_2:0").run() ;
The model was saved using a regular tf.saved_model.save function
def save_module(module, model_dir):
tf.saved_model.save(module, model_dir,
signatures={
'serve_model':
module.__call__.get_concrete_function(signature_dict),
'train_model':
module.train.get_concrete_function(
signature_dict,
tf.TensorSpec(shape=[None, None], dtype=tf.int64, name="labels"))})
How can I reproduce the error:
Error could be reproduced using this scripts which defines than saves the model (credits to Thierry Herrmann)
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers
def make_model():
class CustomLayer(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
l2_reg = keras.regularizers.l2(0.1)
self.dense = layers.Dense(1, kernel_regularizer=l2_reg,
name='my_layer_dense')
def call(self, data):
return self.dense(data)
inputs = keras.Input(shape=(8,))
x1 = layers.Dense(30, activation="relu", name='my_dense')(inputs)
outputs = CustomLayer()(x1)
return keras.Model(inputs=inputs, outputs=outputs)
class CustomModule(tf.Module):
def __init__(self):
super(CustomModule, self).__init__()
self.model = make_model()
self.opt = keras.optimizers.Adam(learning_rate=0.001)
@tf.function(input_signature=[tf.TensorSpec([None, 8], tf.float32)])
def __call__(self, X):
return self.model(X)
# the my_train function processes one batch (one step): computes the loss and apply the
# loss gradient to update the model weights
@tf.function(input_signature=[tf.TensorSpec([None, 8], tf.float32), tf.TensorSpec([None], tf.float32)])
def my_train(self, X, y):
with tf.GradientTape() as tape:
logits = self.model(X, training=True)
main_loss = tf.reduce_mean(keras.losses.mean_squared_error(y, logits))
# self.model.losses contains the reularization loss (see l2_reg above)
loss_value = tf.add_n([main_loss] + self.model.losses)
grads = tape.gradient(loss_value, self.model.trainable_weights)
self.opt.apply_gradients(zip(grads, self.model.trainable_weights))
return loss_value
# instantiate the module
module = CustomModule()
def save_module(module, model_dir):
tf.saved_model.save(module, model_dir,
signatures={
'serve_model' :
module.__call__.get_concrete_function(tf.TensorSpec([None, 8], tf.float32)),
'train_model' :
module.my_train.get_concrete_function(tf.TensorSpec([None, 8], tf.float32),
tf.TensorSpec([None], tf.float32))})
MODEL_OUTPUT_DIR ="..."
save_module(module, MODEL_OUTPUT_DIR)
@karllessard could you please take a quick look ? Thank you
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Comments: 16 (2 by maintainers)
@rnett even if it is a bit unexpected to have a placeholder added to the init scope, it looks like it can happen using basic Keras like @HichemMaiza did above. Can we simply remove this validation, or maybe comment it out, to unblock any users who will face this problem?
I can take care of doing if you want, thanks.
Thanks @maziyarpanahi , this seems to be a different issue. I also suspect that the change did not happened in TF Java but in TF 2.7.0, I’ll start to investigate on this. Maybe we should track this new bug in a separate issue then?
UPDATE: I’ve created the issue: #434
@karllessard @rnett Thank you for your Help, it works with the last Update. For future Googlers pay attention to the saved signatures, if you are interested in many signatures remember to export them for the future
SavedModelBundle.load. Two signatures for me :Thank you @karllessard. @rnett you have just to run the script above. You need just to specify the
MODEL_OUTPUT_DIR = ""and hit run 😉@rnett, see Hichem’s previous comment above, just running this script will create the bundle for you
Will do. My guess is that somehow, there is a placeholder in the Python’s model init scope, which gets exported in the saved model. I would think this is a bug on their end, since I can’t think of a way that would work, but we’ll see. @HichemMaiza would it be possible for you to provide the saved model bundle?
Thanks @HichemMaiza , this init op exception is coming from a new feature that was added recently to the current snapshot, @rnett can you please take a look at this issue?
Hi @HichemMaiza , I’ve just created PR #388 to solve this issue, that was a very small fix. Thanks a lot for documenting the issue so well, that was very useful. I was able to reproduce the problem and now it’s working with the following code:
@karllessard Great News Thank you! Meanwhile, I will update the checkpoint using a direct access approach