Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions claas/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class TrainingConfig:
max_grad_norm: float = 1.0
kl_reg_weight: float = 0.0
teacher_top_k: int = 100
use_topk_divergence: bool = True
steps_per_batch: int = 4
feedback_repetitions: int = 1

Expand Down
1 change: 1 addition & 0 deletions claas/eval/configs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,6 @@ training:
max_grad_norm: 1.0
kl_reg_weight: 0.0
teacher_top_k: 100
use_topk_divergence: true
steps_per_batch: 4
feedback_repetitions: 1
1 change: 1 addition & 0 deletions claas/eval/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class TinkerDistillMetrics:
completion_len: int = 0
batch_size: int = 0
steps_per_batch_applied: int = 1
divergence_mode: str = "scalar_kl"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify how TinkerDistillMetrics is instantiated in runner.py
rg -n "TinkerDistillMetrics\(" claas/eval/runner.py -A 15

Repository: kfallah/CLaaS

Length of output: 1601


🏁 Script executed:

#!/bin/bash
# Check the TinkerDistillMetrics dataclass definition in types.py
rg -n "class TinkerDistillMetrics" claas/eval/types.py -A 20

Repository: kfallah/CLaaS

Length of output: 711


🏁 Script executed:

#!/bin/bash
# Check where metadata is populated to see if divergence_mode is passed
rg -n "divergence_mode" claas/ -B 2 -A 2

Repository: kfallah/CLaaS

Length of output: 681


🏁 Script executed:

#!/bin/bash
# Get more context around the explicit TinkerDistillMetrics instantiation
rg -n "def.*:" claas/eval/runner.py | head -20

Repository: kfallah/CLaaS

Length of output: 671


🏁 Script executed:

#!/bin/bash
# Get the function containing the TinkerDistillMetrics instantiation at line 75
sed -n '60,95p' claas/eval/runner.py

Repository: kfallah/CLaaS

Length of output: 1433


Runner.py does not extract divergence_mode from metadata.

The Tinker engine returns divergence_mode in its metadata response (either "topk_gjs" or "scalar_kl"), but the TinkerDistillMetrics instantiation at claas/eval/runner.py:75-85 doesn't extract it. The field will silently default to "scalar_kl" even if the engine returns a different value.

Update the instantiation to include:

