tensorflow: Masked LSTM and GRU using cuDNN 8.1+ give randomly corrupted results and crashes, even during inference.

Click to expand!

Issue Type

Bug

Have you reproduced the bug with TF nightly?

Yes

Source

binary

Tensorflow Version

TF 2.12, TF 2.11, TF nightly, TF 2.8

Custom Code

Yes

OS Platform and Distribution

Linux Ubuntu 20.04, Colab

Mobile device

No response

Python version

3.9, 3.10

Bazel version

No response

GCC/Compiler version

No response

CUDA/cuDNN version

cuDNN 8.1 or newer

GPU model and memory

I have observed the issue on GeForce 1080 Ti, GeForce 3090 RTX, A40, T4 from Collab.

Current Behaviour?

When LSTM or GRU with mask use cuDNN 8.1+ implementation, it randomly give corrupted results (even during inference), and sometimes even crash with error CUDA_ERROR_ILLEGAL_ADDRESS.

It has taken me quite some time to be able to reproduce the problem, but I have found it (it manifests reliably on Colab).

Important comments:

  • Of course, the problem manifests only when a GPU is available (the CPU implementation is fine).

  • The bug manifests only with cuDNN 8.1+, because older cuDNN use different RNN methods. I even recompiled TF 2.11 and TF 2.8 with CUDA 11.1 and cuDNN 8.0, and they work fine.

  • The bug manifests only when a mask is passed to the GRU and LSTM (the masked RNN calls also use a specific code path).

  • The problem is not data-dependent, the same batch sometimes does and sometimes does not trigger the bug (I use the same batch in the example below).

  • Sometimes the RNN call end with a CUDA_ERROR_ILLEGAL_ADDRESS and crash the program.

  • Different GPU models differ in how frequently the corruption happens – cards with CC 6.1 seem to trigger it more often; cards with CC 8.6 seem to trigger it less, but they still do

  • When the dimensionalities of the cells are larger, the problem manifests more often.

  • Note that before the recent Colab update, it used cuDNN 8.0.[56] for a very long time (I assume specific Colab packages were being build), so the bug was not manifesting there; but now it does.

My initial guess is that the different RNN calls somehow share memory and sometimes overwrite it (and maybe sometimes is the memory freed from one place and being accessed from the other place, causing the crash).

Standalone code to reproduce the issue

The Colab notebook showing the problem with TF 2.12.0: https://colab.research.google.com/drive/17a4AcbGf9CyCl4de_vPEbB3QlTwxlV1b?usp=sharing

The Colab notebook showing the problem with TF nightly https://colab.research.google.com/drive/1ONQ7EBF9iLkSmmJE3yb04nSNbhW9fnlV?usp=sharing

The code triggering the problem:

import numpy as np
import tensorflow as tf
print(tf.__version__)

def create_model(use_mask: bool) -> tf.keras.Model:
    inputs = tf.keras.layers.Input(shape=[None], dtype=tf.int32)
    if use_mask:
        mask = inputs >= 0
    else:
        mask = None
    h = tf.keras.layers.Embedding(64, 2048)(tf.math.maximum(inputs, 0))
    h1 = tf.keras.layers.Bidirectional(tf.keras.layers.GRU(2048, return_sequences=True), merge_mode="sum")(h, mask=mask)
    h2 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(2048, return_sequences=True), merge_mode="sum")(h, mask=mask)
    h = tf.keras.layers.Dense(10)(h1 + h2)
    return tf.keras.Model(inputs, h)

# Data
data = tf.data.Dataset.from_tensor_slices([[j if j <= i else -1 for j in range(64)] for i in range(64)]).batch(64)

# However, when masking is used, even prediction on GPU gives different result.
# It also sometimes crases with the error `CUDA_ERROR_ILLEGAL_ADDRESS`.
# The full error log is copied below.
# If `use_mask=False` is passed, no problem happens.

# Models
tf.keras.utils.set_random_seed(42)
model = create_model(use_mask=True)

# Run prediction
gold = None
for i in range(100):
    result = model.predict(data, verbose=0)
    if gold is None:
        gold = result
    print("Batch {}, max difference {}, mean difference {}".format(i, np.max(np.abs(gold - result)), np.mean(np.abs(gold - result))))

Relevant log output

