tensorflow: Setting class_weight in model.fit() with tf.data.Dataset causes error

System information

  • Have I written custom code (as opposed to using a stock example script provided in TensorFlow): yes
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04 / Windows 10
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 2.3.0 / 2.4.1
  • Python version: 3.6 / 3.7
  • CUDA/cuDNN version: 10.1 / none
  • GPU model and memory: RTX2080 / none

Describe the current behavior When a tf.data.Dataset is used in model.fit(), setting class_weight causes an error.

Describe the expected behavior No error occurs.

Standalone code to reproduce the issue

from tensorflow import keras
import tensorflow as tf
import numpy as np


def get_model():
    inputs = keras.layers.Input(shape=(10, 10, 3))
    x = keras.layers.Flatten()(inputs)
    outputs = keras.layers.Dense(5)(x)
    model = keras.Model(inputs, outputs)
    model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(learning_rate=0.001))
    return model


def map_fun(_):
    dummy_image = np.zeros((10, 10, 3))  
    dummy_label = np.array([0, 0, 1, 0, 0]) 
    return dummy_image, dummy_label


if __name__ == '__main__':
    # dummy dataset
    dataset = tf.data.Dataset.from_tensor_slices([1, 2])  # values are ignored, dummy data generated in map()
    dataset = dataset.map(map_func=lambda x: tf.py_function(map_fun, [x], [tf.uint8, tf.uint8])).batch(2)

    # dummy model
    model = get_model()

    # call fit() without class weights - ok
    model.fit(dataset, epochs=1)

    # define class weights
    class_weight = {idx: weight for (idx, weight) in enumerate([1., 1., 1., 1., 1.])}

    # transform dataset to iterator, call fit() with class weights - ok
    model.fit(dataset.as_numpy_iterator(), class_weight=class_weight, epochs=1)

    # call fit() with class weights on tf.data.Dataset - error
    model.fit(dataset, class_weight=class_weight, epochs=1)

Error message

Traceback (most recent call last):
  File "/data/sandbox/reproduce.py", line 39, in <module>
    model.fit(dataset, class_weight=class_weight, epochs=1)
  File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 1063, in fit
    steps_per_execution=self._steps_per_execution)
  File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/data_adapter.py", line 1122, in __init__
    dataset = dataset.map(_make_class_weight_map_fn(class_weight))
  File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1695, in map
    return MapDataset(self, map_func, preserve_cardinality=True)
  File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4045, in __init__
    use_legacy_function=use_legacy_function)
  File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3371, in __init__
    self._function = wrapper_fn.get_concrete_function()
  File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2939, in get_concrete_function
    *args, **kwargs)
  File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2906, in _get_concrete_function_garbage_collected
    graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
  File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3213, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3075, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 986, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3364, in wrapper_fn
    ret = _wrapper_helper(*args)
  File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 3299, in _wrapper_helper
    ret = autograph.tf_convert(func, ag_ctx)(*nested_args)
  File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py", line 255, in wrapper
    return converted_call(f, args, kwargs, options=options)
  File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py", line 532, in converted_call
    return _call_unconverted(f, args, kwargs, options)
  File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py", line 339, in _call_unconverted
    return f(*args, **kwargs)
  File "/data/sandbox/venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/data_adapter.py", line 1314, in _class_weights_map_fn
    if y.shape.rank > 2:
TypeError: '>' not supported between instances of 'NoneType' and 'int'

Process finished with exit code 1

About this issue

  • Original URL
  • State: closed
  • Created 3 years ago
  • Comments: 19 (6 by maintainers)

Most upvoted comments

Ok, I will move to keras repo. Just that I am not passing y separately - it is part of the BatchDataset - as you say.

The workaround for a somewhat related problem also works in this case: an additional call to map() to manually set the tensor shape makes it work.

I previously tried manually converting the outputs of the py_function to tensors and also manually setting their shapes, but it did not work, so the key here is the second call to map() to set the shapes after batch().

@tensortorch, Please take a look at this comment from similar issue and check if it helps. Thanks!