Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions openweights/jobs/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,28 @@ def create(
else:
return self.create_openai_inference_parallel_request(params)

base_model, lora_adapter = resolve_lora_model(params["model"])
lora_adapters = params.get("lora_adapters", [])
if lora_adapters:
# Multi-adapter path: validate that all adapters share the same rank.
# The rank check is done here (client-side) to fail fast before any GPU
# time is spent. linear combination requires identical ranks across adapters.
base_model = params["model"]
ranks = [get_lora_rank(a) for a in lora_adapters]
if len(set(ranks)) != 1:
raise ValueError(
f"All LoRA adapters must have the same rank for linear combination. "
f"Got: {dict(zip(lora_adapters, ranks))}"
)
lora_rank = ranks[0]
lora_requires = lora_rank / 16
else:
base_model, lora_adapter = resolve_lora_model(params["model"])
if lora_adapter:
lora_rank = get_lora_rank(lora_adapter)
lora_requires = lora_rank / 16
else:
lora_requires = 0

if requires_vram_gb == "guess":
model_size = guess_model_size(base_model)
weights_require = 2 * model_size
Expand All @@ -58,11 +79,6 @@ def create(
elif "4bit" in params["model"] and not "ftjob" in base_model:
weights_require = weights_require / 4
kv_cache_requires = 15 # TODO estimate this better
if lora_adapter:
lora_rank = get_lora_rank(lora_adapter)
lora_requires = lora_rank / 16
else:
lora_requires = 0
requires_vram_gb = int(
weights_require + kv_cache_requires + 0.5 + lora_requires
)
Expand Down
172 changes: 148 additions & 24 deletions openweights/jobs/inference/cli.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import gc
import json
import logging
import sys
import time
from pathlib import Path
from typing import List, Optional

import torch
from huggingface_hub import snapshot_download
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from transformers import AutoModelForCausalLM
from validate import InferenceConfig
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
Expand All @@ -16,6 +18,8 @@

client = OpenWeights()

MERGED_LORA_PATH = Path("/tmp/merged_lora")


def sample(
llm,
Expand Down Expand Up @@ -108,6 +112,106 @@ def get_number_of_gpus():
return count


def download_adapter(adapter_id: str) -> str:
"""Download a LoRA adapter to a local path and return that path.

Handles both plain HF repo IDs (``org/repo``) and subfolder paths
(``org/repo/checkpoint-100``).
"""
parts = adapter_id.split("/")
if len(parts) > 2:
repo_id = "/".join(parts[:2])
subfolder = "/".join(parts[2:])
local_root = snapshot_download(
repo_id=repo_id, allow_patterns=f"{subfolder}/*"
)
return f"{local_root}/{subfolder}"
return snapshot_download(repo_id=adapter_id)


def merge_lora_adapters(
base_model_path: str,
adapter_ids: List[str],
output_path: Path,
) -> str:
"""Merge multiple LoRA adapters into a single adapter via linear combination.

All adapters must have the same rank (enforced on the client side before
job submission). The merge runs entirely on CPU so that GPU memory is
free for vLLM initialisation immediately afterwards.

Args:
base_model_path: Local path to the base model weights.
adapter_ids: List of HuggingFace adapter IDs to merge.
output_path: Directory where the merged adapter will be saved.

Returns:
The string path to the saved merged adapter directory.
"""
try:
from peft import PeftModel
except ImportError:
raise ImportError(
"peft is required for merging multiple LoRA adapters. "
"Install it with: pip install peft"
)

print(f"Merging {len(adapter_ids)} LoRA adapters via linear combination (CPU)…")

# Download all adapters to local paths first.
local_paths: List[str] = []
for i, adapter_id in enumerate(adapter_ids):
print(f" Downloading adapter {i + 1}/{len(adapter_ids)}: {adapter_id}")
local_paths.append(download_adapter(adapter_id))

# Load the base model on CPU only — avoids claiming GPU memory before vLLM.
print(f"Loading base model on CPU for merging: {base_model_path}")
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
torch_dtype=torch.bfloat16,
device_map="cpu",
)

# Load each adapter under a unique name.
adapter_names = [f"adapter_{i}" for i in range(len(local_paths))]
model = PeftModel.from_pretrained(
base_model, local_paths[0], adapter_name=adapter_names[0]
)
for name, local_path in zip(adapter_names[1:], local_paths[1:]):
model.load_adapter(local_path, adapter_name=name)
print(f" Loaded {name}")

# Linearly combine all adapters with equal weights.
# This keeps the merged rank identical to the input rank, which is
# required for vLLM's max_lora_rank to remain unchanged.
model.add_weighted_adapter(
adapters=adapter_names,
weights=[1.0] * len(adapter_names),
combination_type="linear",
adapter_name="combined",
)
model.set_adapter("combined")

# Delete the source adapters so that only "combined" remains.
# When multiple adapters are present, PEFT's save_pretrained creates
# subdirectories per adapter, but vLLM expects a flat layout with
# adapter_config.json at the top level.
for name in adapter_names:
model.delete_adapter(name)

# Save only the adapter weights (not the full model).
output_path.mkdir(parents=True, exist_ok=True)
model.save_pretrained(str(output_path))
print(f"Merged adapter saved to {output_path}")

# Release CPU memory before vLLM claims the GPU.
del model
del base_model
gc.collect()

return str(output_path)


def load_jsonl_file_from_id(input_file_id):
content = client.files.content(input_file_id).decode()
rows = [json.loads(line) for line in content.split("\n") if line.strip()]
Expand All @@ -117,29 +221,60 @@ def load_jsonl_file_from_id(input_file_id):
def main(config_json: str):
cfg = InferenceConfig(**json.loads(config_json))

base_model, lora_adapter = resolve_lora_model(cfg.model)

# Only enable LoRA if we have an adapter
enable_lora = lora_adapter is not None
# ------------------------------------------------------------------
# 1️⃣ Resolve base model and adapter(s)
# ------------------------------------------------------------------
if cfg.lora_adapters:
# Multi-adapter path: cfg.model is the base model; adapters are merged
# on CPU before vLLM is initialised.
base_model = cfg.model
enable_lora = True
lora_adapter = None # resolved after merge below
else:
# Single adapter path (existing behaviour): adapter may be encoded in
# cfg.model as a HuggingFace adapter repo ID.
base_model, lora_adapter = resolve_lora_model(cfg.model)
enable_lora = lora_adapter is not None

# ------------------------------------------------------------------
# 1️⃣ Pre-download the base model to a local directory
# 2️⃣ Pre-download the base model to a local directory
# ------------------------------------------------------------------
LOCAL_MODEL_ROOT = Path("/workspace/hf_models") # pick any local path
LOCAL_MODEL_ROOT = Path("/workspace/hf_models")
LOCAL_MODEL_ROOT.mkdir(parents=True, exist_ok=True)

print(f"Downloading (or re-using) model '{base_model}' …")
print(f"Downloading (or re-using) base model '{base_model}' …")
local_base_model_path = snapshot_download(
repo_id=base_model,
local_dir=str(LOCAL_MODEL_ROOT / base_model.replace("/", "_")),
local_dir_use_symlinks=False, # real files; avoids NFS latency
)

# ------------------------------------------------------------------
# 3️⃣ Merge multiple adapters (if requested) — runs on CPU
# ------------------------------------------------------------------
if cfg.lora_adapters:
lora_path = merge_lora_adapters(
local_base_model_path, cfg.lora_adapters, MERGED_LORA_PATH
)
lora_rank = get_lora_rank(cfg.lora_adapters[0]) # all share same rank
lora_name = "+".join(cfg.lora_adapters)
elif lora_adapter is not None:
lora_path = download_adapter(lora_adapter)
lora_rank = get_lora_rank(lora_adapter)
lora_name = lora_adapter
else:
lora_path = None
lora_rank = None
lora_name = None

# ------------------------------------------------------------------
# 4️⃣ Build vLLM load kwargs
# ------------------------------------------------------------------
llm = None
load_kwargs = dict(
model=local_base_model_path,
enable_prefix_caching=True,
enable_lora=enable_lora, # Only enable if we have an adapter
enable_lora=enable_lora,
tensor_parallel_size=(
get_number_of_gpus() if cfg.load_format != "bitsandbytes" else 1
),
Expand All @@ -148,28 +283,17 @@ def main(config_json: str):
max_model_len=cfg.max_model_len,
)
if enable_lora:
load_kwargs["max_lora_rank"] = get_lora_rank(lora_adapter)
load_kwargs["max_lora_rank"] = lora_rank
if cfg.quantization is not None:
load_kwargs["quantization"] = cfg.quantization
if cfg.load_format is not None:
load_kwargs["load_format"] = cfg.load_format

# Create LoRA request only if we have an adapter
# Create a LoRA request only when an adapter (merged or single) is present.
lora_request = None
if lora_adapter is not None:
if len(lora_adapter.split("/")) > 2:
repo_id, subfolder = (
"/".join(lora_adapter.split("/")[:2]),
"/".join(lora_adapter.split("/")[2:]),
)
lora_path = (
snapshot_download(repo_id=repo_id, allow_patterns=f"{subfolder}/*")
+ f"/{subfolder}"
)
else:
lora_path = lora_adapter
if lora_path is not None:
lora_request = LoRARequest(
lora_name=lora_adapter, lora_int_id=1, lora_path=lora_path
lora_name=lora_name, lora_int_id=1, lora_path=lora_path
)

conversations = load_jsonl_file_from_id(cfg.input_file_id)
Expand Down
18 changes: 18 additions & 0 deletions openweights/jobs/inference/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ class Config:
use_batch: bool = Field(
True, description="Whether to use OpenAI batch API for inference"
)
lora_adapters: List[str] = Field(
[],
description=(
"Optional list of LoRA adapter HuggingFace IDs to merge before inference. "
"When provided, `model` must be the base model ID (not an adapter). "
"All adapters must target the same base model and share the same LoRA rank. "
"They are merged on-device (CPU) via PEFT linear combination before vLLM loads."
),
)

@field_validator("model")
def validate_model_format(cls, v):
Expand All @@ -29,6 +38,15 @@ def validate_model_format(cls, v):
)
return v

@field_validator("lora_adapters")
def validate_lora_adapters_length(cls, v):
if len(v) == 1:
raise ValueError(
"lora_adapters must contain at least 2 adapters. "
"For a single adapter, pass it as the `model` field instead."
)
return v

max_tokens: int = Field(600, description="Maximum number of tokens to generate")
temperature: float = Field(1.0, description="Temperature for sampling")
top_p: float = Field(1.0, description="Top P")
Expand Down