tensorflow: ImageDataGenerator does not work with tpu
Please make sure that this is a bug. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:bug_template
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): Yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Colab
- Mobile device (e.g. iPhone 8, Pixel 2, Samsung Galaxy) if the issue happens on mobile device:
- TensorFlow installed from (source or binary): tensorflow from colab
- TensorFlow version (use command below): 1.15
- Python version: 3.6.7
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version: None
- GPU model and memory: None
You can collect some of this information using our environment capture
script
You can also obtain the TensorFlow version with: 1. TF 1.0: python -c "import tensorflow as tf; print(tf.GIT_VERSION, tf.VERSION)"
2. TF 2.0: python -c "import tensorflow as tf; print(tf.version.GIT_VERSION, tf.version.VERSION)"
Describe the current behavior Currently it is not possible to use fit_generator. Instead we need to use .fit function.
Using ImageDataGenetor, I tried to fit the model but I got an issue :
AssertionError Traceback (most recent call last)
<ipython-input-18-46d4a00b20c5> in <module>()
5 train_gen.flow(X_train, y_train, batch_size=batch_size), # tf.data.Dataset.from_tensor_slices((X_train[:2000], y_train[:2000])).batch(batch_size).repeat(), #
6 steps_per_epoch= len(X_train)//batch_size,
----> 7 epochs=3,
8 )
2 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
725 max_queue_size=max_queue_size,
726 workers=workers,
--> 727 use_multiprocessing=use_multiprocessing)
728
729 def evaluate(self,
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_distributed.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, **kwargs)
617 validation_split=validation_split,
618 shuffle=shuffle,
--> 619 epochs=epochs)
620 if not dist_utils.is_distributing_by_cloning(model):
621 with model._distribution_strategy.scope():
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py in _distribution_standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, validation_split, shuffle, epochs, allow_partial_batch)
2312 x = ds.batch(batch_size, drop_remainder=drop_remainder)
2313 else:
-> 2314 assert isinstance(x, dataset_ops.DatasetV2)
2315 training_utils.validate_dataset_input(x, y, sample_weight,
2316 validation_split)
AssertionError:
Describe the expected behavior from the keras documentation, it should work.
Code to reproduce the issue Provide a reproducible test case that is the bare minimum necessary to generate the problem. Here the link to the colab code : https://colab.research.google.com/drive/1bdNfb127n-VI6ab9_sUTOgRzkil88a9n Other info / logs Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
Does that mean we cannot use generator with tensorflow 1.15 ?
About this issue
- Original URL
- State: open
- Created 5 years ago
- Comments: 52 (24 by maintainers)
The error that you’re seeing is expected, but I’ll admit we’ve not done a good job of explaining what’s going on, so let me give that a shot. When you use
Dataset.from_generator
(or pass a generator to Keras which will call it under the hood), the Dataset embeds the generator in a PyFunc op in its graph, and every time that op is invoked it calls next on the generator and gets the resultant bytes. (Basically treating Python as a black box.)When everything is running on the same machine this is fine, but the trouble is that the ways TPUs work is that there is a separate machine controlling the TPU (called, imaginatively, the TPU host controller. ^^), and you run things on the TPU by sending it a TensorFlow graph to execute. So the graph containing that PyFunc gets sent to the TPU, and the TPU can’t execute it because there is no Python on the TPU host machine. (And even if there was, it wouldn’t be the same interpreter with the same state as your local machine.) So it fails by telling you it can’t execute the PyFunc op, but not in a very clear way unfortunately.
Now I’m guessing the pattern that you expect is that the generator (or generator-like in case of a Keras Sequence) gets executed on the local machine and then the data is streamed to the TPU for ingestion and execution RPC-style. There is actually a Dataset that supports that functionality called the StreamingFilesDataset (https://github.com/tensorflow/tensorflow/blob/9b82752179bd4a61a9d33a638b8d9cb06adcf9e0/tensorflow/python/tpu/datasets.py#L50), which despite the name can also just stream in-memory data, but it is not a public symbol and is not actively maintained AFAIK. However if we think this is a workflow we want to support out of the box maybe it’s worth dusting it off and integrating it into Keras. @frankchn @tomerk @alextp WDYT? (And @saeta who wrote the thing in the first place.)
Hi,
I tried with the new version, the assertion is gone but a new error came with fit and fit_generator:
I am wondering if the ImageDataGenerator is compatible with TPU ?
Hi everyone, we have pushed a commit (https://github.com/tensorflow/tensorflow/commit/7bfb8380a7c09603259f49027374a4faf4199ad2) which place all PyFuncOp on the local host’s address space (defined by job/replica/task) if ops.executing_eagerly_outside_functions() returns True. This change should be able to fix the bug for this issue.
The patch should be included in TF 2.2.0 which is scheduled to have branch cut on 2/26.
Hi, you can do the following trick to install nightly 2.x in both vm and tpu worker.
@michaelbanfield is looking into the Cloud TPUs connection issue.
@mgmverburg can I take a look of your code sample ?
Sounds good. I can give this a shot.
Pinging @jsimsa for tf.data expertise.
I think cases like this show we do need to handle sources we cannot move in tf.data; it should still be possible to do prefetching / buffering / etc on the TPU controller to hide most of the latency, assuming the generator produces data quickly enough.