DirectML: RuntimeError: new(): expected key in DispatchKeySet(CPU, CUDA, HIP, XLA, MPS, IPU, XPU, HPU, Lazy, Meta) but got: PrivateUse1

🐛 Describe the bug

Microsoft directml custom backend for pytorch gpu acceleration in WSL receives error in huggingface transormer .generate method.

directml_torch reference: https://learn.microsoft.com/en-us/windows/ai/directml/gpu-pytorch-windows huggingface transformer reference: https://github.com/huggingface/transformers/blob/v4.26.1/src/transformers/generation/utils.py#L2424

import torch
import torch_directml
dml = torch_directml.device()

tensor1 = torch.tensor([1]).to(dml)
tensor2 = torch.tensor([2]).to(dml)

tensor1 = tensor1.new(tensor1.shape[0]).fill_(0)
tensor2 = tensor2.new(tensor2.shape[0]).fill_(0)

print("sum:", (tensor1 + tensor2).item())
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[24], line 8
      5 tensor1 = torch.tensor([1]).to(dml)
      6 tensor2 = torch.tensor([2]).to(dml)
----> 8 tensor1 = tensor1.new(tensor1.shape[0]).fill_(0)
      9 tensor2 = tensor2.new(tensor2.shape[0]).fill_(0)
     11 print("sum:", (tensor1 + tensor2).item())

RuntimeError: new(): expected key in DispatchKeySet(CPU, CUDA, HIP, XLA, MPS, IPU, XPU, HPU, Lazy, Meta) but got: PrivateUse1

Versions

Collecting environment information...
PyTorch version: 1.13.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.8.16 (default, Jan 17 2023, 23:13:24)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.90.1-microsoft-standard-WSL2-x86_64-with-glibc2.17
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   48 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          32
On-line CPU(s) list:             0-31
Vendor ID:                       AuthenticAMD
Model name:                      AMD Ryzen 9 5950X 16-Core Processor
CPU family:                      25
Model:                           33
Thread(s) per core:              2
Core(s) per socket:              16
Socket(s):                       1
Stepping:                        0
BogoMIPS:                        6800.05
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves clzero xsaveerptr arat npt nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload umip vaes vpclmulqdq rdpid fsrm
Virtualization:                  AMD-V
Hypervisor vendor:               Microsoft
Virtualization type:             full
L1d cache:                       512 KiB (16 instances)
L1i cache:                       512 KiB (16 instances)
L2 cache:                        8 MiB (16 instances)
L3 cache:                        32 MiB (1 instance)
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] numpy==1.22.4
[pip3] torch==1.13.1+rocm5.2
[pip3] torch-directml==0.1.13.1.dev230119
[pip3] torchaudio==0.13.1+rocm5.2
[pip3] torchdata==0.5.1
[pip3] torchtext==0.14.1
[pip3] torchvision==0.14.1+rocm5.2
[conda] blas                      1.0                         mkl  
[conda] mkl                       2022.1.0           hc2b9512_224  
[conda] numpy                     1.22.4                   pypi_0    pypi
[conda] pytorch                   1.13.1              py3.8_cpu_0    pytorch
[conda] pytorch-mutex             1.0                         cpu    pytorch
[conda] torch                     1.13.1+rocm5.2           pypi_0    pypi
[conda] torch-directml            0.1.13.1.dev230119          pypi_0    pypi
[conda] torchaudio                0.13.1+rocm5.2           pypi_0    pypi
[conda] torchdata                 0.5.1                    pypi_0    pypi
[conda] torchtext                 0.14.1                   pypi_0    pypi
[conda] torchvision               0.14.1+rocm5.2           pypi_0    pypi

About this issue

Most upvoted comments

This should get you pretty close (ubuntu 22.04) for the patched pytorch build (cpu only). Wheels will be in the artifacts directory:

sudo apt-get install -y g++-10
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 30
sudo update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-10 30
sudo update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-10 30

git clone --recursive --branch v1.13.1 https://github.com/pytorch/pytorch
git clone --recursive --branch release/1.13 https://github.com/pytorch/builder

cd pytorch
git cherry-pick --no-commit cbec22fe6e1b5d8716a2daf056c773252b220dea
! sudo git submodule sync
! sudo git submodule update --init --recursive
cd ..
mkdir -p artifacts

sudo su -p
conda create --name pytorch -y python=3.10
conda activate pytorch
conda install cmake ninja -y 

export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}
export PYTORCH_BUILD_VERSION=1.13.1+cpu
export PYTORCH_BUILD_NUMBER=1
export _GLIBCXX_USE_CXX11_ABI=0
export CXXFLAGS=-D_GLIBCXX_USE_CXX11_ABI=0
export DESIRED_PYTHON=3.10
export BUILD_SPLIT_CUDA=1
export PYTORCH_ROOT=./pytorch
export PYTORCH_FINAL_PACKAGE_DIR=./artifacts
export GPU_ARCH_TYPE=cpu
export USE_NUMA=OFF
export USE_MPI=OFF
export USE_CUDA=OFF
export USE_MKLDNN=OFF

bash ./builder/common/install_mkl.sh
bash ./builder/common/install_patchelf.sh
bash ./builder/manywheel/build.sh > build.log  2>&1

Will this be fixed in a 1.13.2 release? Probably it will take some time to get a compiled version with the fix included.

git cherry-pick --no-commit e6e1b5d8716a2daf056c773252b220dea

its just this line that addresses the issue (since the pytorch team agreed to fix🙏), which is just applying this commit (https://github.com/pytorch/pytorch/commit/cbec22fe6e1b5d8716a2daf056c773252b220dea) onto the local v1.13.1 source code.

if you already have a working windows build based on v1.13.1 branch source, i would think this should be the only change you need.

i did see that pytorch 2.0 was recently released, and you are correct that the change hasn’t made it in there yet (maybe next release?): https://github.com/pytorch/pytorch/compare/cbec22f..v2.0.0#diff-dafad7c78f9f0da499157bce059930b4945084a73ed7bbf52719c714d29cc0ab

image

If it helps @Looong01, locally compiled wheel (Ubuntu 22.04 dependant) is here: https://file.io/jZYHjII2ocMM (one-time download)

Thank you very much! It works.

I am facing the same problem, would you please tell me how to solve it. Thanks in advance.