tensorflow: Problem with read and get batch from 2d array tfrecords dataset
URL(s) with the issue:
https://www.tensorflow.org/tutorials/load_data/tfrecord#tfrecord_files_using_tfdata
Description of issue (what needs changing):
Problem with read and get batch from 2d array tfrecords dataset
Clear description
Hello. I use Tensorflow 2.0 version. I have some problems with reading Tfrecords file when get batch.
First, this is my read_tfrecords.py file.
import tensorflow as tf
import os
from glob import glob
import numpy as np
def serialize_example(batch, list1, list2):
filename = "./train_set.tfrecords"
writer = tf.io.TFRecordWriter(filename)
for i in range(batch):
feature = {}
feature1 = np.load(list1[i])
feature2 = np.load(list2[i])
print('feature1 shape {} feature2 shape {}'.format(feature1.shape, feature2.shape))
feature['input'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature1.flatten()))
feature['target'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature2.flatten()))
features = tf.train.Features(feature=feature)
example = tf.train.Example(features=features)
serialized = example.SerializeToString()
writer.write(serialized)
print("{}th input {} target {} finished".format(i, list1[i], list2[i]))
list_inp = sorted(glob('./input/2d_magnitude/*'))
list_tar = sorted(glob('./target/2d_magnitude/*'))
print(len(list_inp))
serialize_example(len(list_inp), list_inp, list_tar)
My input and target shapes are 2d array (Material of dataset is spectrogram). Therefore, my Tfrecords file includes two features likes [number_of_dataset, x, y]. About 100,000 dataset was successfully saved as Tfrecords file.
And I have problem when I read Tfrecords file to get batch. This is my code read_tfrecords.py:
import tensorflow as tf
import os
import numpy as np
shuffle_buffer_size = 50000
batch_size = 10
record_file = '/data2/dataset/tfrecords/train_set.tfrecords'
raw_dataset = tf.data.TFRecordDataset(record_file)
print('raw_dataset', raw_dataset) # ==> raw_dataset <TFRecordDatasetV2 shapes: (), types: tf.string>
raw_dataset = raw_dataset.repeat()
print('repeat', raw_dataset) # ==> repeat <RepeatDataset shapes: (), types: tf.string>
raw_dataset = raw_dataset.shuffle(shuffle_buffer_size)
print('shuffle', raw_dataset) # ==> shuffle <ShuffleDataset shapes: (), types: tf.string>
raw_dataset = raw_dataset.batch(batch_size, drop_remainder=True)
print('batch', raw_dataset) # ==> batch <BatchDataset shapes: (10,), types: tf.string>
raw_example = next(iter(raw_dataset))
parsed = tf.train.Example.FromString(raw_example.numpy()) # ==> read_tfrecords.py:25: RuntimeWarning: Unexpected end-group tag: Not all data was converted
print('parsed', parsed) # ==> ''
input = parsed.features.feature['input'].float_list.value
print('input', input) # ==> []
target = parsed.features.feature['target'].float_list.value
print('target', target) # ==> []
Here are results from code:
raw_dataset <TFRecordDatasetV2 shapes: (), types: tf.string>
repeat <RepeatDataset shapes: (), types: tf.string>
shuffle <ShuffleDataset shapes: (), types: tf.string>
batch <BatchDataset shapes: (10,), types: tf.string>
read_tfrecords.py:25: RuntimeWarning: Unexpected end-group tag: Not all data was converted
parsed = tf.train.Example.FromString(raw_example.numpy())
parsed
input []
target []
As a result, I wonder how I get the batch from Tfrecords file to train. read_tfrecords.py:25: RuntimeWarning: Unexpected end-group tag: Not all data was converted Could you give advice? Thank you very much.
Usage example
Maybe…
raw_dataset = tf.data.TFRecordDataset(record_file)
raw_dataset = raw_dataset.repeat()
raw_dataset = raw_dataset.shuffle(shuffle_buffer_size)
raw_dataset = raw_dataset.batch(batch_size, drop_remainder=True)
raw_example = next(iter(raw_dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())
input = parsed.features.feature['input'].float_list.value
target = parsed.features.feature['target'].float_list.value
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 20 (10 by maintainers)
It looks like you are using
FixedLenFeatureto parse your features, but the features are sequences not scalars, so you need to useFixedLenSequenceFeatureinstead.