feat: activation patching at scale (capture-sourced, continuous-batching-aware)#212
Open
RhizoNymph wants to merge 40 commits into
Open
feat: activation patching at scale (capture-sourced, continuous-batching-aware)#212RhizoNymph wants to merge 40 commits into
RhizoNymph wants to merge 40 commits into
Conversation
Patch was only wired into the v2 runner, so any model not on the v2 allowlist (e.g. gemma3) silently accepted patch specs without applying them. Wire the same control plane into the v1 GPUModelRunner: PatchModelRunnerMixin, _init_patch_state, per-step _update_patch_buffers, and add/finish hooks. Move the runner-agnostic _patch_add_request into the base mixin (shared by both runners). Root fix: set the process-global patch slot count before the v1 model build so register_steering_buffers attaches patch buffers (the v2 runner already did this; v1 did not, so no patchable layers were discovered). GPU-validated on gemma3-4b (v1 runner) and Qwen3-0.6B (both runners), eager + cudagraph: no-op/self-identity bit-exact, cross-run replace reproduces clean, denoising surfaces the clean answer.
A patched request re-forwards from its patch floor and registers its computed blocks under vanilla token hashes, so a later unpatched request with the same prompt could be served the patched KV (GPU repro: 0.47 max logprob corruption; only unnoticed because short validation prompts never filled a full block). Fold a deterministic patch-spec hash into the block hashes of all blocks at or after the lowest patched position (attention propagates the patch forward), the same mechanism steering uses. Blocks below the floor stay shareable, preserving the corrupt-prefix sharing that makes sweeps cheap; distinct specs get distinct KV chains. GPU-validated both ways: with the fix an unpatched rerun after a patched run is bit-identical to a fresh-engine ground truth; with the fix disabled it differs by 0.47.
Sweep cells graded the answer/foil by looking them up in the generated top-k logprobs — an answer outside top-k graded as None (top-k boundary flicker), silently dropping cells from the grid. Use the engine's logprob_token_ids to score the answer/foil ids exactly on every request: the sweep endpoint resolves answer_token/foil_token to single token ids via the tokenizer (400 if multi-token), and PatchStudy resolves them via /tokenize, both passing the ids through (logprob_token_ids is now exposed on the completions API). The engine requires logprobs == len(ids) when ids are given. Live-validated: a token far outside top-1 is reported exactly; the full sweep grid grades every cell (0 top-k None-mismatches, 63/63 cells).
…hed cells A source run evicted between admission (manifest check, positively cached) and worker resolution made the patch entry log-and-skip: the request ran UNPATCHED and its sweep cell silently reported the corrupt baseline as a patched result. Two layers of defense: - Leases: the admission path leases referenced runs on the workers (throttled to ~one RPC per run per half-TTL); store eviction skips unexpired-leased runs, soft-exceeding the byte budget with a warning instead of un-patching in-flight requests. Live-validated: a leased run survives capture pressure that would previously have evicted it, and re-sweeps grade 4/4 cells. - Backstop: any residual resolution miss is recorded per-request in a worker registry; the sweep endpoint drains it after each sweep (collective_rpc) and voids the affected cells (grid=None + skipped[] entries) instead of returning unpatched values.
…t a runner-set global The process-global slot count had to be set by each runner before its model build — the v1 runner didn't, which shipped patching as a silent no-op there. Resolve the slot count inside maybe_register_patch_buffers from get_current_vllm_config_or_none() (models are always built under set_current_vllm_config, on every runner), removing the runner-side setup from both runners; the global remains only as a test-context fallback. GPU-checked: buffers register and patching validates on both runners with no runner code.
source_position == dest_position silently patches shifted positions when the clean and corrupt prompts tokenize to different lengths — a plausible-looking but wrong heatmap. Add alignment: equal lengths map identity (corresponding positions are the causal-tracing pairing); unequal lengths map the common token prefix by identity and the common suffix by the length delta, and skip the differing middle loudly (skipped[] + alignment summary in the response). The sweep endpoint takes clean_prompt and refuses a length mismatch without it (the source run's captured prompt length is exposed via the admission cache); PatchStudy records the clean prompt on CleanRun and aligns automatically on both the per-cell and server-side paths. Live-validated: mismatch 400s without clean_prompt; an 11-vs-9-token pair aligns (prefix 4, suffix 4, middle skipped) and grades 16/16 aligned cells.
vLLM is not batch-invariant by default, so identical requests in different batch compositions return slightly different logprobs. Rather than forcing batch-invariant mode (a server-wide throughput tax far below causal-tracing signal), each sweep re-runs the corrupt baseline inside the cell batch and reports |delta| vs the solo baseline as noise_floor — grid differences at or below it are not meaningful. Docs point at batch_invariance for exact reproducibility.
- gpu_patch_validate gains check F: at the best denoising site, alpha in
{0, 0.5, 1} must move the answer logprob monotonically corrupt -> clean
(exact grading via logprob_token_ids). Validates the lerp path between its
endpoints, which was only CPU-tested.
- Chat admission rejects patch specs on multimodal prompts: prompt positions
include image placeholder tokens, so patch positions would target
placeholder activations — semantically undefined and unvalidated. Documented
text-only scope.
… add span positions
feat(patch): --enable-patching implies patch_source capture consumer
refactor(patch): promote PatchStudy client into the vllm package + span-based positions
feat(patch): one-call patch sweeps via server-side auto-capture
feat(patch): compose server-side spans + one-call auto-capture in sweeps
feat(patch): opt-in SSE streaming for /v1/patch_sweep grids
feat(patch): multi-hook sweeps + source-run lifecycle
…(slot 0 sentinel)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds activation patching — overwriting (alpha=1) or interpolating residual-stream activations at specific
(layer, hook, position)sites of a destination request with vectors captured from a prior clean run. Built for performant coarse→fine causal-tracing sweeps that mix freely into a continuously-batched stream: only the patched rows are intervened on, the rest pass through untouched.What & why
Activation patching is steering with three changes — replace/lerp instead of add, per-
(request, layer, hook, position)values, and values sourced from a prior capture run — so it reuses the steering/capture machinery and inherits the hard parts of continuous batching (per-token row gating, position→row mapping under chunked prefill, CUDA-graph-safe persistent buffers).apply_patch(lerp) +apply_patch_block(two-tensorpost_blockthat reconstructsresidual + hidden_states, since vLLM defers the MLP add and replace does not commute through it). Precise-lerp(1-α)·h + α·tsoα=1is a bit-exact replacement. Folded intoapply_layer_steering/apply_block_steeringvia a process-global slot count — zero model-file edits.(layer, hook)buffers + a per-step planner (abs_row = token_offset + (dest_pos - num_computed)), ephemeral per-step slots, strict overflow. CUDA-graph-safe (no force-eager).PatchSourceStore(whole-run LRU) populated by apatch_sourcecapture consumer that reuses the capture pipeline. Cross-rank: local resolution under PP, rank-0→peers broadcast under TP.--enable-patching,SamplingParams.patch, OpenAI chat+completion plumbing, prefix-cache floor, admission validation.examples/online_serving/openai_patch_client.py—PatchStudy(sweep / zoom / heatmap) over the HTTP API.Also cherry-picks the
post_blockcapture-DCE fix (keeps the capture op live undertorch.compile) so post_block patching works under CUDA graphs.GPU-validated on Qwen3-0.6B across {eager, cudagraph} and {TP1/PP1, TP2, PP2}: replace is bit-exact, single-site patches recover the clean answer (denoising), and all parallelism configs agree.