tensorflow: model.trainable=False does nothing in tensorflow keras
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow):Yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 16.04
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary):source
- TensorFlow version (use command below):1.13
- Python version:2.7
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:10.1.105
- GPU model and memory:m60,8gb
It seems setting model.trainable=False in tensorflow keras does nothing except for to print a wrong model.summary(). Here is the code to reproduce the issue:
import tensorflow as tf
import numpy as np
IMG_SHAPE = (160, 160, 3)
# Create the base model from the pre-trained model MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet')
base_model.trainable = False
# for layer in base_model.layers:
# layer.trainable=False
bc=[] #before compile
ac=[] #after compile
for layer in base_model.layers:
bc.append(layer.trainable)
print(np.all(bc)) #True
print(base_model.summary()) ##this changes to show no trainable parameters but that is wrong given the output to previous np.all(bc)
base_model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])
for layer in base_model.layers:
ac.append(layer.trainable)
print(np.all(ac)) #True
print(base_model.summary()) #this changes to show no trainable parameters but that is wrong given the output to previous np.all(ac)
About this issue
- Original URL
- State: closed
- Created 5 years ago
- Comments: 23 (11 by maintainers)
I think the problem is actually that the individual layers in the model stay trainable when the model is set to be non trainable.
yes this is because confusingly, the behavior of model.trainable is different in keras vs tf.keras: in keras it only impacts the variable of the model layer, without impacting the variable of all the sub layers, contrary to what happens in tf.keras. (Is it expected to have such different behavior between keras and tf.keras??) In the second summary, you effectively creates a new model which set model.trainable = True by default, which explains the different behavior. I find setting the trainable flag of the whole model very confusing, so I now only set it per sub layers. For example in tf.keras, if you do model.trainable = False and then model.layers[-4].trainable = True, the latter instruction has no effect as it does not change the trainable flag at the top of the model.