tensorflow: tf.case giving unexpected result in TF 1.0.1

What related GitHub issues or StackOverflow threads have you found by searching the web for your problem?

I posted this SO question on 2017-03-10 which was never answered: http://stackoverflow.com/questions/42728235/tensorflow-why-is-tf-case-giving-me-the-wrong-result

Environment info

Operating System: Linux 312e492cd9df 4.4.0-66-generic #87-Ubuntu SMP Fri Mar 3 15:29:05 UTC 2017 x86_64 x86_64 x86_64 GNU/Linux

Installed version of CUDA and cuDNN: none

Installed from: I’m running this on official tensorflow-devel Docker image for 1.0.1 (gcr.io/tensorflow/tensorflow:1.0.1-devel)

If possible, provide a minimal reproducible example (We usually don’t have time to read hundreds of lines of your code)

import tensorflow as tf

global_step = tf.Variable(0, dtype=tf.int64)
train_op = tf.assign(global_step, global_step + 1)

learning_rate = tf.Variable(0.1, dtype=tf.float32, name='learning_rate')

# Update the learning_rate tensor conditionally
# When global_step == 2, update to 0.01
# When global_step == 4, update to 0.001
case_tensors = [
    (tf.equal(global_step, 2), tf.constant(0.01, dtype=tf.float32)),
    (tf.equal(global_step, 4), tf.constant(0.001, dtype=tf.float32)),
]
cases = [(pred, lambda: fn_tensor) for pred, fn_tensor in case_tensors]
update = tf.case(cases, default=lambda: learning_rate)
updated_learning_rate = tf.assign(learning_rate, update)

print tf.__version__
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for _ in xrange(6):
        print sess.run([global_step, case_tensors, learning_rate, update, updated_learning_rate])
        sess.run(train_op)

What other attempted solutions have you tried?

None

Logs or other output that would be helpful

(If logs are large, please upload as attachment or provide link).

The above code prints the following output:

W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE3 instructions, but these are available on your machine and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
1.0.1
[0, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1, 0.1]
[1, [(False, 0.0099999998), (False, 0.001)], 0.1, 0.1, 0.1]
[2, [(True, 0.0099999998), (False, 0.001)], 0.001, 0.001, 0.001]
[3, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001, 0.001]
[4, [(False, 0.0099999998), (True, 0.001)], 0.001, 0.001, 0.001]
[5, [(False, 0.0099999998), (False, 0.001)], 0.001, 0.001, 0.001]

I expect that the learning rate should get set to 0.0099999998 when the global step reaches 2. However, even though the predicate for global_step==2 evaluates to True, the learning rate does not get set to 0.0099999998, but rather gets set to 0.001 instead.

About this issue

  • Original URL
  • State: closed
  • Created 7 years ago
  • Comments: 15 (4 by maintainers)

Most upvoted comments

This is not a TF issue.

Closures are defined over names and not over values (https://stackoverflow.com/a/13355291). In this case, all the lambdas ended up with the value the variable had at the end of the loop. A simple solution is to have a lambda generate the lambdas you are trying to iterate over and pass the iterated value as the first and only argument. The closure is now defined over a constant, or the result of the evaluation of the first lambda (which only depends on the argument, which is itself copied).

So the code that yields the results you are looking for looks like:

import tensorflow as tf

orig_label = tf.constant(0.046026)
label_bounds = [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08]
preds = [tf.less(orig_label, bound) for bound in label_bounds]
# Pair each predicate with a function returning the index of the predicate.
z = lambda x: lambda: tf.constant(x)
pred_fn_pairs = [(pred, z(i))
                 for i, pred in enumerate(preds)]
# If no predicate evaluates to true, default to returning the index after
# the index of the last predicate.
default = lambda: tf.constant(len(pred_fn_pairs))
case = tf.case(pred_fn_pairs, default=default, exclusive=False)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run([preds, case]))
print(list(enumerate(label_bounds)))
print(len(preds))

@knub and I agree with @kopekC. I incorrectly believed that the issue described was a TF issue. Rather, the problem code was constructed poorly, and this thread should remain closed.