tensorflow: Can't use estimator + dataset and train for less than one epoch

TensorFlow 1.4 moves TF Dataset to core (tf.data.Dataset) and doc/tutorial suggest to use tf.estimator to train models.

However, as recommended at the end of this page, the Dataset object and its iterator must be instantiated inside the input_fn function. This means the iterations through the dataset will start over for each call to estimator.train(input_fn, steps). Thus, calling is with steps < number of samples in epoch, will lead to train the model on a subset of the dataset.

Thus my question. Is it possible to implement something like this with Estimator + Dataset:

for i in range(num_epochs):
    # Train for some steps
    estimator.train(input_fn=train_input_fn, steps=valid_freq)

   validation_iterator.
    # Evaluate on the validation set (steps=None, we evaluate on the full validation set)
   estimator.evaluate(input_fn=valid_input_fn)

without starting training samples iterations from scratch at each call to estimator.train(input_fn=train_input_fn, steps=valid_freq)

For example, unlike here, instantiate the Dataset and its iterator outside input_fn. I tried it but it does not work because then the input (from the dataset iterator) and the model (from the estimator model_fn) are not part of the same graph.

Thanks

About this issue

  • Original URL
  • State: closed
  • Created 7 years ago
  • Reactions: 6
  • Comments: 22 (20 by maintainers)

Most upvoted comments

We have the same issue: We have a large dataset which trains several hours per epoch, but wish to evaluate in between epochs. Unfortunately there is no resource sharing for local sessions, so iterators cannot be shared between sessions.

So these are the possibilities I see:

  • As far as I understand, dataset.skip() always skips the first lines, thus it is not usable for this case. Also the global step is not available during graph creation in the Estimator, so it cannot be used since it is stateful.
  • Allow local sessions to share resources, and keep the dataset iterator alive.
  • Prevent the graph and session from being destroyed, thus keeping the iterator. This might be problematic with memory usage.
  • Allow a way to use the main graph for both, evaluation and training (i.e. evaluation only uses a subgraph of training, there might be some tf.conds involved with a placeholder for switching the mode). The problem here is that saving the graph will result in a lot of unnececssary ops. (<- this is still my favorite, because it seems feasable)
  • Allow to skip entries on first training step after the session was created (e.g. iterator.skip_initial_steps), thus continuing training at about the i-th step (the shuffle buffer is not exactly the same then, but at least close to). This should be optimized, such that the dataset does not actually read the data, but only skip it.
  • Add a simple option to evaluate after exactly one epoch, thus ensure each sample was read. This unfortunately does not allow for evaluations in between, but is rather a workaround.

Are there any other suggestions?