tensorflow-directml-plugin: `Transformers` : Error while fitting TFBertForSequenceClassification model

Hi,

First of all thanks a lot making directML compatible with TensorFlow > 2 ! As the Transformers lib needs TensorFlow >= 2.3 i had got hope my TFBertForSequenceClassification model was working with my AMD GPU card (Rx 6800)

Unfortuntly, this is not the case. (important detail : it takes a crazy long time but works on CPU with standard Tensorflow lib)

My env: windows 11 PRO 64bit : 21H2 python 3.8.13 tensorflow-cpu 2.9.1 tensorflow-directml-plugin 0.0.1.dev220621 CPU : Ryzen 5600X, GPU : Rx 6800

The GPU is rightly recognized:

tf.config.list_physical_devices('GPU')
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

My code :

bert_model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path,
                                                save_weights_only=True,
                                                monitor='val_loss',
                                                mode='min',
                                                save_best_only=True),
             keras.callbacks.TensorBoard(log_dir=log_dir)]

print('\nBert Model', bert_model.summary())

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = [tf.keras.metrics.SparseCategoricalAccuracy('accuracy')]
optimizer = tf.keras.optimizers.Adam(learning_rate=2e-5,epsilon=1e-08)

bert_model.compile(loss=loss, optimizer=optimizer, metrics=metric)

#OUTPUT : 
Model: "tf_bert_for_sequence_classification_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 bert (TFBertMainLayer)      multiple                  109482240 
                                                                 
 dropout_151 (Dropout)       multiple                  0         
                                                                 
 classifier (Dense)          multiple                  1538      
                                                                 
=================================================================
Total params: 109,483,778
Trainable params: 109,483,778
Non-trainable params: 0
_________________________________________________________________

Above code works well but when reaching the fit :

history=bert_model.fit([X_train, Mask_Train], 
                       y_train,
                       batch_size=32,
                       epochs=EPOCHS,
                       validation_data=([X_test, Mask_test], y_test),
                       callbacks=callbacks)

The GPU VRAM begins to receive data (i’m monitoring it via Radeon Adrenalin Software) and suddenly an error message (see below) appears !

The Error Message :

Epoch 1/3
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
File <timed exec>:2, in <module>

File ~\anaconda3\envs\P7_Bert_TF29_PYT38\lib\site-packages\keras\utils\traceback_utils.py:67,
 in filter_traceback.<locals>.error_handler(*args, **kwargs)
     65 except Exception as e:  # pylint: disable=broad-except
     66   filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67   raise e.with_traceback(filtered_tb) from None
     68 finally:
     69   del filtered_tb

File ~\anaconda3\envs\P7_Bert_TF29_PYT38\lib\site-packages\tensorflow\python\eager\execute.py:54, 
 in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     52 try:
     53   ctx.ensure_initialized()
---> 54   tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     55                                       inputs, attrs, num_outputs)
     56 except core._NotOkStatusException as e:
     57   if name is not None:

