tensorflow: Cast int32 to bfloat16 does not run on A100 GPU

Click to expand!

Issue Type

Feature Request

Have you reproduced the bug with TF nightly?

Yes

Source

source

Tensorflow Version

TensorFlow version 2.13.0-dev20230215

Custom Code

Yes

OS Platform and Distribution

ubuntu 20.04

Mobile device

No response

Python version

3.8

Bazel version

No response

GCC/Compiler version

No response

CUDA/cuDNN version

11.8/8.6

GPU model and memory

single A100 80G

Current Behaviour?

when using tf.cast to cast tf.int32 tensor to tf.bfloat16 tensor, op run on GPU.

when i convert int32->bfloat16, it run on CPU.
when i convert int32->float32->bfloat16, it run on GPU. is it expected ?

Standalone code to reproduce the issue

import tensorflow as tf
from tensorflow.keras import mixed_precision

tf.debugging.set_log_device_placement(True)
policy = mixed_precision.Policy('mixed_bfloat16')
print(policy.name)
mixed_precision.set_global_policy(policy)

class toy_layer(tf.keras.layers.Layer):
  def build(self, input_shape):
    self.kernel = self.add_weight('kernel', (input_shape[-1], 10))
  def call(self, inputs):
    out = tf.linalg.matmul(inputs, self.kernel)
    out2 = tf.ones((10, 10), dtype=tf.int32)
    #out2 = tf.cast(out2, tf.float32, name="cast_out2_1")
    out2 = tf.cast(out2, out.dtype, name="cast_out2_2")
    out3 = out * out2
    return out3

layer = toy_layer()
y = layer(tf.ones((10, 10)))

Relevant log output

