tensorflow: [C API] while loop: unable to access operations defined outside of the loop from within the loop
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): macOS Mojave 10.14.4
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device: N/A
- TensorFlow installed from (source or binary): source
- TensorFlow version (use command below): master
- Python version: Python 3.7.3
- Bazel version (if compiling from source): 0.24.1
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version: N/A
- GPU model and memory: N/A
Describe the current behavior
I am unable to access operations defined outside of the while loop from within the loop.
The C API while loop creates separate conditional and body graphs, so an error is thrown when we try to use operations defined in the outer graph within the body graph. See earlier discussion with @skye here.
Describe the expected behavior
I would expect the behavior of the C API while loop to match the Python API while loop, where accessing operations defined outside the while loop works.
I’ve included a minimal working example in Python that demonstrates the expected behavior below. In this example, we are able to use “increment” in the loop body even though it’s defined outside the loop.
import tensorflow as tf
increment = tf.constant(1, name='one')
def loop_cond(loop_var):
return tf.math.less(loop_var, 10)
def loop_body(loop_var):
return loop_var + increment
loop_input = tf.Variable(0, name='loop_input')
loop_output = tf.while_loop(loop_cond, loop_body, [loop_input])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(loop_input)) // should be 0
print(sess.run(loop_output)) // should be 10
Code to reproduce the issue
I replicated the example above as a unit test in while_loop_test.cc. Once again, we try to use “increment” within the loop body, but here we get an error since it’s not part of the body graph. The error message is “Node ‘add’: Unknown input node ‘scalar’”.
TEST_F(CApiWhileLoopTest, AccessOuterOp) {
Init(1);
// increment = 1
// while (i < 10) {
// i = i + increment
// }
// Create increment *in the outer graph*
TF_Operation* increment = ScalarConst(1, graph_, s_);
// Create cond graph: i < 10
TF_Operation* ten = ScalarConst(10, params_->cond_graph, s_);
TF_Operation* less_than = LessThan(params_->cond_inputs[0], {ten, 0}, params_->cond_graph, s_);
DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
params_->cond_output = {less_than, 0};
// Create body graph: i = i + increment
TF_Operation* add = Add(params_->body_inputs[0], {increment, 0}, params_->body_graph, s_);
ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
params_->body_outputs[0] = {add, 0};
ExpectOK();
}
This test can be copy-pasted into while_loop_test.cc and run with the following command: bazel run //tensorflow/c:while_loop_test.
Other info / logs The above examples are intended to be as minimal as possible, so they’re not practical. However, accessing outside operations would be important when updating external variables, for example when training within a loop (we would need to update external weights and biases).
We discovered this issue while using TF Java, after exposing the C API while loop to Java in this commit.
About this issue
- Original URL
- State: open
- Created 5 years ago
- Comments: 21 (21 by maintainers)
Sorry this got pushed to the back-burner since we have been focusing on fixing bugs for 2.0. I am planning to find some time over the next few weeks to design how FuncGraphs in the C API might look like. I think this will be important to pursue soon since we are starting to see issues in other areas e.g. C++ shape inference, colocations etc. because of missing nested graph information. Maybe I can send out a RFC so that the SIG JVM community can join in.
Got it, thanks for clarifying! I will give that a try and report back.
Hi Melissa, sorry for the delay. I should have looked into this earlier, but I just noticed that the “Unknown input node” message comes from importing a GraphDef: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/graph/graph_constructor.cc#L561
I’m guessing this means that the body graph can somehow already reference nodes from the outer graph (maybe because we usually check this in Python?), but then it complains when we try to import the body graph into the outer graph: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.cc#L2430 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.cc#L2282
A quick fix might be to just add all the nodes in the outer graph to the input_map of that import, not just the loop inputs (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/c/c_api.cc#L2266). Melissa, do you wanna give this a try and see if it works, or if we get more errors? There might be performance issues down the line creating such a large input_map, but we can start with this just to see if this is the only problem. Please let me know if you need more help or context.