InvalidArgumentError: Cannot assign a device for operation tf_bert_for_sequence_classification_3/bert/embeddings/Gather: 
Could not satisfy explicit device specification '' 
because the node {{colocation_node tf_bert_for_sequence_classification_3/bert/embeddings/Gather}} was colocated 
with a group of nodes that required incompatible device '/job:localhost/replica:0/task:0/device:GPU:0'. 
All available devices [/job:localhost/replica:0/task:0/device:CPU:0, /job:localhost/replica:0/task:0/device:GPU:0]. 
Colocation Debug Info:
Colocation group had the following types and supported devices: 
Root Member(assigned_device_name_index_=2 requested_device_name_='/job:localhost/replica:0/task:0/device:GPU:0' 
assigned_device_name_='/job:localhost/replica:0/task:0/device:GPU:0' 
resource_device_name_='/job:localhost/replica:0/task:0/device:GPU:0' supported_device_types_=[CPU] possible_devices_=[]

StridedSlice: CPU 
Unique: GPU CPU 
Shape: GPU CPU 
_Arg: GPU CPU 
ResourceGather: GPU CPU 
Const: GPU CPU 
UnsortedSegmentSum: CPU 
Mul: GPU CPU 
ReadVariableOp: GPU CPU 
AssignVariableOp: GPU CPU 
ResourceScatterAdd: GPU CPU 
Sqrt: GPU CPU 
AddV2: GPU CPU 
RealDiv: GPU CPU 
AssignSubVariableOp: GPU CPU 
NoOp: GPU CPU 

Colocation members, user-requested devices, and framework assigned devices, if any:
  tf_bert_for_sequence_classification_3_bert_embeddings_gather_resource (_Arg)  
       framework assigned device=/job:localhost/replica:0/task:0/device:GPU:0
  adam_adam_update_readvariableop_resource (_Arg)  
       framework assigned device=/job:localhost/replica:0/task:0/device:GPU:0
  adam_adam_update_readvariableop_2_resource (_Arg)  
       framework assigned device=/job:localhost/replica:0/task:0/device:GPU:0
  tf_bert_for_sequence_classification_3/bert/embeddings/Gather (ResourceGather) 
  Adam/Adam/update/Unique (Unique) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/Shape (Shape) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/strided_slice/stack (Const) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/strided_slice/stack_1 (Const) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/strided_slice/stack_2 (Const) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/strided_slice (StridedSlice) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/UnsortedSegmentSum (UnsortedSegmentSum) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/mul (Mul) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/ReadVariableOp (ReadVariableOp) 
  Adam/Adam/update/mul_1 (Mul) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/AssignVariableOp (AssignVariableOp) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/ResourceScatterAdd (ResourceScatterAdd) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/ReadVariableOp_1 (ReadVariableOp) 
  Adam/Adam/update/mul_2 (Mul) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/mul_3 (Mul) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/ReadVariableOp_2 (ReadVariableOp) 
  Adam/Adam/update/mul_4 (Mul) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/AssignVariableOp_1 (AssignVariableOp) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/ResourceScatterAdd_1 (ResourceScatterAdd) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/ReadVariableOp_3 (ReadVariableOp) 
  Adam/Adam/update/Sqrt (Sqrt) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/mul_5 (Mul) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/add (AddV2) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/truediv (RealDiv) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/AssignSubVariableOp (AssignSubVariableOp) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/group_deps/NoOp (NoOp) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/group_deps/NoOp_1 (NoOp) /job:localhost/replica:0/task:0/device:GPU:0
  Adam/Adam/update/group_deps (NoOp) /job:localhost/replica:0/task:0/device:GPU:0

      [[{{node tf_bert_for_sequence_classification_3/bert/embeddings/Gather}}]] 
      [Op:__inference_train_function_57566]

Thanks in advance for any Help ! Have a good day.

About this issue

  • Original URL
  • State: closed
  • Created 2 years ago
  • Comments: 25

Most upvoted comments

Hi,

Well, so, i’ve just tested with tensorflow-cpu 2.10 and even before reaching the .fit(…) code i’ve got a Jupyter kernel crash, see below

image

It seems that the crash is occuring when trying to execute this line:

tf.keras.metrics.SparseCategoricalAccuracy('accuracy')

We have a the first nightly build that should contain the fixes: https://github.com/microsoft/tensorflow-directml-plugin/actions/runs/2887210312

Please try it out and let me know how it works for you! Note that we haven’t been focusing on Transformer models performance thus far so we don’t know how well we perform on them, but this is something we’ll be focusing on in the next few months. Our main concern right now is making sure that they run to completion without crashing 😃

Hey @ImRobot777, I can’t be 100% sure without having a full repro, but from the log above it looks like all that’s missing is the int32 data type support for StridedSlice and UnsortedSegmentSum. We already have implementations for those ops, so all that’s missing should be to add support for int32. We are currently planning to release a new version of the plugin shortly after the upcoming TensorFlow 2.10 release, so you shouldn’t have to wait too long.

If I provide you with a development wheel after the change has been merged in, would you be willing to try it out and let me know if it fixes your issue? Or if you’re comfortable with building from source that should also work.