DirectML: RuntimeError: new(): expected key in DispatchKeySet(CPU, CUDA, HIP, XLA, MPS, IPU, XPU, HPU, Lazy, Meta) but got: PrivateUse1
🐛 Describe the bug
Microsoft directml custom backend for pytorch gpu acceleration in WSL receives error in huggingface transormer .generate method.
directml_torch reference: https://learn.microsoft.com/en-us/windows/ai/directml/gpu-pytorch-windows huggingface transformer reference: https://github.com/huggingface/transformers/blob/v4.26.1/src/transformers/generation/utils.py#L2424
import torch
import torch_directml
dml = torch_directml.device()
tensor1 = torch.tensor([1]).to(dml)
tensor2 = torch.tensor([2]).to(dml)
tensor1 = tensor1.new(tensor1.shape[0]).fill_(0)
tensor2 = tensor2.new(tensor2.shape[0]).fill_(0)
print("sum:", (tensor1 + tensor2).item())
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[24], line 8
5 tensor1 = torch.tensor([1]).to(dml)
6 tensor2 = torch.tensor([2]).to(dml)
----> 8 tensor1 = tensor1.new(tensor1.shape[0]).fill_(0)
9 tensor2 = tensor2.new(tensor2.shape[0]).fill_(0)
11 print("sum:", (tensor1 + tensor2).item())
RuntimeError: new(): expected key in DispatchKeySet(CPU, CUDA, HIP, XLA, MPS, IPU, XPU, HPU, Lazy, Meta) but got: PrivateUse1
Versions
Collecting environment information...
PyTorch version: 1.13.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.2 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.8.16 (default, Jan 17 2023, 23:13:24) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.90.1-microsoft-standard-WSL2-x86_64-with-glibc2.17
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 32
On-line CPU(s) list: 0-31
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 9 5950X 16-Core Processor
CPU family: 25
Model: 33
Thread(s) per core: 2
Core(s) per socket: 16
Socket(s): 1
Stepping: 0
BogoMIPS: 6800.05
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl tsc_reliable nonstop_tsc cpuid extd_apicid pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves clzero xsaveerptr arat npt nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload umip vaes vpclmulqdq rdpid fsrm
Virtualization: AMD-V
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 512 KiB (16 instances)
L1i cache: 512 KiB (16 instances)
L2 cache: 8 MiB (16 instances)
L3 cache: 32 MiB (1 instance)
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==1.22.4
[pip3] torch==1.13.1+rocm5.2
[pip3] torch-directml==0.1.13.1.dev230119
[pip3] torchaudio==0.13.1+rocm5.2
[pip3] torchdata==0.5.1
[pip3] torchtext==0.14.1
[pip3] torchvision==0.14.1+rocm5.2
[conda] blas 1.0 mkl
[conda] mkl 2022.1.0 hc2b9512_224
[conda] numpy 1.22.4 pypi_0 pypi
[conda] pytorch 1.13.1 py3.8_cpu_0 pytorch
[conda] pytorch-mutex 1.0 cpu pytorch
[conda] torch 1.13.1+rocm5.2 pypi_0 pypi
[conda] torch-directml 0.1.13.1.dev230119 pypi_0 pypi
[conda] torchaudio 0.13.1+rocm5.2 pypi_0 pypi
[conda] torchdata 0.5.1 pypi_0 pypi
[conda] torchtext 0.14.1 pypi_0 pypi
[conda] torchvision 0.14.1+rocm5.2 pypi_0 pypi
About this issue
- Original URL
- State: open
- Created a year ago
- Comments: 27
This should get you pretty close (ubuntu 22.04) for the patched pytorch build (cpu only). Wheels will be in the artifacts directory:
Will this be fixed in a 1.13.2 release? Probably it will take some time to get a compiled version with the fix included.
its just this line that addresses the issue (since the pytorch team agreed to fix🙏), which is just applying this commit (https://github.com/pytorch/pytorch/commit/cbec22fe6e1b5d8716a2daf056c773252b220dea) onto the local v1.13.1 source code.
if you already have a working windows build based on v1.13.1 branch source, i would think this should be the only change you need.
i did see that pytorch 2.0 was recently released, and you are correct that the change hasn’t made it in there yet (maybe next release?): https://github.com/pytorch/pytorch/compare/cbec22f..v2.0.0#diff-dafad7c78f9f0da499157bce059930b4945084a73ed7bbf52719c714d29cc0ab
Thank you very much! It works.
I am facing the same problem, would you please tell me how to solve it. Thanks in advance.