tensorflow: Setting class_weight in model.fit() with tf.data.Dataset causes error
System information
- Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04 / Windows 10
- TensorFlow installed from (source or binary): binary
- TensorFlow version (use command below): 2.3.0 / 2.4.1
- Python version: 3.6 / 3.7
- CUDA/cuDNN version: 10.1 / none
- GPU model and memory: RTX2080 / none
Describe the current behavior When a tf.data.Dataset is used in model.fit(), setting class_weight causes an error.
Describe the expected behavior No error occurs.
Standalone code to reproduce the issue
from tensorflow import keras
import tensorflow as tf
import numpy as np
def get_model():
inputs = keras.layers.Input(shape=(10, 10, 3))
x = keras.layers.Flatten()(inputs)
outputs = keras.layers.Dense(5)(x)
model = keras.Model(inputs, outputs)
model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(learning_rate=0.001))
return model
def map_fun(_):
dummy_image = np.zeros((10, 10, 3))
dummy_label = np.array([0, 0, 1, 0, 0])
return dummy_image, dummy_label
if __name__ == '__main__':
# dummy dataset
dataset = tf.data.Dataset.from_tensor_slices([1, 2]) # values are ignored, dummy data generated in map()
dataset = dataset.map(map_func=lambda x: tf.py_function(map_fun, [x], [tf.uint8, tf.uint8])).batch(2)
# dummy model
model = get_model()
# call fit() without class weights - ok
model.fit(dataset, epochs=1)
# define class weights
class_weight = {idx: weight for (idx, weight) in enumerate([1., 1., 1., 1., 1.])}
# transform dataset to iterator, call fit() with class weights - ok
model.fit(dataset.as_numpy_iterator(), class_weight=class_weight, epochs=1)
# call fit() with class weights on tf.data.Dataset - error
model.fit(dataset, class_weight=class_weight, epochs=1)
Error message
Traceback (most recent call last):
File "/data/sandbox/reproduce.py", line 39, in <module>
model.fit(dataset, class_weight=class_weight, epochs=1)
File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
return method(self, *args, **kwargs)
File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 1063, in fit
steps_per_execution=self._steps_per_execution)
File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/data_adapter.py", line 1122, in __init__
dataset = dataset.map(_make_class_weight_map_fn(class_weight))
File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1695, in map
return MapDataset(self, map_func, preserve_cardinality=True)
File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4045, in __init__
use_legacy_function=use_legacy_function)
File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3371, in __init__
self._function = wrapper_fn.get_concrete_function()
File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2939, in get_concrete_function
*args, **kwargs)
File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2906, in _get_concrete_function_garbage_collected
graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3075, in _create_graph_function
capture_by_value=self._capture_by_value),
File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3364, in wrapper_fn
ret = _wrapper_helper(*args)
File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3299, in _wrapper_helper
ret = autograph.tf_convert(func, ag_ctx)(*nested_args)
File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py", line 255, in wrapper
return converted_call(f, args, kwargs, options=options)
File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py", line 532, in converted_call
return _call_unconverted(f, args, kwargs, options)
File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py", line 339, in _call_unconverted
return f(*args, **kwargs)
File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/data_adapter.py", line 1314, in _class_weights_map_fn
if y.shape.rank > 2:
TypeError: '>' not supported between instances of 'NoneType' and 'int'
Process finished with exit code 1
About this issue
- Original URL
- State: closed
- Created 3 years ago
- Comments: 19 (6 by maintainers)
Ok, I will move to keras repo. Just that I am not passing
yseparately - it is part of theBatchDataset- as you say.The workaround for a somewhat related problem also works in this case: an additional call to map() to manually set the tensor shape makes it work.
I previously tried manually converting the outputs of the py_function to tensors and also manually setting their shapes, but it did not work, so the key here is the second call to map() to set the shapes after batch().
@tensortorch, Please take a look at this comment from similar issue and check if it helps. Thanks!