fix(steering): warmup matches runtime row-monitor specialization; single-source op args#230
Merged
Conversation
…gle-source op args
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Two related data-plane fixes to the steering kernel path.
Fix 1 — warmup compiled the wrong Triton specialization (default config)
warmup_apply_steering_kernelalways allocated full-size per-row-monitor buffers (rprobe = (table_rows, hidden),rparams = (table_rows, 2)). But withenable_row_monitor=False— the default — layers keep the registered(1, 1)/(1, 2)dummy buffers (resize_steering_row_monitor_buffersis a no-op when disabled). The kernel receives the per-row probe table's leading striderp_stride_r = probe_table.stride(0):1for the dummy vshidden_sizefor warmup's buffer. Triton specializes integer args on== 1(constexpr) and divisibility-by-16, so the cache keys differ — warmup compiled a variant the default runtime never hits, and the first real forward JIT-compiled fresh, exactly the served-window/capture-time cost warmup exists to prevent.Fix: thread
row_monitor_enabledintowarmup_apply_steering_kernel; whenFalseallocate(1, 1)/(1, 2)buffers matching the registered dummies, whenTruekeep full-size (matchingresize_steering_row_monitor_buffers). The caller insteering_model_runner_mixin.pypasses the runner's_row_monitor_enabledstate.The old warmup regression test (
test_subsequent_invocations_at_warmed_shape_no_new_variants) built its runtime-mimic with a full-size(8, 128)probe table — replicating warmup's wrong shape rather than the true default runtime shape — so it asserted the buggy behavior and passed. That test is fixed here to use the real default-config shapes and is parametrized over both row-monitor enable states, asserting the JIT cache does not grow after runtime-shaped calls in each.The sibling
steering_monitor_kernel.pywarmup has no analogous issue: none of the monitor op's tensors have a dummy-vs-full-size distinction.Fix 2 — single source of truth for the 15-arg op signature
The 15-tensor positional list was repeated at ~8 sites (emit, op impl, fake, Triton wrapper, kernel launch interleaved with ~16 stride scalars, warmup ×2, tests). All args are same-typed tensors, so a transposition type-checks and fails only behaviorally.
SteeringOpArgs(NamedTuple)(15 tensor fields in canonical order) plus a_build_steering_op_argsbuilder used by_emit_steering_opand by warmup. The registered op, fake, and Triton wrapper keep their flat signatures (torch custom-op schemas require flat tensors) as the only flat sites, each mirroring the NamedTuple order.SteeringOpArgs._fieldsequals the registered op schema's argument names in order._steering_kernel_strides, so tensor/stride pairing is generated from the tensors rather than hand-zipped at the highest-risk interleaved site. The emitted launch is identical.steering_monitor_offsubstitution used in cross-layer monitor mode). No change to the op schema, arity, or any kernel semantics.Notes
Both changes are host-side / warmup-only and produce byte-identical kernel behavior. CPU tests pass:
test_steering_op.py,test_block_steering.py,test_steering_monitor_op.py,test_steering_monitor.py,test_steering_row_monitor.py,test_steering_warmup.py— 57 passed, 5 skipped (CUDA-only).GPU confirmation of the warmup cache-idempotency assertions is pending (those parts skip without CUDA).
Sibling-PR conflicts: none expected on
steering.py/steering_kernel.py.chore/steering-row-owneralso touchessteering_model_runner_mixin.py; the one-line caller change here may conflict trivially.