divergence_mode=metadata.get("divergence_mode", "scalar_kl"),
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@claas/eval/types.py` at line 134, The Tinker engine returns divergence_mode
in metadata but the TinkerDistillMetrics dataclass's divergence_mode field
(default "scalar_kl") is not being set from metadata; update the
TinkerDistillMetrics instantiation in runner.py (where TinkerDistillMetrics is
constructed) to pass divergence_mode=metadata.get("divergence_mode",
"scalar_kl") so the value from metadata (e.g., "topk_gjs" or "scalar_kl") is
used instead of always using the default.



@dataclass
Expand Down
242 changes: 192 additions & 50 deletions claas/training/engine/tinker/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import os
from datetime import datetime, timezone
from typing import Any, TypedDict
from typing import Any
from urllib.parse import quote

import httpx
Expand Down Expand Up @@ -46,6 +46,19 @@
lora_exists as state_lora_exists,
set_tinker_path,
)
from claas.training.engine.tinker.types import (
BehaviorSignal,
PreparedSample,
ScalarBehavior,
ScalarPreparedSample,
TopKBehavior,
TopKPreparedSample,
)
from claas.training.gjs import (
compute_topk_gjs,
extract_token_logprobs,
slice_completion_topk,
)
from claas.training.teacher_helpers import build_teacher_messages, teacher_messages_to_chat_template

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -187,6 +200,9 @@ async def distill(
kl_coef = payload.training.alpha
steps_per_batch = payload.training.steps_per_batch
feedback_repetitions = payload.training.feedback_repetitions
use_topk = payload.training.use_topk_divergence
teacher_top_k = payload.training.teacher_top_k
alpha = payload.training.alpha

# ── Phase 1: Setup (once per batch) ──
training_client = await self.service.create_training_client_from_state_async(
Expand All @@ -205,12 +221,14 @@ async def distill(
tokenizer=tokenizer,
teacher_sampling=teacher_sampling,
feedback_repetitions=feedback_repetitions,
use_topk=use_topk,
teacher_top_k=teacher_top_k,
)
for sample in payload.samples
]
results = await asyncio.gather(*tasks)
prepared_samples = [r[0] for r in results]
behavior_logprobs = [r[1] for r in results]
prepared_samples: list[PreparedSample] = [r[0] for r in results]
behavior_signals: list[BehaviorSignal] = [r[1] for r in results]

# ── Phase 3: Multi-step training ──
step_metrics: list[dict[str, float | int]] = []
Expand All @@ -219,8 +237,9 @@ async def distill(
datum_metrics = [
_build_sample_datum(
prepared=prepared,
student_logprobs=behavior_logprobs[sample_idx],
behavior=behavior_signals[sample_idx],
kl_coef=kl_coef,
alpha=alpha,
)
for sample_idx, prepared in enumerate(prepared_samples)
]
Expand Down Expand Up @@ -256,10 +275,19 @@ async def distill(

if step_idx < steps_per_batch - 1:
student_sampling = await training_client.save_weights_and_get_sampling_client_async()
behavior_logprobs = await _compute_student_logprobs_for_batch(
student_sampling=student_sampling,
prepared_samples=prepared_samples,
)
if use_topk:
topk_behaviors = await _compute_student_topk_for_batch(
student_sampling=student_sampling,
prepared_samples=prepared_samples,
top_k=teacher_top_k,
)
behavior_signals = list[BehaviorSignal](topk_behaviors)
else:
scalar_logprobs = await _compute_student_logprobs_for_batch(
student_sampling=student_sampling,
prepared_samples=prepared_samples,
)
behavior_signals = [ScalarBehavior(logprobs=lps) for lps in scalar_logprobs]
Comment on lines +278 to +290
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Type annotation mismatch causes static analysis failure.

The pipeline reports list[TopKBehavior] is not assignable to list[BehaviorSignal]. This is a variance issue—lists are invariant in Python's type system.

The fix is to widen the return type annotation of _compute_student_topk_for_batch from list[TopKBehavior] to list[BehaviorSignal]:

🔧 Proposed fix
 async def _compute_student_topk_for_batch(
     *,
     student_sampling: Any,
     prepared_samples: list[PreparedSample],
     top_k: int,
-) -> list[TopKBehavior]:
+) -> list[BehaviorSignal]:
🧰 Tools
🪛 GitHub Actions: CI

[error] 326-331: invalid-assignment: Object of type list[TopKBehavior] is not assignable to list[ScalarBehavior | TopKBehavior]

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@claas/training/engine/tinker/engine.py` around lines 325 - 336, The type
mismatch arises because _compute_student_topk_for_batch is annotated to return
list[TopKBehavior] which is not compatible with list[BehaviorSignal]; update the
return type annotation of the _compute_student_topk_for_batch function to
list[BehaviorSignal] (or a more general Sequence[BehaviorSignal] if preferred)
so its declared output matches the variable behavior_signals and resolves the
invariant list type error involving TopKBehavior and BehaviorSignal.


# ── Phase 4: Save & return (once per request) ──
final_step = step_metrics[-1]
Expand Down Expand Up @@ -297,9 +325,10 @@ async def distill(
"lr": lr,
"loss_fn": "importance_sampling",
"timestamp": datetime.now(timezone.utc).isoformat(),
"teacher_scored_texts": [p["teacher_scored_text"] for p in prepared_samples],
"teacher_scored_texts": [p.teacher_scored_text for p in prepared_samples],
"steps_per_batch_applied": steps_per_batch,
"per_step_metrics": step_metrics,
"divergence_mode": "topk_gjs" if use_topk else "scalar_kl",
}

