tensorflow: Class Weights InvalidArgumentError

Hi, I am using class weights for my unbalanced dataset, and i am getting this error

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-38-797b60fa6b57> in <module>
      6     steps_per_epoch=STEPS_PER_EPOCH,
      7     validation_data=get_validation_dataset(),
----> 8     class_weight=class_weigths
      9 )

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    817         max_queue_size=max_queue_size,
    818         workers=workers,
--> 819         use_multiprocessing=use_multiprocessing)
    820 
    821   def evaluate(self,

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
    340                 mode=ModeKeys.TRAIN,
    341                 training_context=training_context,
--> 342                 total_epochs=epochs)
    343             cbks.make_logs(model, epoch_logs, training_result, ModeKeys.TRAIN)
    344 

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in run_one_epoch(model, iterator, execution_function, dataset_size, batch_size, strategy, steps_per_epoch, num_samples, mode, training_context, total_epochs)
    126         step=step, mode=mode, size=current_batch_size) as batch_logs:
    127       try:
--> 128         batch_outs = execution_function(iterator)
    129       except (StopIteration, errors.OutOfRangeError):
    130         # TODO(kaftan): File bug about tf function and errors.OutOfRangeError?

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in execution_function(input_fn)
     96     # `numpy` translates Tensors to values in Eager mode.
     97     return nest.map_structure(_non_none_constant_value,
---> 98                               distributed_function(input_fn))
     99 
    100   return execution_function

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/util/nest.py in map_structure(func, *structure, **kwargs)
    566 
    567   return pack_sequence_as(
--> 568       structure[0], [func(*x) for x in entries],
    569       expand_composites=expand_composites)
    570 

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/util/nest.py in <listcomp>(.0)
    566 
    567   return pack_sequence_as(
--> 568       structure[0], [func(*x) for x in entries],
    569       expand_composites=expand_composites)
    570 

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in _non_none_constant_value(v)
    128 
    129 def _non_none_constant_value(v):
--> 130   constant_value = tensor_util.constant_value(v)
    131   return constant_value if constant_value is not None else v
    132 

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/framework/tensor_util.py in constant_value(tensor, partial)
    820   """
    821   if isinstance(tensor, ops.EagerTensor):
--> 822     return tensor.numpy()
    823   if not is_tensor(tensor):
    824     return tensor

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py in numpy(self)
    940     """
    941     # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors.
--> 942     maybe_arr = self._numpy()  # pylint: disable=protected-access
    943     return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr
    944 

/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py in _numpy(self)
    908       return self._numpy_internal()
    909     except core._NotOkStatusException as e:
--> 910       six.raise_from(core._status_to_exception(e.code, e.message), None)
    911 
    912   @property

/opt/conda/lib/python3.7/site-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: {{function_node __inference_distributed_function_431665}} Compilation failure: Detected unsupported operations when trying to compile graph has_valid_nonscalar_shape_true_402714_const_0[] on XLA_TPU_JIT: DenseToDenseSetOperation (No registered 'DenseToDenseSetOperation' OpKernel for XLA_TPU_JIT devices compatible with node {{node has_invalid_dims/DenseToDenseSetOperation}}
	.  Registered:  device='CPU'; T in [DT_INT8]
  device='CPU'; T in [DT_INT16]
  device='CPU'; T in [DT_INT32]
  device='CPU'; T in [DT_INT64]
  device='CPU'; T in [DT_UINT8]
  device='CPU'; T in [DT_UINT16]
  device='CPU'; T in [DT_STRING]
){{node has_invalid_dims/DenseToDenseSetOperation}}
	 [[has_valid_nonscalar_shape]]
	 [[loss/dense_4_loss/weighted_loss/broadcast_weights/assert_broadcastable/is_valid_shape]]
	TPU compilation failed
	 [[tpu_compile_succeeded_assert/_2606253191027783421/_10]]

I am using Kaggle Notebook with TPU. And Tensorflow version is 2.1.0 and Keras version is 2.2.4-tf. Here’s the kaggle notebook to reproduce the error. Thanks.

About this issue

  • Original URL
  • State: closed
  • Created 4 years ago
  • Comments: 25

Most upvoted comments

When I change my class_weight from a Dict: class_weight = {0:1.0, 1:1.0} to a List class_weight = [1.0, 1.0], I get the following error:

  File "/Users/.../anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 66, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/Users/.../anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 815, in fit
    model=self)
  File "/Users/.../anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/data_adapter.py", line 1117, in __init__
    dataset = dataset.map(_make_class_weight_map_fn(class_weight))
  File "/Users/.../anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/data_adapter.py", line 1229, in _make_class_weight_map_fn
    class_ids = list(sorted(class_weight.keys()))
AttributeError: 'list' object has no attribute 'keys'

On the other hand, if I use a Dict, I get the following error:

  File "/Users/.../anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 66, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/Users/.../anaconda3/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 848, in fit
    tmp_logs = train_function(iterator)
  File "/Users/.../anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 580, in __call__
    result = self._call(*args, **kwds)
  File "/Users/.../anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 644, in _call
    return self._stateless_fn(*args, **kwds)
  File "/Users/.../anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2420, in __call__
    return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
  File "/Users/.../anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1665, in _filtered_call
    self.captured_inputs)
  File "/Users/.../anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1746, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "/Users/.../anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 598, in call
    ctx=ctx)
  File "/Users/.../anaconda3/lib/python3.7/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError:  indices[4] = 2 is not in [0, 2)
	 [[{{node GatherV2}}]]
	 [[IteratorGetNext]] [Op:__inference_train_function_10811]
Function call stack:
train_function

I’m using tensorflow 2.2.0.

Same issue as @kevinashaw with tf 2.4.1 - does anyone have a solution?

Same issue with TF 2.3

Same issue as @kevinashaw

Also struggling with this issue with TF 2.3

I also have the same problem (with same results as @kevinashaw ) for TF 2.3. Is it possible to re-open this ticket?

Same issue with TF 2.4.1 😕

Refer to this link it might be helpful. Update your class_weigths to a list instead of a dictionary class_weigths = [0.176754, 9.823246] the model is being trained with this change