triton: [BUG] DeepSpeed Inference with HF diffusers breaks on T4

I am trying to use DeepSpeed Inference with Diffusers on T4 GPU but it seems there is a triton error.

Reported the bug on DeepSpeed for better tracking: https://github.com/microsoft/DeepSpeed/issues/2702

import os, torch, diffusers, deepspeed

hf_auth_key = os.getenv("HF_AUTH_KEY")
if not hf_auth_key:
    raise ValueError("HF_AUTH_KEY is not set")

pipe = diffusers.StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    use_auth_token=hf_auth_key,
    torch_dtype=torch.float16,
    revision="fp16")

model = deepspeed.init_inference(pipe.to("cuda"), dtype=torch.float16)
model("hello from here")

Here is the error trace associated with the inference. It seems related to Triton caching.

Traceback (most recent call last):
  File "<string>", line 21, in _fwd_kernel
KeyError: ('2-.-0-.-0--2b0c5161c53c71b37ae20a9996ee4bb8-c1f92808b4e4644c1732e8338187ac87-42648570729a4835b21c1c18cebedbfe-12f7ac1ca211e037f62a7c0c323d9990-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.float16, torch.float16, torch.float16, 'fp32', torch.float32, torch.float16, 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), (128, 64, 128), (True, True, True, (False,), True, True, (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (False, False), (False, False), (True, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "test.py", line 14, in <module>
    model("hello from here")
  File "/home/zeus/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/deepspeed/inference/engine.py", line 524, in forward
    outputs = self.module(*inputs, **kwargs)
  File "/home/zeus/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 387, in __call__
    noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
  File "/home/zeus/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/deepspeed/model_implementations/diffusers/unet.py", line 41, in forward
    return self._forward(*inputs, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/deepspeed/model_implementations/diffusers/unet.py", line 63, in _forward
    return self.unet(sample, timestamp, encoder_hidden_states, return_dict)
  File "/home/zeus/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py", line 307, in forward
    sample, res_samples = downsample_block(
  File "/home/zeus/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py", line 598, in forward
    hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
  File "/home/zeus/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/diffusers/models/attention.py", line 202, in forward
    hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
  File "/home/zeus/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/diffusers_transformer_block.py", line 99, in forward
    out_attn_1 = self.attn_1(out_norm_1)
  File "/home/zeus/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/diffusers_attention.py", line 225, in forward
    output = DeepSpeedDiffusersAttentionFunction.apply(
  File "/content/venv/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/diffusers_attention.py", line 115, in forward
    output = selfAttention_fp(input, context, input_mask)
  File "/content/venv/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/diffusers_attention.py", line 79, in selfAttention_fp
    context_layer = triton_flash_attn_kernel(qkv_out[0],
  File "/home/zeus/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/triton_ops.py", line 119, in forward
    _fwd_kernel[grid](
  File "/content/venv/lib/python3.8/site-packages/triton/runtime/jit.py", line 106, in launcher
    return self.run(*args, grid=grid, **kwargs)
  File "<string>", line 43, in _fwd_kernel
RuntimeError: Triton Error [CUDA]: invalid argument
 NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.5
aiobotocore==2.4.2
aiohttp==3.8.3
aioitertools==0.11.0
aiosignal==1.3.1
anyio==3.6.2
arrow==1.2.3
async-timeout==4.0.2
attrs==22.2.0
beautifulsoup4==4.11.1
black==22.12.0
bleach==5.0.1
blessed==1.19.1
botocore==1.27.59
certifi==2019.11.28
cffi==1.15.1
chardet==3.0.4
charset-normalizer==2.1.1
click==8.1.3
cmake==3.25.0
commonmark==0.9.1
croniter==1.3.8
cryptography==39.0.0
dbus-python==1.2.16
deepdiff==6.2.3
deepspeed==0.7.7
diffusers==0.7.1
diffusion-with-autoscaler @ file:///content
distlib==0.3.6
dnspython==2.2.1
docker==6.0.1
docutils==0.19
email-validator==1.3.0
exceptiongroup==1.1.0
fastapi==0.89.1
filelock==3.9.0
frozenlist==1.3.3
fsspec==2022.11.0
h11==0.14.0
hjson==3.1.0
httpcore==0.16.3
httptools==0.5.0
httpx==0.23.3
huggingface-hub==0.11.1
idna==2.8
importlib-metadata==6.0.0
importlib-resources==5.10.2
iniconfig==2.0.0
inquirer==3.1.2
isort==5.11.4
itsdangerous==2.1.2
jaraco.classes==3.2.3
jeepney==0.8.0
Jinja2==3.1.2
jmespath==1.0.1
keyring==23.13.1
lightning @ https://github.com/Lightning-AI/lightning/archive/refs/tags/1.8.6.zip
lightning-api-access @ git+https://github.com/Lightning-AI/LAI-API-Access-UI-Component.git@ec3016c1bd2165f9e720b686a83376def1705a60
lightning-cloud==0.5.16
lightning-launcher==0.0.43
lightning-utilities==0.5.0
MarkupSafe==2.1.1
more-itertools==9.0.0
multidict==6.0.4
mypy-extensions==0.4.3
ninja==1.11.1
numpy==1.24.1
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
ordered-set==4.1.0
orjson==3.8.5
packaging==23.0
pathspec==0.10.3
Pillow==9.4.0
pkginfo==1.9.6
platformdirs==2.6.2
pluggy==1.0.0
protobuf==3.20.1
psutil==5.9.4
py-cpuinfo==9.0.0
pycparser==2.21
pydantic==1.10.4
Pygments==2.14.0
PyGObject==3.36.0
PyJWT==2.6.0
pytest==7.2.0
python-apt==2.0.0+ubuntu0.20.4.8
python-dateutil==2.8.2
python-dotenv==0.21.0
python-editor==1.0.4
python-multipart==0.0.5
PyYAML==6.0
readchar==4.0.3
readme-renderer==37.3
redis==4.4.2
regex==2022.10.31
requests==2.28.1
requests-toolbelt==0.10.1
requests-unixsocket==0.2.0
rfc3986==1.5.0
rich==13.0.1
s3fs==2022.11.0
SecretStorage==3.3.3
six==1.14.0
sniffio==1.3.0
soupsieve==2.3.2.post1
starlette==0.22.0
starsessions==1.3.0
tabulate==0.9.0
tensorboardX==2.5.1
tokenizers==0.13.2
tomli==2.0.1
torch==1.13.1
torchmetrics==0.11.0
tqdm==4.64.1
traitlets==5.8.1
transformers==4.24.0
triton==2.0.0.dev20221202
twine==4.0.2
typing_extensions==4.4.0
ujson==5.7.0
urllib3==1.26.14
uvicorn==0.20.0
uvloop==0.17.0
virtualenv==20.17.1
watchfiles==0.18.1
wcwidth==0.2.5
webencodings==0.5.1
websocket-client==1.4.2
websockets==10.4
wrapt==1.14.1
yarl==1.8.2
zipp==3.11.0

About this issue

  • Original URL
  • State: open
  • Created a year ago
  • Reactions: 1
  • Comments: 16

Most upvoted comments

If I remember correctly, the forward pass used to work on pre-Ampere hardware, but the backward pass only worked on post-Ampere hardware. It may be the case that now neither works on Turing 😄 I’m still working on A100 performance optimizations, but I agree that the forward pass should work well on all hardware. I don’t think there’s any major roadblock against this. I’ll look into it.