model-optimization: QAT model saving bug : KeyError: '__inference_depthwise_conv2d_layer_call_fn_126

Describe the bug Please download the scripts to reproduce from : https://drive.google.com/drive/folders/15cajAZ9sAZ2Uyix8sDVSYku6QCqDCec7?usp=sharing

Command to run : python sample_qat.py.

I have a simple model with input layer and a depthwise conv2d layer. I quantize this model by adding quantize_and_dequantize nodes at the input of depthwiseconv2d layer (commented in the code). When I save the model and load it back, I see the following

  File "/home/dperi/Downloads/py3/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 544, in <lambda>
    "function": lambda: self._recreate_function(proto.function),
  File "/home/dperi/Downloads/py3/lib/python3.6/site-packages/tensorflow/python/saved_model/load.py", line 586, in _recreate_function
    proto, self._concrete_functions), setattr
  File "/home/dperi/Downloads/py3/lib/python3.6/site-packages/tensorflow/python/saved_model/function_deserialization.py", line 295, in recreate_function
    concrete_function_objects.append(concrete_functions[concrete_function_name])
KeyError: '__inference_depthwise_conv2d_layer_call_and_return_conditional_losses_117'

System information

TensorFlow version (installed from source or binary): 2.5 (Tried with 2.6 as well)

TensorFlow Model Optimization version (installed from source or binary):

Saved model loading fails especially for Depthwise convolution. It works fine for regular conv.

About this issue

  • Original URL
  • State: open
  • Created 3 years ago
  • Reactions: 1
  • Comments: 19 (3 by maintainers)

Most upvoted comments

The best way to avoid this issue is to disable the layer tracing when creating the SavedModel, but you’ll have to manually define the serving_default function (this is the default name that is used in TF2ONNX).

@tf.function
def predict(*args, **kwargs):
  return model(*args, **kwargs)

arg_spec, kwarg_spec = model.save_spec()
model.save(path, save_traces=False, signatures={
  "serving_default": predict.get_concrete_function(*arg_spec, **kwarg_spec)
})