Skip to content

Commit 155642f

Browse files
committed
feat: introduce GPU Direct Storage and DiskKVCache support for improved performance
- Added `use_gds` parameter to enable GPU Direct Storage with `kvikio-cu12`, allowing layers to load directly from disk to GPU. - Introduced `kv_cache_dir` option for offloading KV cache to SSD, supporting long contexts (50k+ tokens). - Updated README and CHANGELOG to reflect new features and usage instructions. - Modified Makefile and pyproject.toml to include new dependencies and installation options.
1 parent d04a5a2 commit 155642f

14 files changed

Lines changed: 987 additions & 10 deletions

.gitignore

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,6 @@ coverage.xml
4949
*.swo
5050

5151
# Local model cache (downloads and split layers) — do not commit
52-
/models/
53-
/.models/
52+
/.models/
53+
54+
examples

CHANGELOG.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,20 @@
22

33
All notable changes to RabbitLLM are documented here.
44

5+
## [1.1.0]
6+
7+
### Added
8+
- **GPU Direct Storage (kvikio)**: When `kvikio-cu12` is installed, layers load directly from disk to GPU,
9+
bypassing CPU and pin_memory. Install with `pip install rabbitllm[gds]`.
10+
- **DiskKVCache**: `kv_cache_dir` option to offload KV cache to SSD for 50k+ token contexts.
11+
- **example.py** in project root for quick onboarding.
12+
- **samples/** directory with sample text for long-context testing.
13+
- `use_gds` parameter (default `True`) to enable/disable kvikio when available.
14+
15+
### Changed
16+
- `load_layer_to_cpu` now tries kvikio (GDS) first when available and compression is not used.
17+
- README documents `use_gds`, `kv_cache_dir`, and the optional `[gds]` extra.
18+
519
## [1.0.1] — 2026-02-22
620

721
### Fixed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
.PHONY: install dev lint format test test-cov typecheck clean bash
22

33
install:
4-
uv sync
4+
uv sync --extra gds
55

66
dev: install
77

README.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ model = AutoModel.from_pretrained(
130130
max_seq_len=512, # maximum sequence length
131131
prefetching=True, # overlap layer loading with compute
132132
prefetch_pin_memory=True, # faster CPU→GPU for small/medium models
133+
use_gds=True, # GPU Direct Storage (kvikio) when available
134+
kv_cache_dir=None, # path to offload KV cache for long context (50k+ tokens)
133135
token="hf_...", # HuggingFace token for gated repos
134136
layer_shards_saving_path="/path/to/cache", # custom split cache directory
135137
profiling_mode=False, # print per-layer timing
@@ -150,6 +152,41 @@ model = AutoModel.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", compression="4bi
150152

151153
Requires `bitsandbytes`: `pip install bitsandbytes`.
152154

155+
### GPU Direct Storage (optional)
156+
157+
For CUDA without compression, install `kvikio-cu12` to load layers directly from disk to GPU,
158+
bypassing CPU and pin_memory (can significantly speed up 70B+ models):
159+
160+
```bash
161+
pip install rabbitllm[gds]
162+
# or: pip install kvikio-cu12
163+
```
164+
165+
Set `use_gds=False` to disable.
166+
167+
### Long context (KV cache on disk)
168+
169+
For 50k+ token contexts, pass `kv_cache_dir` to offload KV cache to SSD:
170+
171+
```python
172+
model = AutoModel.from_pretrained("Qwen/Qwen2.5-72B-Instruct", kv_cache_dir="./kv_cache")
173+
```
174+
175+
### Benchmarking improvements
176+
177+
To measure GDS and DiskKVCache improvements:
178+
179+
```bash
180+
# Local: make install pulls in kvikio (--extra gds)
181+
make install
182+
uv run python scripts/benchmark_improvements.py --mode gds
183+
uv run python scripts/benchmark_improvements.py --mode long_context
184+
185+
# Docker (make bash): install with GDS first
186+
pip install -e ".[gds]"
187+
python scripts/benchmark_improvements.py --mode gds
188+
```
189+
153190
### Gated models
154191

155192
Pass a HuggingFace token for repos that require access approval:

example.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#!/usr/bin/env python3
2+
"""
3+
RabbitLLM example — minimal inference script.
4+
5+
Run: python example.py
6+
Or: uv run python example.py
7+
8+
Uses a small model (Qwen2.5-0.5B) for fast testing. For larger models or long
9+
context, see scripts/quickstart.py and the Configuration section in README.
10+
"""
11+
12+
import warnings
13+
14+
import torch
15+
from rabbitllm import AutoModel
16+
17+
with warnings.catch_warnings():
18+
warnings.filterwarnings("ignore", message=".*CUDA.*unknown error.*", category=UserWarning)
19+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
20+
21+
model = AutoModel.from_pretrained(
22+
"Qwen/Qwen2.5-0.5B-Instruct",
23+
device=device,
24+
compression="4bit",
25+
)
26+
27+
messages = [
28+
{"role": "system", "content": "You are a helpful assistant."},
29+
{"role": "user", "content": "What is 2 + 2? Answer briefly."},
30+
]
31+
32+
input_text = model.tokenizer.apply_chat_template(
33+
messages, tokenize=False, add_generation_prompt=True
34+
)
35+
tokens = model.tokenizer(
36+
[input_text], return_tensors="pt", truncation=True, max_length=512
37+
)
38+
input_ids = tokens["input_ids"].to(device)
39+
attention_mask = tokens.get("attention_mask")
40+
if attention_mask is None:
41+
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
42+
else:
43+
attention_mask = attention_mask.to(device)
44+
45+
output = model.generate(
46+
input_ids,
47+
attention_mask=attention_mask,
48+
max_new_tokens=64,
49+
use_cache=True,
50+
do_sample=True,
51+
temperature=0.6,
52+
return_dict_in_generate=True,
53+
)
54+
55+
input_len = tokens["input_ids"].shape[1]
56+
print(model.tokenizer.decode(output.sequences[0][input_len:], skip_special_tokens=True))

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ classifiers = [
3636
[project.optional-dependencies]
3737
compression = ["bitsandbytes"]
3838
flash = ["flash-attn>=2.5"]
39+
gds = ["kvikio-cu12"]
3940
server = []
4041

4142
[tool.hatch.build.targets.wheel]

0 commit comments

Comments
 (0)