tensorflow: Can't change model attribute inside @tf.function with custom training loop.

Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): No
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Colab
  • Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 2.2.0
  • Python version: 3.7
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version:
  • GPU model and memory:

You can collect some of this information using our environment capture script You can also obtain the TensorFlow version with:

  1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
  2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"

Describe the current behavior @tf.function with custom training loop can’t change the attribute of the model. For example, after 10k steps, I want to set some layer to non-trainable by model.layer[0].trainable = False, or in case training GAN, after 10k steps, the discriminator start to train.

Describe the expected behavior Without @tf.function everything is fine but the training will slow compared with graph mode.

Standalone code to reproduce the issue Provide a reproducible test case that is the bare minimum necessary to generate the problem. If possible, please share a link to Colab/Jupyter/any notebook.

Here is a simple example that I take from custom_training_walkthough colab and modify a bit. (https://colab.research.google.com/drive/1Rv0YTBq6qISvt-EFebYr4U_RUdjdSWDe?usp=sharing).

The main code I changed is here:

    if steps == 1:
      print("trainable=True")
      model.layer0.trainable = True
      model.layer1.trainable = True
      model.layer2.trainable = True

    if steps == 100:
      print("trainable=False")
      model.layer0.trainable = False
      model.layer1.trainable = False
      model.layer2.trainable = False  

Other info / logs Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.

ValueError: tf.function-decorated function tried to create variables on non-first call.

I also have an example in training GAN where I enable discriminator training after N steps but nothing happened, just like the code inside the if condition is not run. See (https://github.com/TensorSpeech/TensorflowTTS/blob/master/examples/multiband_melgan/train_multiband_melgan.py#L157-L166). The workaround is that I MUST to training generator with N steps then resume to training both G and D.

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 15 (5 by maintainers)

Most upvoted comments

all variables should be created once and used every time you call grad function. Is there any hidden variable creation?

Right. For any particular model & optimizer the variables will only be created once. But, a separate trace is created for each new combinations of python object-args. The problem is that for a stand-alone function the “no variables created after the first call” logic is global, it’s not per trace.

So the way to fix this is to move this to a method on the Model class. For methods the “no variables created after the first call” logic is tracked per object instead of per function.

(It looks like you’re writing a “train_step” method)

class MyModel(keras.Model):
    ...
    @tf.function
    def train_step(self, inputs, targets):
      with tf.GradientTape() as tape:
        loss_value = self.loss(inputs, targets, training=True)
      grads = tape.gradient(loss_value, model.trainable_variables)
      self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
    
    
model = MyModel(...)
model.compile(loss = loss, optimizer=optimizer)
model.train_step(...)

model2 = MyModel(...)
model2.compile(loss = loss, optimizer=optimizer)
model2.train_step(...)

You can also create a separate function for each, but this is messier:

def get_step_function()
    @tf.function
    def grad(model, inputs, targets, optimizer):
        with tf.GradientTape() as tape:
          loss_value = loss(model, inputs, targets, training=True)
        grads = tape.gradient(loss_value, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return grad

model1 = ...
model2 = ...

f1 = get_step_function()
f1(model1,...)

f2 = get_step_function()
f2(model1,...)

I also have an example in training GAN where I enable discriminator training after N steps but nothing happened, just like the code inside the if condition is not run. See (https://github.com/TensorSpeech/TensorflowTTS/blob/master/examples/multiband_melgan/train_multiband_melgan.py#L157-L166). The workaround is that I MUST to training generator with N steps then resume to training both G and D.

https://github.com/TensorSpeech/TensorFlowTTS/blob/a280b619ec399c37609d8d9db335ffdb0816dd2e/examples/multiband_melgan/train_multiband_melgan.py#L157

tf.function writes down all the operations that it encounters to create a Graph. It only picks up control flow ops if they act on a TenosrFlow object (Tensor or Variable) so my bet is: on that line both if self.steps >= self.config["discriminator_train_start_steps"]: arguments to >= are python ints, not tensors. If you change a python value after tracing, tensorflow can’t see the new value.

From your notebook, this is the problem:

@tf.function
def grad(model,...):
  ...

grad(model, ...)

model = SmallModel()

grad(model, ...)

For now, if you need to call a tf.function on multiple objects (that initialize their variables in the first call) you need to create a separate @tf.function for each instance. Or define it as a method.

See: https://github.com/keras-team/keras-io/pull/334

dense1 = tf.keras.layers.Dense(3)
dense2 = tf.keras.layers.Dense(3)

@tf.function
def fun(dense, input):
  return dense(input)
  
fun(dense1, input)
fun(dense2, input)
ValueError: tf.function-decorated function tried to create variables on non-first call.

That explains the non-first call. error. But everything else is still true. You shouldn’t change an object’s state and expect an old trace to pick up the difference.

As far as I can tell this doesn’t have anything to do with Keras specifically. (and it doesn’t even use Keras’s own tf.functioning logic in fit/evaluate)

tf.functions will only capture python state at the time when they are first traced, so changing any python attribute after the first time you call the tf.function’d grad won’t do anything. @mdanatg I suppose the tf.function guide might be insufficient, because it seems to only talk about python side effects w/in the tf.function ( https://www.tensorflow.org/guide/function#python_side_effects ) but doesn’t mention that changes to python state that happen outside the function will only be reflected in new traces.

If you want the trainable state to be reflected w/in your tf.function, you need to make sure to use different traces when you want to set the model.trainable to True and when you wan’t to set it to False. You could do this by making your grad method take a trainable python boolean argument, then set the model’s state at the start of grad inside of the tf.function tracing. (Just make sure to reset trainable at the end of your trace)

I had a look, I think there are two issues highlighted in this bug:

  1. calling model(x, training=False) and then model(x, training=True) leads to error. This looks like a bug, but it can be worked around by decorating call with @tf.function. Ideally, Keras should rebuild the graph when the training argument changes.
  2. it’s unclear whether changing the layer’s trainable attribute has the expected effect - @tomerk will Keras rebuild the graph when the trainable attribute of various layers is being changed? I think that would be the expected behavior.

I think both cases should be handled without the need for extra workarounds, although they may require additional smarts inside Keras to handle that correctly. I do agree that users should not need to resort to building multiple versions of the same model.

@amahendrakar I agree that the native solution is create multiple forward functions for each condition. But don’t you think this solution is not good interms of coding style?. Especially when you want to apply very complicated training procedures such as GAN, etc. Is there any good solution to replace this native solution ?