tensorflow: tf.case giving unexpected result in TF 1.0.1
What related GitHub issues or StackOverflow threads have you found by searching the web for your problem?
I posted this SO question on 2017-03-10 which was never answered: http://stackoverflow.com/questions/42728235/tensorflow-why-is-tf-case-giving-me-the-wrong-result
Environment info
Operating System: Linux 312e492cd9df 4.4.0-66-generic #87-Ubuntu SMP Fri Mar 3 15:29:05 UTC 2017 x86_64 x86_64 x86_64 GNU/Linux
Installed version of CUDA and cuDNN: none
Installed from: I’m running this on official tensorflow-devel Docker image for 1.0.1 (gcr.io/tensorflow/tensorflow:1.0.1-devel
)
If possible, provide a minimal reproducible example (We usually don’t have time to read hundreds of lines of your code)
import tensorflow as tf
global_step = tf.Variable(0, dtype=tf.int64)
train_op = tf.assign(global_step, global_step + 1)
learning_rate = tf.Variable(0.1, dtype=tf.float32, name='learning_rate')
# Update the learning_rate tensor conditionally
# When global_step == 2, update to 0.01
# When global_step == 4, update to 0.001
case_tensors = [
(tf.equal(global_step, 2), tf.constant(0.01, dtype=tf.float32)),
(tf.equal(global_step, 4), tf.constant(0.001, dtype=tf.float32)),
]
cases = [(pred, lambda: fn_tensor) for pred, fn_tensor in case_tensors]
update = tf.case(cases, default=lambda: learning_rate)
updated_learning_rate = tf.assign(learning_rate, update)
print tf.__version__
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for _ in xrange(6):
print sess.run([global_step, case_tensors, learning_rate, update, updated_learning_rate])
sess.run(train_op)
What other attempted solutions have you tried?
None
Logs or other output that would be helpful
(If logs are large, please upload as attachment or provide link).
The above code prints the following output:
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE3 instructions, but these are available on your machine and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
1.0.1
[0, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1, 0.1]
[1, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1, 0.1]
[2, [(True, 0.0099999998), (False, 0.001)], 0.001, 0.001, 0.001]
[3, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001, 0.001]
[4, [(False, 0.0099999998), (True, 0.001)], 0.001, 0.001, 0.001]
[5, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001, 0.001]
I expect that the learning rate should get set to 0.0099999998
when the global step reaches 2. However, even though the predicate for global_step==2 evaluates to True, the learning rate does not get set to 0.0099999998
, but rather gets set to 0.001
instead.
About this issue
- Original URL
- State: closed
- Created 7 years ago
- Comments: 15 (4 by maintainers)
This is not a TF issue.
Closures are defined over names and not over values (https://stackoverflow.com/a/13355291). In this case, all the lambdas ended up with the value the variable had at the end of the loop. A simple solution is to have a lambda generate the lambdas you are trying to iterate over and pass the iterated value as the first and only argument. The closure is now defined over a constant, or the result of the evaluation of the first lambda (which only depends on the argument, which is itself copied).
So the code that yields the results you are looking for looks like:
@knub and I agree with @kopekC. I incorrectly believed that the issue described was a TF issue. Rather, the problem code was constructed poorly, and this thread should remain closed.