2023-02-17 16:26:34.748024: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-17 16:26:35.370677: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
mixed_bfloat16
2023-02-17 16:26:36.876507: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1635] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78915 MB memory:  -> device: 0, name: NVIDIA A100 80GB PCIe, pci bus id: 0001:00:00.0, compute capability: 8.0
input: (_Arg): /job:localhost/replica:0/task:0/device:CPU:0
2023-02-17 16:26:36.886233: I tensorflow/core/common_runtime/placer.cc:114] input: (_Arg): /job:localhost/replica:0/task:0/device:CPU:0
_EagerConst: (_EagerConst): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:36.886254: I tensorflow/core/common_runtime/placer.cc:114] _EagerConst: (_EagerConst): /job:localhost/replica:0/task:0/device:GPU:0
output_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:36.886267: I tensorflow/core/common_runtime/placer.cc:114] output_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:36.889118: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op _EagerConst in device /job:localhost/replica:0/task:0/device:GPU:0
input: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.293473: I tensorflow/core/common_runtime/placer.cc:114] input: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
_EagerConst: (_EagerConst): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.293513: I tensorflow/core/common_runtime/placer.cc:114] _EagerConst: (_EagerConst): /job:localhost/replica:0/task:0/device:GPU:0
output_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.293521: I tensorflow/core/common_runtime/placer.cc:114] output_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.294777: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op _EagerConst in device /job:localhost/replica:0/task:0/device:GPU:0
dims: (_DeviceArg): /job:localhost/replica:0/task:0/device:CPU:0
2023-02-17 16:26:37.295227: I tensorflow/core/common_runtime/placer.cc:114] dims: (_DeviceArg): /job:localhost/replica:0/task:0/device:CPU:0
value: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.295238: I tensorflow/core/common_runtime/placer.cc:114] value: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
Fill: (Fill): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.295250: I tensorflow/core/common_runtime/placer.cc:114] Fill: (Fill): /job:localhost/replica:0/task:0/device:GPU:0
output_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.295257: I tensorflow/core/common_runtime/placer.cc:114] output_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.295729: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op Fill in device /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.297303: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op _EagerConst in device /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.297667: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op _EagerConst in device /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.297928: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op _EagerConst in device /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.298227: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op _EagerConst in device /job:localhost/replica:0/task:0/device:GPU:0
seed: (_Arg): /job:localhost/replica:0/task:0/device:CPU:0
2023-02-17 16:26:37.298533: I tensorflow/core/common_runtime/placer.cc:114] seed: (_Arg): /job:localhost/replica:0/task:0/device:CPU:0
StatelessRandomGetKeyCounter: (StatelessRandomGetKeyCounter): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.298547: I tensorflow/core/common_runtime/placer.cc:114] StatelessRandomGetKeyCounter: (StatelessRandomGetKeyCounter): /job:localhost/replica:0/task:0/device:GPU:0
key_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.298561: I tensorflow/core/common_runtime/placer.cc:114] key_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
counter_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.298573: I tensorflow/core/common_runtime/placer.cc:114] counter_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.299212: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op StatelessRandomGetKeyCounter in device /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.300704: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op _EagerConst in device /job:localhost/replica:0/task:0/device:GPU:0
shape: (_DeviceArg): /job:localhost/replica:0/task:0/device:CPU:0
2023-02-17 16:26:37.300963: I tensorflow/core/common_runtime/placer.cc:114] shape: (_DeviceArg): /job:localhost/replica:0/task:0/device:CPU:0
key: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.300974: I tensorflow/core/common_runtime/placer.cc:114] key: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
counter: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.300979: I tensorflow/core/common_runtime/placer.cc:114] counter: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
alg: (_DeviceArg): /job:localhost/replica:0/task:0/device:CPU:0
2023-02-17 16:26:37.300993: I tensorflow/core/common_runtime/placer.cc:114] alg: (_DeviceArg): /job:localhost/replica:0/task:0/device:CPU:0
StatelessRandomUniformV2: (StatelessRandomUniformV2): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.300999: I tensorflow/core/common_runtime/placer.cc:114] StatelessRandomUniformV2: (StatelessRandomUniformV2): /job:localhost/replica:0/task:0/device:GPU:0
output_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.301008: I tensorflow/core/common_runtime/placer.cc:114] output_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.301526: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op StatelessRandomUniformV2 in device /job:localhost/replica:0/task:0/device:GPU:0
x: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.302304: I tensorflow/core/common_runtime/placer.cc:114] x: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
y: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.302316: I tensorflow/core/common_runtime/placer.cc:114] y: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
Sub: (Sub): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.302332: I tensorflow/core/common_runtime/placer.cc:114] Sub: (Sub): /job:localhost/replica:0/task:0/device:GPU:0
z_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.302338: I tensorflow/core/common_runtime/placer.cc:114] z_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.302706: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op Sub in device /job:localhost/replica:0/task:0/device:GPU:0
x: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.303361: I tensorflow/core/common_runtime/placer.cc:114] x: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
y: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.303372: I tensorflow/core/common_runtime/placer.cc:114] y: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
Mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.303379: I tensorflow/core/common_runtime/placer.cc:114] Mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
z_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.303391: I tensorflow/core/common_runtime/placer.cc:114] z_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.303748: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op Mul in device /job:localhost/replica:0/task:0/device:GPU:0
x: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.304205: I tensorflow/core/common_runtime/placer.cc:114] x: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
y: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.304216: I tensorflow/core/common_runtime/placer.cc:114] y: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
AddV2: (AddV2): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.304223: I tensorflow/core/common_runtime/placer.cc:114] AddV2: (AddV2): /job:localhost/replica:0/task:0/device:GPU:0
z_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.304229: I tensorflow/core/common_runtime/placer.cc:114] z_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.304663: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op AddV2 in device /job:localhost/replica:0/task:0/device:GPU:0
resource_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.305249: I tensorflow/core/common_runtime/placer.cc:114] resource_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
VarHandleOp: (VarHandleOp): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.305263: I tensorflow/core/common_runtime/placer.cc:114] VarHandleOp: (VarHandleOp): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.305637: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0
resource: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.306104: I tensorflow/core/common_runtime/placer.cc:114] resource: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
value: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.306122: I tensorflow/core/common_runtime/placer.cc:114] value: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
AssignVariableOp: (AssignVariableOp): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.306139: I tensorflow/core/common_runtime/placer.cc:114] AssignVariableOp: (AssignVariableOp): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.306554: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0
x: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.307498: I tensorflow/core/common_runtime/placer.cc:114] x: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
Cast: (Cast): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.307523: I tensorflow/core/common_runtime/placer.cc:114] Cast: (Cast): /job:localhost/replica:0/task:0/device:GPU:0
y_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.307532: I tensorflow/core/common_runtime/placer.cc:114] y_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.307943: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op Cast in device /job:localhost/replica:0/task:0/device:GPU:0
resource: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.308714: I tensorflow/core/common_runtime/placer.cc:114] resource: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
ReadVariableOp: (ReadVariableOp): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.308734: I tensorflow/core/common_runtime/placer.cc:114] ReadVariableOp: (ReadVariableOp): /job:localhost/replica:0/task:0/device:GPU:0
value_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.308743: I tensorflow/core/common_runtime/placer.cc:114] value_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.309154: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.309366: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op Cast in device /job:localhost/replica:0/task:0/device:GPU:0
a: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.309693: I tensorflow/core/common_runtime/placer.cc:114] a: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
b: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.309710: I tensorflow/core/common_runtime/placer.cc:114] b: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
MatMul: (MatMul): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.309727: I tensorflow/core/common_runtime/placer.cc:114] MatMul: (MatMul): /job:localhost/replica:0/task:0/device:GPU:0
product_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.309738: I tensorflow/core/common_runtime/placer.cc:114] product_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.310165: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op MatMul in device /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.897703: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op _EagerConst in device /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.897841: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op _EagerConst in device /job:localhost/replica:0/task:0/device:GPU:0
dims: (_Arg): /job:localhost/replica:0/task:0/device:CPU:0
2023-02-17 16:26:37.898261: I tensorflow/core/common_runtime/placer.cc:114] dims: (_Arg): /job:localhost/replica:0/task:0/device:CPU:0
value: (_Arg): /job:localhost/replica:0/task:0/device:CPU:0
2023-02-17 16:26:37.898272: I tensorflow/core/common_runtime/placer.cc:114] value: (_Arg): /job:localhost/replica:0/task:0/device:CPU:0
Fill: (Fill): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.898280: I tensorflow/core/common_runtime/placer.cc:114] Fill: (Fill): /job:localhost/replica:0/task:0/device:GPU:0
output_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.898285: I tensorflow/core/common_runtime/placer.cc:114] output_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.898934: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op Fill in device /job:localhost/replica:0/task:0/device:GPU:0
x: (_DeviceArg): /job:localhost/replica:0/task:0/device:CPU:0
2023-02-17 16:26:37.899509: I tensorflow/core/common_runtime/placer.cc:114] x: (_DeviceArg): /job:localhost/replica:0/task:0/device:CPU:0
Cast: (Cast): /job:localhost/replica:0/task:0/device:CPU:0
2023-02-17 16:26:37.899523: I tensorflow/core/common_runtime/placer.cc:114] Cast: (Cast): /job:localhost/replica:0/task:0/device:CPU:0
y_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:CPU:0
2023-02-17 16:26:37.899530: I tensorflow/core/common_runtime/placer.cc:114] y_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:CPU:0
2023-02-17 16:26:37.899935: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op Cast in device /job:localhost/replica:0/task:0/device:CPU:0
x: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.900335: I tensorflow/core/common_runtime/placer.cc:114] x: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
y: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.900348: I tensorflow/core/common_runtime/placer.cc:114] y: (_Arg): /job:localhost/replica:0/task:0/device:GPU:0
Mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.900361: I tensorflow/core/common_runtime/placer.cc:114] Mul: (Mul): /job:localhost/replica:0/task:0/device:GPU:0
z_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.900369: I tensorflow/core/common_runtime/placer.cc:114] z_RetVal: (_Retval): /job:localhost/replica:0/task:0/device:GPU:0
2023-02-17 16:26:37.901006: I tensorflow/core/common_runtime/eager/execute.cc:1514] Executing op Mul in device /job:localhost/replica:0/task:0/device:GPU:0

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Comments: 18 (13 by maintainers)

Most upvoted comments

Thanks @reedwm! I was able to figure it out, I was originally confused because DEFINE_ALL_FROM and DEFINE_ALL_TO_* are actually the same and one of them mislabeled the input type… Let me know if you have any suggestions to clean things up in the PR.

@yufang67, Thank you for the confirmation and this issue we’ll consider as feature request

Hi, @sachinprasadhs

Could you please look into this issue ? Thank you!