You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add support for AMD GPUs (RDNA3/RDNA4) via ROCm's HIP translation layer.
Most PyTorch code works unchanged since HIP maps torch.cuda.* APIs, but
three areas needed explicit AMD handling:
device_utils.py:
- Add GPUInfo dataclass and enumerate_gpus() function
- Tries nvidia-smi, then amd-smi (ROCm 6.0+), then rocm-smi (legacy),
then torch.cuda fallback — works on NVIDIA, AMD, or either
inference_engine.py:
- torch.compile uses "max-autotune-no-cudagraphs" on ROCm to avoid a
known HIP graph segfault on large graphs (pytorch/pytorch#155720)
- Auto-sets TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 on ROCm so SDPA
dispatches to AOTriton flash attention on RDNA3 (without this it
silently falls back to O(n^2) math backend)
- Skips torch.compile entirely on Windows ROCm where Triton kernel
compilation hangs indefinitely (eager fallback works fine)
pyproject.toml:
- Add "rocm" optional extra with install instructions
README.md:
- Add AMD ROCm to Hardware Requirements
- Add ROCm install instructions for Linux and Windows
- Add dedicated AMD ROCm Setup section with supported GPUs, automatic
behavior, and known limitations
Tested on RX 7800 XT (gfx1101) Windows with AMD ROCm 7.2:
- torch.cuda.is_available() = True
- SDPA with float16 = OK
- Raw inference = OK
- torch.compile = hangs on Windows (skipped), works on Linux
Copy file name to clipboardExpand all lines: README.md
+80-1Lines changed: 80 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -35,7 +35,8 @@ This project was designed and built on a Linux workstation (Puget Systems PC) eq
35
35
36
36
The most recent build should work on computers with 6-8 gig of VRAM, and it can run on most M1+ Mac systems with unified memory. Yes, it might even work on your old Macbook pro. Let us know on the Discord!
37
37
38
-
***Windows Users:** To run GPU acceleration natively on Windows, your system MUST have NVIDIA drivers that support **CUDA 12.8 or higher** installed. If your drivers only support older CUDA versions, the installer will likely fallback to the CPU.
38
+
***Windows Users (NVIDIA):** To run GPU acceleration natively on Windows, your system MUST have NVIDIA drivers that support **CUDA 12.8 or higher** installed. If your drivers only support older CUDA versions, the installer will likely fallback to the CPU.
39
+
***AMD GPU Users (ROCm):** AMD Radeon RX 7000 series (RDNA3) and RX 9000 series (RDNA4) are supported via ROCm on **Linux**. Windows ROCm support is experimental (torch.compile is not yet functional). See the [AMD ROCm Setup](#amd-rocm-setup) section below.
39
40
***GVM (Optional):** Requires approximately **80 GB of VRAM** and utilizes massive Stable Video Diffusion models.
40
41
***VideoMaMa (Optional):** Natively requires a massive chunk of VRAM as well (originally 80GB+). While the community has tweaked the architecture to run at less than 24GB, those extreme memory optimizations have not yet been fully implemented in this repository.
***CorridorKey v1.0 Model (~300MB):** Downloads automatically on first run. If no `.pth` file is found in`CorridorKeyModule/checkpoints/`, the engine fetches it from [CorridorKey's HuggingFace](https://huggingface.co/nikopueringer/CorridorKey_v1.0) and saves it as `CorridorKey.pth`. No manual download needed.
**Use native MLX instead of PyTorch MPS:** MLX avoids PyTorch's MPS layer entirely and typically runs faster on Apple Silicon. See the [Backend Selection](#backend-selection) section below for setup steps.
222
227
228
+
### AMD ROCm Setup
229
+
230
+
CorridorKey supports AMD GPUs via PyTorch's ROCm/HIP backend. The `torch.cuda.*` API works transparently on AMD — HIP intercepts all CUDA calls at runtime, so the inference code runs unchanged.
**VRAM requirements:** CorridorKey inference at 2048x2048 needs ~18GB VRAM. The RX 7900 XTX (24GB) and RX 7900 XT (20GB) run at full resolution. Cards with 16GB (RX 7800 XT, 9070 XT) work on Windows (which uses system RAM as overflow) but may OOM on Linux — see notes below.
238
+
239
+
**Linux native (recommended):**
240
+
```bash
241
+
# Install AMD's ROCm torch wheels, then sync everything else
Windows ROCm requires Python 3.12 and AMD Adrenalin 25.3.1+ driver. `torch.compile` does not work on Windows ROCm — inference runs in eager mode (significantly slower than Linux).
- Sets `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` so SDPA dispatches to flash attention kernels on RDNA3 (without this, attention falls back to a slow O(n²) path)
288
+
- Sets `MIOPEN_FIND_MODE=2`for faster convolution kernel selection (reduces warmup from 5-8 minutes to seconds)
289
+
- Uses `torch.compile(mode="max-autotune-no-cudagraphs")` on Linux to avoid a known HIP graph segfault ([pytorch/pytorch#155720](https://github.com/pytorch/pytorch/issues/155720))
290
+
- Skips `torch.compile` entirely on Windows ROCm where Triton compilation hangs
291
+
292
+
**First-run note:** The first inference run on a new AMD GPU triggers Triton kernel autotuning (10-20 minutes). This is cached in`~/.cache/corridorkey/inductor/` and only happens once per GPU architecture. Subsequent runs start instantly.
293
+
294
+
**16GB cards on Linux:** CorridorKey at 2048x2048 needs ~18GB. Windows handles this transparently via shared GPU memory (system RAM overflow). On Linux, the GPU has a hard VRAM limit. If you hit OOM on a 16GB card, install `pytorch-rocm-gtt` to enable GTT (system RAM as GPU overflow) — CorridorKey detects and uses it automatically:
295
+
```bash
296
+
pip install pytorch-rocm-gtt
297
+
```
298
+
GTT memory is accessed over PCIe (~10-20x slower than VRAM), so expect slower frame times on 16GB cards vs 20-24GB cards.
299
+
300
+
**WSL2 limitation:** WSL2 cannot use GTT or shared memory — it has a hard VRAM limit. 16GB cards will OOM in WSL2 at 2048x2048. Use Windows native instead, or a card with 20GB+ VRAM.
0 commit comments