if final_fwd_metrics is not None:
Expand All @@ -308,14 +337,9 @@ async def distill(
return DistillResponse(lora_id=payload.lora_id, metadata=metadata)


class PreparedSample(TypedDict):
full_tokens: list[int]
input_tokens: list[int]
target_tokens: list[int]
prompt_len: int
completion_len: int
teacher_logprobs: list[float]
teacher_scored_text: str
# ------------------------------------------------------------------
# Sample preparation
# ------------------------------------------------------------------


async def _prepare_sample_inputs(
Expand All @@ -324,8 +348,10 @@ async def _prepare_sample_inputs(
tokenizer: Any,
teacher_sampling: Any,
feedback_repetitions: int,
) -> tuple[PreparedSample, list[float]]:
"""Prepare sample-invariant tensors and initial behavior logprobs."""
use_topk: bool = False,
teacher_top_k: int = 100,
) -> tuple[PreparedSample, BehaviorSignal]:
"""Prepare sample-invariant tensors and initial behavior signal."""
prompt_tokens = list(sample.prompt_token_ids)
response_tokens = list(sample.response_token_ids)
completion_len = len(response_tokens)
Expand Down Expand Up @@ -361,61 +387,125 @@ async def _prepare_sample_inputs(
teacher_full = T.ModelInput.from_ints(teacher_full_tokens)
teacher_scored_text = tokenizer.decode(teacher_full_tokens, skip_special_tokens=False)

teacher_logprobs_full = await teacher_sampling.compute_logprobs_async(teacher_full)
teacher_logprobs = _slice_completion_logprobs(
teacher_logprobs_full,
teacher_prompt_len,
completion_len,
)
initial_behavior = ScalarBehavior(logprobs=list(sample.response_logprobs))

if use_topk:
# Top-K path: fetch teacher top-K distributions via sample_async
response = await teacher_sampling.sample_async(
prompt=teacher_full,
num_samples=1,
sampling_params=T.SamplingParams(max_tokens=1),
include_prompt_logprobs=True,
topk_prompt_logprobs=teacher_top_k,
)
topk_full = response.topk_prompt_logprobs
teacher_topk = slice_completion_topk(topk_full, teacher_prompt_len, completion_len)
teacher_logprobs = extract_token_logprobs(
teacher_topk, response_tokens,
)

prepared: PreparedSample = TopKPreparedSample(
full_tokens=full_tokens,
input_tokens=input_tokens,
target_tokens=target_tokens,
prompt_len=prompt_len,
completion_len=completion_len,
teacher_scored_text=teacher_scored_text,
teacher_logprobs=teacher_logprobs,
teacher_topk=teacher_topk,
)
else:
# Scalar path: fetch teacher logprobs via compute_logprobs_async
teacher_logprobs_full = await teacher_sampling.compute_logprobs_async(teacher_full)
teacher_logprobs = _slice_completion_logprobs(
teacher_logprobs_full,
teacher_prompt_len,
completion_len,
)

prepared = ScalarPreparedSample(
full_tokens=full_tokens,
input_tokens=input_tokens,
target_tokens=target_tokens,
prompt_len=prompt_len,
completion_len=completion_len,
teacher_scored_text=teacher_scored_text,
teacher_logprobs=teacher_logprobs,
)

return prepared, initial_behavior

prepared = PreparedSample(
full_tokens=full_tokens,
input_tokens=input_tokens,
target_tokens=target_tokens,
prompt_len=prompt_len,
completion_len=completion_len,
teacher_logprobs=teacher_logprobs,
teacher_scored_text=teacher_scored_text,
)
return prepared, list(sample.response_logprobs)

# ------------------------------------------------------------------
# Datum construction
# ------------------------------------------------------------------


def _build_sample_datum(
*,
prepared: PreparedSample,
student_logprobs: list[float],
behavior: BehaviorSignal,
kl_coef: float,
alpha: float,
) -> tuple[T.Datum, dict[str, float]]:
"""Build a Tinker datum from prepared teacher signals + current behavior policy."""
completion_len = prepared["completion_len"]
completion_len = prepared.completion_len
student_logprobs = behavior.logprobs
if len(student_logprobs) != completion_len:
raise ValueError(
f"student_logprobs length ({len(student_logprobs)}) != "
f"completion_len ({completion_len})"
)

teacher_logprobs = prepared["teacher_logprobs"]
raw_kl_deltas = [t - s for s, t in zip(student_logprobs, teacher_logprobs, strict=True)]
adv_abs_mean_raw = sum(abs(d) for d in raw_kl_deltas) / max(len(raw_kl_deltas), 1)
if isinstance(prepared, TopKPreparedSample) and isinstance(behavior, TopKBehavior):
# Top-K GJS path
raw_advantages = compute_topk_gjs(
prepared.teacher_topk,
behavior.topk,
alpha,
)
elif isinstance(prepared, ScalarPreparedSample) and isinstance(behavior, ScalarBehavior):
# Scalar KL path (current)
raw_advantages = [
t - s
for s, t in zip(student_logprobs, prepared.teacher_logprobs, strict=True)
]
elif isinstance(prepared, TopKPreparedSample) and isinstance(behavior, ScalarBehavior):
# Initial step in top-K mode: behavior is still scalar from rollout cache
# Fall back to scalar KL using teacher_logprobs derived from top-K
raw_advantages = [
t - s
for s, t in zip(student_logprobs, prepared.teacher_logprobs, strict=True)
Comment on lines +473 to +478
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Apply GJS advantages on the first top-K distillation step

This fallback makes use_topk_divergence=True run scalar teacher_logprob - student_logprob advantages whenever behavior is still ScalarBehavior, which is always true for step 1; with steps_per_batch=1, the whole update never uses GJS even though metadata reports "divergence_mode": "topk_gjs". That means experiments configured for top-K divergence silently train with the old objective in a common single-step setting.

Useful? React with 👍 / 👎.

]
else:
raise TypeError(
f"Incompatible prepared/behavior types: "
f"{type(prepared).__name__} + {type(behavior).__name__}"
)

adv_abs_mean_raw = sum(abs(a) for a in raw_advantages) / max(len(raw_advantages), 1)

gain = 1.0
if adv_abs_mean_raw > 0:
gain = min(max(_TARGET_ADV_ABS_MEAN / adv_abs_mean_raw, 1.0), _MAX_KL_GAIN)
effective_kl_coef = kl_coef * gain
advantages = [
effective_kl_coef * (t - s)
for s, t in zip(student_logprobs, teacher_logprobs, strict=True)
advantages = [effective_kl_coef * a for a in raw_advantages]

# Also compute scalar KL for metrics (always available via teacher_logprobs)
raw_kl_deltas = [
t - s
for s, t in zip(student_logprobs, prepared.teacher_logprobs, strict=True)
]

full_logprobs = [0.0] * prepared["prompt_len"] + student_logprobs
full_advantages = [0.0] * prepared["prompt_len"] + advantages
full_logprobs = [0.0] * prepared.prompt_len + student_logprobs
full_advantages = [0.0] * prepared.prompt_len + advantages
shifted_logprobs = full_logprobs[1:]
shifted_advantages = full_advantages[1:]

datum = T.Datum(
model_input=T.ModelInput.from_ints(prepared["input_tokens"]),
model_input=T.ModelInput.from_ints(prepared.input_tokens),
loss_fn_inputs={
"target_tokens": TensorData(data=prepared["target_tokens"], dtype="int64"),
"target_tokens": TensorData(data=prepared.target_tokens, dtype="int64"),
"logprobs": TensorData(data=shifted_logprobs, dtype="float32"),
"advantages": TensorData(data=shifted_advantages, dtype="float32"),
},
Expand All @@ -433,6 +523,11 @@ def _build_sample_datum(
return datum, metrics


# ------------------------------------------------------------------
# Student behavior recomputation
# ------------------------------------------------------------------


async def _compute_student_logprobs_for_batch(
*,
student_sampling: Any,
Expand All @@ -455,13 +550,60 @@ async def _compute_student_logprobs_for_sample(
prepared: PreparedSample,
) -> list[float]:
"""Compute completion logprobs for one sample under current student weights."""
student_full = T.ModelInput.from_ints(prepared["full_tokens"])
student_full = T.ModelInput.from_ints(prepared.full_tokens)
student_logprobs_full = await student_sampling.compute_logprobs_async(student_full)
return _slice_completion_logprobs(
student_logprobs_full,
prepared["prompt_len"],
prepared["completion_len"],
prepared.prompt_len,
prepared.completion_len,
)


async def _compute_student_topk_for_batch(
*,
student_sampling: Any,
prepared_samples: list[PreparedSample],
top_k: int,
) -> list[TopKBehavior]:
"""Recompute student top-K behavior under the updated student policy."""
tasks = [
_compute_student_topk_for_sample(
student_sampling=student_sampling,
prepared=prepared,
top_k=top_k,
)
for prepared in prepared_samples
]
return await asyncio.gather(*tasks)


async def _compute_student_topk_for_sample(
*,
student_sampling: Any,
prepared: PreparedSample,
top_k: int,
) -> TopKBehavior:
"""Compute student top-K distributions for one sample."""
student_full = T.ModelInput.from_ints(prepared.full_tokens)
response = await student_sampling.sample_async(
prompt=student_full,
num_samples=1,
sampling_params=T.SamplingParams(max_tokens=1),
include_prompt_logprobs=True,
topk_prompt_logprobs=top_k,
)
topk_full = response.topk_prompt_logprobs
student_topk = slice_completion_topk(topk_full, prepared.prompt_len, prepared.completion_len)

# Extract response token IDs from full_tokens
response_tokens = prepared.full_tokens[prepared.prompt_len:]
scalar_logprobs = extract_token_logprobs(student_topk, response_tokens)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep exact student logprobs for importance-sampling ratios

Here the student logprobs used downstream in loss_fn_inputs["logprobs"] are reconstructed from top-K and floored when a rollout token is not in the returned top-K set, so any miss becomes -20.0 regardless of the true probability. In multi-step top-K mode, this can severely skew IS ratios/advantages whenever updated student rankings drop a response token out of top-K (especially for smaller teacher_top_k), producing incorrect training signals.

Useful? React with 👍 / 👎.

return TopKBehavior(logprobs=scalar_logprobs, topk=student_topk)


# ------------------------------------------------------------------
# Utilities
# ------------------------------------------------------------------


def _require_entry(lora_id: str, state_path: str) -> LoraEntry:
Expand Down
Loading
Loading