tensorflow: train_and_evaluate does not preserve the Dataset iterator state across train/eval
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): N/A
- TensorFlow installed from (source or binary): binary
- TensorFlow version (use command below): v1.8.0-0-g93bc2e2072 1.8.0
- Python version: 3.6
- Bazel version (if compiling from source): N/A
- GCC/Compiler version (if compiling from source): N/A
- CUDA/cuDNN version: N/A
- GPU model and memory: N/A
- Exact command to reproduce: see below
Describe the problem
When the input function is based on the one-shot dataset iterator, the training phase always starts from the beginning of the iterator. That is, the iterator state gets reset in between the train/eval phases. Therefore, if the dataset is big enough, then the training would only see a subset of the data, which can be processed in eval_spec.throttle_secs
.
I think the issue is caused by the fact that the graph is persisted before transitioning to the next phase, and restored upon reentering training. However, I find the behaviour a bit counterintuitive, so if it is not a bug, it should be mentioned in the train_and_evaluate
docs.
Source code / logs
Here is a small example demonstrating the issue:
import tensorflow as tf
def input_fn(data):
dataset = tf.data.Dataset.from_tensor_slices(data)
dataset = dataset.batch(batch_size=1)
x = dataset.make_one_shot_iterator().get_next()
return {"x": tf.Print(x, [x])}, x
if __name__ == "__main__":
model = tf.estimator.LinearRegressor(feature_columns=[
tf.feature_column.numeric_column("x")
])
train_spec = tf.estimator.TrainSpec(
input_fn=lambda: input_fn(list(range(2**20))),
max_steps=2)
eval_spec = tf.estimator.EvalSpec(
input_fn=lambda: input_fn([42]),
steps=1,
start_delay_secs=1,
throttle_secs=1)
tf.logging.set_verbosity("INFO")
tf.train.create_global_step()
tf.estimator.train_and_evaluate(model, train_spec, eval_spec)
The code produces the following log output (I’ve omitted irrelevant lines):
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after 1 secs (eval_spec.throttle_secs) or training is finished.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
[0]
INFO:tensorflow:Saving checkpoints for 1 into /var/folders/wr/s7brqkzj74v4pmdwkwn19321l34xg2/T/tmpgtq1ap9_/model.ckpt.
INFO:tensorflow:loss = 0.0, step = 1
INFO:tensorflow:Loss for final step: 0.0.
...
INFO:tensorflow:Starting evaluation at 2018-05-03-16:21:51
...
INFO:tensorflow:Finished evaluation at 2018-05-03-16:21:51
...
INFO:tensorflow:Restoring parameters from /var/folders/wr/s7brqkzj74v4pmdwkwn19321l34xg2/T/tmpgtq1ap9_/model.ckpt-1
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
[0]
INFO:tensorflow:Saving checkpoints for 2 into /var/folders/wr/s7brqkzj74v4pmdwkwn19321l34xg2/T/tmpgtq1ap9_/model.ckpt.
INFO:tensorflow:loss = 0.0, step = 2
INFO:tensorflow:Loss for final step: 0.0.
...
About this issue
- Original URL
- State: closed
- Created 6 years ago
- Reactions: 1
- Comments: 23 (21 by maintainers)
I’m sorry, I think I miss context to understand what needs to be hijacked.