Skip to content

feat: activation patching at scale (capture-sourced, continuous-batching-aware)#212

Open
RhizoNymph wants to merge 40 commits into
feat/integrationfrom
feat/activation-patching
Open

feat: activation patching at scale (capture-sourced, continuous-batching-aware)#212
RhizoNymph wants to merge 40 commits into
feat/integrationfrom
feat/activation-patching

Conversation

@RhizoNymph

Copy link
Copy Markdown
Owner

Adds activation patching — overwriting (alpha=1) or interpolating residual-stream activations at specific (layer, hook, position) sites of a destination request with vectors captured from a prior clean run. Built for performant coarse→fine causal-tracing sweeps that mix freely into a continuously-batched stream: only the patched rows are intervened on, the rest pass through untouched.

What & why

Activation patching is steering with three changes — replace/lerp instead of add, per-(request, layer, hook, position) values, and values sourced from a prior capture run — so it reuses the steering/capture machinery and inherits the hard parts of continuous batching (per-token row gating, position→row mapping under chunked prefill, CUDA-graph-safe persistent buffers).

  • Data plane: apply_patch (lerp) + apply_patch_block (two-tensor post_block that reconstructs residual + hidden_states, since vLLM defers the MLP add and replace does not commute through it). Precise-lerp (1-α)·h + α·t so α=1 is a bit-exact replacement. Folded into apply_layer_steering / apply_block_steering via a process-global slot count — zero model-file edits.
  • Injection plane: per-(layer, hook) buffers + a per-step planner (abs_row = token_offset + (dest_pos - num_computed)), ephemeral per-step slots, strict overflow. CUDA-graph-safe (no force-eager).
  • Source: run-id-keyed PatchSourceStore (whole-run LRU) populated by a patch_source capture consumer that reuses the capture pipeline. Cross-rank: local resolution under PP, rank-0→peers broadcast under TP.
  • Config / spec: --enable-patching, SamplingParams.patch, OpenAI chat+completion plumbing, prefix-cache floor, admission validation.
  • Client: examples/online_serving/openai_patch_client.pyPatchStudy (sweep / zoom / heatmap) over the HTTP API.

Also cherry-picks the post_block capture-DCE fix (keeps the capture op live under torch.compile) so post_block patching works under CUDA graphs.

GPU-validated on Qwen3-0.6B across {eager, cudagraph} and {TP1/PP1, TP2, PP2}: replace is bit-exact, single-site patches recover the clean answer (denoising), and all parallelism configs agree.

RhizoNymph and others added 19 commits July 1, 2026 11:47
Patch was only wired into the v2 runner, so any model not on the v2 allowlist
(e.g. gemma3) silently accepted patch specs without applying them. Wire the same
control plane into the v1 GPUModelRunner: PatchModelRunnerMixin, _init_patch_state,
per-step _update_patch_buffers, and add/finish hooks. Move the runner-agnostic
_patch_add_request into the base mixin (shared by both runners).

Root fix: set the process-global patch slot count before the v1 model build so
register_steering_buffers attaches patch buffers (the v2 runner already did this;
v1 did not, so no patchable layers were discovered).

GPU-validated on gemma3-4b (v1 runner) and Qwen3-0.6B (both runners), eager +
cudagraph: no-op/self-identity bit-exact, cross-run replace reproduces clean,
denoising surfaces the clean answer.
A patched request re-forwards from its patch floor and registers its computed
blocks under vanilla token hashes, so a later unpatched request with the same
prompt could be served the patched KV (GPU repro: 0.47 max logprob corruption;
only unnoticed because short validation prompts never filled a full block).

Fold a deterministic patch-spec hash into the block hashes of all blocks at or
after the lowest patched position (attention propagates the patch forward), the
same mechanism steering uses. Blocks below the floor stay shareable, preserving
the corrupt-prefix sharing that makes sweeps cheap; distinct specs get distinct
KV chains.

