io: Training local TF model via ArrowStreamDataset raises "ValueError: Cannot take the length of shape with unknown rank."
Hello,
I am trying to use arrow files as data-sources for creating a TF model. I have been following and trying to reproduce The examples in this article, however I am running into trouble.
I include the code that I am using at the end of this post, which is the same code as in the article linked above, with the addition of generating a small CSV file as a datasource.
I am using tensorflow version 2.1.0 and tensorflow_io version 0.11.0.
When I execute the code below, the keras model fitting function raises ValueError: Cannot take the length of shape with unknown rank.
Can you please advise me on this issue? My final goal would be to set up a datasource via the ArrowStreamDataset to a tensorflow Estimator.
Here is the code:
from functools import partial
import numpy as np
import pandas as pd
import pyarrow.csv
import tensorflow as tf
import tensorflow_io.arrow as arrow_io
from sklearn.datasets import make_classification
from sklearn.preprocessing import StandardScaler
# Generate data
X, y = make_classification(n_samples=1_000, n_features=2, n_classes=2, n_redundant=0, n_repeated=0, scale=100)
df = pd.DataFrame(data=X, columns=['x0', 'x1'])
df['label'] = y
df.to_csv(path_or_buf='./testing.csv', index=False)
# Define the model training function
def model_fit(ds):
"""Create and fit a Keras logistic regression model."""
# Build the Keras model
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(1, input_shape=(2,), activation='sigmoid'))
model.compile(optimizer='sgd', loss='mean_squared_error', metrics=['accuracy'])
# Fit the model on the given dataset
model.fit(ds, epochs=5, shuffle=False)
return model
# Function to read a CSV into Arrow data in batches
def read_and_process(filename):
opts = pyarrow.csv.ReadOptions(block_size=200)
table = pyarrow.csv.read_csv(filename, read_options=opts)
# Fit the feature transform
df = table.to_pandas()
scaler = StandardScaler().fit(df[['x0', 'x1']])
# Iterate over batches in the pyarrow.Table and apply processing
for batch in table.to_batches():
df = batch.to_pandas()
# Process the batch and apply feature transforms
X_scaled = scaler.transform(df[['x0', 'x1']])
df_scaled = pd.DataFrame({'label': df['label'],
'x0': X_scaled[:, 0],
'x1': X_scaled[:, 1]})
batch_scaled = pyarrow.RecordBatch.from_pandas(df_scaled, preserve_index=False)
yield batch_scaled
# Build an ArrowStreamDataset to be used with TF
def make_local_dataset(filename):
# Read the local file and get a record batch iterator
batch_iter = read_and_process(filename)
# Create the Arrow Dataset as a stream from local iterator of record batches
ds = arrow_io.ArrowStreamDataset.from_record_batches(batch_iter,
output_types=(tf.float64, tf.float64, tf.float64),
batch_mode='auto',
record_batch_iter_factory=partial(read_and_process,
filename))
# Map the dataset to combine feature columns to a single tensor
ds = ds.map(lambda l, x0, x1: (tf.stack([x0, x1], axis=1), l))
return ds
# Train a model
ds = make_local_dataset(filename='./testing.csv')
model = model_fit(ds) # <--- The issue is raised here
Any help/advice would be highly appreciated
Cheers.
About this issue
- Original URL
- State: closed
- Created 4 years ago
- Comments: 17 (10 by maintainers)
@JovanVeljanoski and I just found out that the nightly pip install seems to work, would be great to see that released!
@maartenbreddels the tensorflow-io’s master branch has already been upgraded to arrow 0.16.0. We could release another version of tensorflow-io almost any time (with arrow 0.16).
Hi BryanCutler,
thanks for your replies, and nice example @JovanVeljanoski, that made it easy to investigate.
Using the ArrowDataset (without stream), gives a clearer error message:
Which I traced to https://github.com/apache/arrow/blob/3bc01ec94eb2e310b28402a35196e1e8c5c9aec8/cpp/src/arrow/ipc/message.cc#L54
By downgrading till it works, I found it works with (py)arrow 0.14. What are the plans to go to 0.16 (which should be almost 1.0)? In vaex (https://github.com/vaexio/vaex/) where we plan to use this, we are requiring 0.15, and soon 0.16. Awesome package btw!
Regards,
Maarten