Batch 0, max difference 0.0, mean difference 0.0
Batch 1, max difference 0.0, mean difference 0.0
Batch 2, max difference 0.0, mean difference 0.0
Batch 3, max difference 0.0, mean difference 0.0
Batch 4, max difference 0.0, mean difference 0.0
Batch 5, max difference 0.0, mean difference 0.0
Batch 6, max difference 1.6253925561904907, mean difference 0.032524894922971725
Batch 7, max difference 1.5961412191390991, mean difference 0.034154243767261505
Batch 8, max difference 0.0, mean difference 0.0
Batch 9, max difference 0.13439376652240753, mean difference 0.009293164126574993
Batch 10, max difference 0.0, mean difference 0.0
Batch 11, max difference 0.0, mean difference 0.0
Batch 12, max difference 0.0, mean difference 0.0
Batch 13, max difference 0.0, mean difference 0.0
Batch 14, max difference 0.0, mean difference 0.0
Batch 15, max difference 0.0, mean difference 0.0
Batch 16, max difference 0.0, mean difference 0.0
Batch 17, max difference 0.0, mean difference 0.0
Batch 18, max difference 0.0, mean difference 0.0
Batch 19, max difference 0.0, mean difference 0.0
Batch 20, max difference 0.0900668054819107, mean difference 0.006251291837543249
Batch 21, max difference 1.516040325164795, mean difference 0.03404051065444946
Batch 22, max difference 0.0, mean difference 0.0
Batch 23, max difference 0.09510315954685211, mean difference 0.004909028299152851
Batch 24, max difference 0.0, mean difference 0.0
Batch 25, max difference 0.0, mean difference 0.0
Batch 26, max difference 0.0, mean difference 0.0
Batch 27, max difference 0.10816850513219833, mean difference 0.006311381701380014
Batch 28, max difference 0.06991208344697952, mean difference 0.004763560835272074
Batch 29, max difference 0.0, mean difference 0.0
Batch 30, max difference 0.0, mean difference 0.0
Batch 31, max difference 0.12369412928819656, mean difference 0.0054976968094706535
Batch 32, max difference 1.114426612854004, mean difference 0.01574256643652916
Batch 33, max difference 1.1078299283981323, mean difference 0.01561223715543747
Batch 34, max difference 0.0, mean difference 0.0
Batch 35, max difference 0.07631748914718628, mean difference 0.004826296120882034
Batch 36, max difference 0.10820074379444122, mean difference 0.006503588054329157
Batch 37, max difference 0.09048224240541458, mean difference 0.006238402798771858
Batch 38, max difference 0.0, mean difference 0.0
Batch 39, max difference 0.0, mean difference 0.0
Batch 40, max difference 0.1100485697388649, mean difference 0.014095092192292213
Batch 41, max difference 0.007873136550188065, mean difference 0.0007940558716654778
Batch 42, max difference 0.007873136550188065, mean difference 0.0007940558716654778
Batch 43, max difference 0.007873136550188065, mean difference 0.0007940558716654778
Batch 44, max difference 0.06849975883960724, mean difference 0.004877315368503332
Batch 45, max difference 0.007873136550188065, mean difference 0.0007940558716654778
Batch 46, max difference 0.007873136550188065, mean difference 0.0007940558716654778
Batch 47, max difference 0.09048224240541458, mean difference 0.006395397242158651
Batch 48, max difference 0.12369412928819656, mean difference 0.005698652006685734
Batch 49, max difference 0.007873136550188065, mean difference 0.0007940558716654778
Batch 50, max difference 0.007873136550188065, mean difference 0.0007940558716654778
Batch 51, max difference 1.1203747987747192, mean difference 0.016059130430221558
Batch 52, max difference 0.007873136550188065, mean difference 0.0007940558716654778
Batch 53, max difference 0.007873136550188065, mean difference 0.0007940558716654778
Batch 54, max difference 0.007873136550188065, mean difference 0.0007940558716654778
Batch 55, max difference 0.007873136550188065, mean difference 0.0007940558716654778
Batch 56, max difference 0.07624031603336334, mean difference 0.0050272345542907715
Batch 57, max difference 1.6275964975357056, mean difference 0.032646626234054565
Batch 58, max difference 0.043441060930490494, mean difference 0.0046346113085746765
Batch 59, max difference 0.007873136550188065, mean difference 0.0007940558716654778
Batch 60, max difference 0.08132395148277283, mean difference 0.0063618021085858345

Then the following error appeared and crashed the program:

Mar 31, 2023, 10:30:50 PM	WARNING	2023-03-31 20:30:50.700469: E tensorflow/compiler/xla/stream_executor/cuda/cuda_event.cc:29] Error polling for event status: failed to query event: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
Mar 31, 2023, 10:30:50 PM	WARNING	2023-03-31 20:30:50.700546: F tensorflow/core/common_runtime/device/device_event_mgr.cc:223] Unexpected Event status: 1

About this issue

  • Original URL
  • State: closed
  • Created a year ago
  • Reactions: 2
  • Comments: 24 (18 by maintainers)

Most upvoted comments

Yes, the quoted lines are mainly for resolving another separate issue reported to our cudnn team, which was only reproducible in the multi-GPU settings. I happened to notice it and thought the root cause might be similar to this one. From what we have seen so far, the root cause should be the corrupted devSeqLengths on the device side, so depending on its content, it might cause wrong results or even a crash.

At this moment, we are not sure if it is TF implementation error or not. I would argue that we need to be very careful to correctly use the cudnn rnn APIs (of course before 8.9.1). For example, currently our guess is that in TF, under some circumstances when we copy the host data to prepare the device array devSeqLengths, the host memory might be already freed since it is an async copy. To mitigate such usage overhead, cudnn decides to help users to prepare the devSeqLengths so that no explicit devSeqLengths is needed anymore. Not very familiar with how the pytorch calls the cudnn. But if it can guarantee the correct devSeqLengths, it should be fine. Hope this explains the backstory.

Can you try the latest cuDNN like 8.9.2: https://docs.nvidia.com/deeplearning/cudnn/release-notes/index.html?

Basically, we guess the root cause is because the TF uses the async memcopy to prepare the devSeqLengths and when the copy actually happens, the host side memory might already be freed, causing the corrupted content in devSeqLengths. So, we have some workaround from the cuDNN side to not rely on this behavior. So, can you give a shot of this cudnn version and see if the issue is gone?

Hi, thanks a lot for all your work! It is great that just an update of cuDNN will be able to fix this.

I verified that the example from this issue (forward pass only) no longer produces any differences (and have never crashed so far) when using:

  • TF 2.11, CUDA 11.2, cuDNN 8.9.1 and cuDNN 8.9.2
  • TF 2.12, CUDA 11.8, cuDNN 8.9.1 and cuDNN 8.9.2

I used both RTX A4000 with cc 8.6 and Quadro P5000 with cc 6.1 for the tests.

In a few days I will test also our initial workload of training a RNN seq2seq model, which consists of a complex computation graph with quite a few RNN layers, and will report here. Leaving open until then.

BTW, in the release notes for 8.9.1, it is said that

Starting in cuDNN 8.9.1, the const int32_t devSeqLengths[] argument in cudnnRNNForward(), cudnnRNNBackwardData_v8(), and cudnnRNNBackwardWeights_v8() APIs will be ignored. All three functions will source variable sequence length arrays from RNN data descriptors, configured through the seqLengthArray parameter of cudnnSetRNNDataDescriptor(). The user does not need to transfer this array to device memory; the operation will be performed automatically by RNN APIs. This refinement simplifies the usage of cuDNN RNN APIs. It is also a workaround for random crashes in multi-GPU RNN training on TensorFlow.

Note that the problem was not only in multi-GPU training, but also in single-GPU training and also single-GPU inference, and the problem was not just the crashes, but also “just” silently worse results (in a few tasks we trained models for ~10 hours and performed inference on million of tokens, and we did not experience any crash, just bad results – during inference, some of the batches produced corrupted results, while other worked fine).

🙏 Lastly, out of curiosity, is the TF implementation incorrect and you just decided to fix it on the cuDNN side, or is the problem not as simple as saying “there is a bug on this line in the TF implementation”? Was this affecting also for example PyTorch, or TF only?

Yes, I also observed the changing results. If I use the CUDA_LAUNCH_BLOCKING=1, the results will be all the same. I think it is a bug. Let me dig into it to confirm if this is a cuDNN issue or TF issue.

@kaixih A gentle remainder after two weeks – I would be great if you could find the time to look into the issue (as the original code author) 🙏

@gbaned Could I humbly ask you to assign some CUDA GPU developer to review this bug? It has been causing silent corruption in masked LSTM and GRU code on GPU since TensorFlow 2.5 (the first TF to use cuDNN 8.1 or newer), and it can be replicated.