GPU-validated both ways: with the fix an unpatched rerun after a patched run is
bit-identical to a fresh-engine ground truth; with the fix disabled it differs
by 0.47.
Sweep cells graded the answer/foil by looking them up in the generated top-k
logprobs — an answer outside top-k graded as None (top-k boundary flicker),
silently dropping cells from the grid.

Use the engine's logprob_token_ids to score the answer/foil ids exactly on
every request: the sweep endpoint resolves answer_token/foil_token to single
token ids via the tokenizer (400 if multi-token), and PatchStudy resolves them
via /tokenize, both passing the ids through (logprob_token_ids is now exposed
on the completions API). The engine requires logprobs == len(ids) when ids are
given.

Live-validated: a token far outside top-1 is reported exactly; the full sweep
grid grades every cell (0 top-k None-mismatches, 63/63 cells).
…hed cells

A source run evicted between admission (manifest check, positively cached) and
worker resolution made the patch entry log-and-skip: the request ran UNPATCHED
and its sweep cell silently reported the corrupt baseline as a patched result.

Two layers of defense:
- Leases: the admission path leases referenced runs on the workers (throttled
  to ~one RPC per run per half-TTL); store eviction skips unexpired-leased runs,
  soft-exceeding the byte budget with a warning instead of un-patching in-flight
  requests. Live-validated: a leased run survives capture pressure that would
  previously have evicted it, and re-sweeps grade 4/4 cells.
- Backstop: any residual resolution miss is recorded per-request in a worker
  registry; the sweep endpoint drains it after each sweep (collective_rpc) and
  voids the affected cells (grid=None + skipped[] entries) instead of returning
  unpatched values.
…t a runner-set global

The process-global slot count had to be set by each runner before its model
build — the v1 runner didn't, which shipped patching as a silent no-op there.
Resolve the slot count inside maybe_register_patch_buffers from
get_current_vllm_config_or_none() (models are always built under
set_current_vllm_config, on every runner), removing the runner-side setup from
both runners; the global remains only as a test-context fallback. GPU-checked:
buffers register and patching validates on both runners with no runner code.
source_position == dest_position silently patches shifted positions when the
clean and corrupt prompts tokenize to different lengths — a plausible-looking
but wrong heatmap. Add alignment: equal lengths map identity (corresponding
positions are the causal-tracing pairing); unequal lengths map the common token
prefix by identity and the common suffix by the length delta, and skip the
differing middle loudly (skipped[] + alignment summary in the response).

The sweep endpoint takes clean_prompt and refuses a length mismatch without it
(the source run's captured prompt length is exposed via the admission cache);
PatchStudy records the clean prompt on CleanRun and aligns automatically on
both the per-cell and server-side paths. Live-validated: mismatch 400s without
clean_prompt; an 11-vs-9-token pair aligns (prefix 4, suffix 4, middle skipped)
and grades 16/16 aligned cells.
vLLM is not batch-invariant by default, so identical requests in different
batch compositions return slightly different logprobs. Rather than forcing
batch-invariant mode (a server-wide throughput tax far below causal-tracing
signal), each sweep re-runs the corrupt baseline inside the cell batch and
reports |delta| vs the solo baseline as noise_floor — grid differences at or
below it are not meaningful. Docs point at batch_invariance for exact
reproducibility.
- gpu_patch_validate gains check F: at the best denoising site, alpha in
  {0, 0.5, 1} must move the answer logprob monotonically corrupt -> clean
  (exact grading via logprob_token_ids). Validates the lerp path between its
  endpoints, which was only CPU-tested.
- Chat admission rejects patch specs on multimodal prompts: prompt positions
  include image placeholder tokens, so patch positions would target
  placeholder activations — semantically undefined and unvalidated. Documented
  text-only scope.
feat(patch): --enable-patching implies patch_source capture consumer
refactor(patch): promote PatchStudy client into the vllm package + span-based positions
feat(patch): one-call patch sweeps via server-side auto-capture
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant