tensorflow: Can't change model attribute inside @tf.function with custom training loop.
Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): No
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Colab
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary): binary
- TensorFlow version (use command below): 2.2.0
- Python version: 3.7
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:
- GPU model and memory:
You can collect some of this information using our environment capture script You can also obtain the TensorFlow version with:
- TF 1.0:
python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
- TF 2.0:
python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"
Describe the current behavior
@tf.function with custom training loop can’t change the attribute of the model. For example, after 10k steps, I want to set some layer to non-trainable by model.layer[0].trainable = False
, or in case training GAN, after 10k steps, the discriminator start to train.
Describe the expected behavior Without @tf.function everything is fine but the training will slow compared with graph mode.
Standalone code to reproduce the issue Provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab/Jupyter/any notebook.
Here is a simple example that I take from custom_training_walkthough colab and modify a bit. (https://colab.research.google.com/drive/1Rv0YTBq6qISvt-EFebYr4U_RUdjdSWDe?usp=sharing).
The main code I changed is here:
if steps == 1:
print("trainable=True")
model.layer0.trainable = True
model.layer1.trainable = True
model.layer2.trainable = True
if steps == 100:
print("trainable=False")
model.layer0.trainable = False
model.layer1.trainable = False
model.layer2.trainable = False
Other info / logs Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
ValueError: tf.function-decorated function tried to create variables on non-first call.
I also have an example in training GAN where I enable discriminator training after N steps but nothing happened, just like the code inside the if condition is not run. See (https://github.com/TensorSpeech/TensorflowTTS/blob/master/examples/multiband_melgan/train_multiband_melgan.py#L157-L166). The workaround is that I MUST to training generator with N steps then resume to training both G and D.
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 15 (5 by maintainers)
Right. For any particular model & optimizer the variables will only be created once. But, a separate trace is created for each new combinations of python object-args. The problem is that for a stand-alone function the “no variables created after the first call” logic is global, it’s not per trace.
So the way to fix this is to move this to a method on the Model class. For methods the “no variables created after the first call” logic is tracked per object instead of per function.
(It looks like you’re writing a “train_step” method)
You can also create a separate function for each, but this is messier:
https://github.com/TensorSpeech/TensorFlowTTS/blob/a280b619ec399c37609d8d9db335ffdb0816dd2e/examples/multiband_melgan/train_multiband_melgan.py#L157
tf.function writes down all the operations that it encounters to create a Graph. It only picks up control flow ops if they act on a TenosrFlow object (Tensor or Variable) so my bet is: on that line both
if self.steps >= self.config["discriminator_train_start_steps"]:
arguments to>=
are python ints, not tensors. If you change a python value after tracing, tensorflow can’t see the new value.From your notebook, this is the problem:
For now, if you need to call a
tf.function
on multiple objects (that initialize their variables in the first call) you need to create a separate@tf.function
for each instance. Or define it as a method.See: https://github.com/keras-team/keras-io/pull/334
That explains the
non-first call.
error. But everything else is still true. You shouldn’t change an object’s state and expect an old trace to pick up the difference.As far as I can tell this doesn’t have anything to do with Keras specifically. (and it doesn’t even use Keras’s own tf.functioning logic in fit/evaluate)
tf.functions will only capture python state at the time when they are first traced, so changing any python attribute after the first time you call the
tf.function
’dgrad
won’t do anything. @mdanatg I suppose the tf.function guide might be insufficient, because it seems to only talk about python side effects w/in the tf.function ( https://www.tensorflow.org/guide/function#python_side_effects ) but doesn’t mention that changes to python state that happen outside the function will only be reflected in new traces.If you want the trainable state to be reflected w/in your tf.function, you need to make sure to use different traces when you want to set the model.trainable to True and when you wan’t to set it to False. You could do this by making your
grad
method take atrainable
python boolean argument, then set the model’s state at the start ofgrad
inside of the tf.function tracing. (Just make sure to reset trainable at the end of your trace)I had a look, I think there are two issues highlighted in this bug:
model(x, training=False)
and thenmodel(x, training=True)
leads to error. This looks like a bug, but it can be worked around by decoratingcall
with@tf.function
. Ideally, Keras should rebuild the graph when thetraining
argument changes.trainable
attribute of various layers is being changed? I think that would be the expected behavior.I think both cases should be handled without the need for extra workarounds, although they may require additional smarts inside Keras to handle that correctly. I do agree that users should not need to resort to building multiple versions of the same model.
@amahendrakar I agree that the native solution is create multiple forward functions for each condition. But don’t you think this solution is not good interms of coding style?. Especially when you want to apply very complicated training procedures such as GAN, etc. Is there any good solution to replace this native solution ?