-
Notifications
You must be signed in to change notification settings - Fork 3
Add top-K GJS divergence mode to Tinker SDPO engine #41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,7 @@ | |
| import logging | ||
| import os | ||
| from datetime import datetime, timezone | ||
| from typing import Any, TypedDict | ||
| from typing import Any | ||
| from urllib.parse import quote | ||
|
|
||
| import httpx | ||
|
|
@@ -46,6 +46,19 @@ | |
| lora_exists as state_lora_exists, | ||
| set_tinker_path, | ||
| ) | ||
| from claas.training.engine.tinker.types import ( | ||
| BehaviorSignal, | ||
| PreparedSample, | ||
| ScalarBehavior, | ||
| ScalarPreparedSample, | ||
| TopKBehavior, | ||
| TopKPreparedSample, | ||
| ) | ||
| from claas.training.gjs import ( | ||
| compute_topk_gjs, | ||
| extract_token_logprobs, | ||
| slice_completion_topk, | ||
| ) | ||
| from claas.training.teacher_helpers import build_teacher_messages, teacher_messages_to_chat_template | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -187,6 +200,9 @@ async def distill( | |
| kl_coef = payload.training.alpha | ||
| steps_per_batch = payload.training.steps_per_batch | ||
| feedback_repetitions = payload.training.feedback_repetitions | ||
| use_topk = payload.training.use_topk_divergence | ||
| teacher_top_k = payload.training.teacher_top_k | ||
| alpha = payload.training.alpha | ||
|
|
||
| # ── Phase 1: Setup (once per batch) ── | ||
| training_client = await self.service.create_training_client_from_state_async( | ||
|
|
@@ -205,12 +221,14 @@ async def distill( | |
| tokenizer=tokenizer, | ||
| teacher_sampling=teacher_sampling, | ||
| feedback_repetitions=feedback_repetitions, | ||
| use_topk=use_topk, | ||
| teacher_top_k=teacher_top_k, | ||
| ) | ||
| for sample in payload.samples | ||
| ] | ||
| results = await asyncio.gather(*tasks) | ||
| prepared_samples = [r[0] for r in results] | ||
| behavior_logprobs = [r[1] for r in results] | ||
| prepared_samples: list[PreparedSample] = [r[0] for r in results] | ||
| behavior_signals: list[BehaviorSignal] = [r[1] for r in results] | ||
|
|
||
| # ── Phase 3: Multi-step training ── | ||
| step_metrics: list[dict[str, float | int]] = [] | ||
|
|
@@ -219,8 +237,9 @@ async def distill( | |
| datum_metrics = [ | ||
| _build_sample_datum( | ||
| prepared=prepared, | ||
| student_logprobs=behavior_logprobs[sample_idx], | ||
| behavior=behavior_signals[sample_idx], | ||
| kl_coef=kl_coef, | ||
| alpha=alpha, | ||
| ) | ||
| for sample_idx, prepared in enumerate(prepared_samples) | ||
| ] | ||
|
|
@@ -256,10 +275,19 @@ async def distill( | |
|
|
||
| if step_idx < steps_per_batch - 1: | ||
| student_sampling = await training_client.save_weights_and_get_sampling_client_async() | ||
| behavior_logprobs = await _compute_student_logprobs_for_batch( | ||
| student_sampling=student_sampling, | ||
| prepared_samples=prepared_samples, | ||
| ) | ||
| if use_topk: | ||
| topk_behaviors = await _compute_student_topk_for_batch( | ||
| student_sampling=student_sampling, | ||
| prepared_samples=prepared_samples, | ||
| top_k=teacher_top_k, | ||
| ) | ||
| behavior_signals = list[BehaviorSignal](topk_behaviors) | ||
| else: | ||
| scalar_logprobs = await _compute_student_logprobs_for_batch( | ||
| student_sampling=student_sampling, | ||
| prepared_samples=prepared_samples, | ||
| ) | ||
| behavior_signals = [ScalarBehavior(logprobs=lps) for lps in scalar_logprobs] | ||
|
Comment on lines
+278
to
+290
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Type annotation mismatch causes static analysis failure. The pipeline reports The fix is to widen the return type annotation of 🔧 Proposed fix async def _compute_student_topk_for_batch(
*,
student_sampling: Any,
prepared_samples: list[PreparedSample],
top_k: int,
-) -> list[TopKBehavior]:
+) -> list[BehaviorSignal]:🧰 Tools🪛 GitHub Actions: CI[error] 326-331: invalid-assignment: Object of type 🤖 Prompt for AI Agents |
||
|
|
||
| # ── Phase 4: Save & return (once per request) ── | ||
| final_step = step_metrics[-1] | ||
|
|
@@ -297,9 +325,10 @@ async def distill( | |
| "lr": lr, | ||
| "loss_fn": "importance_sampling", | ||
| "timestamp": datetime.now(timezone.utc).isoformat(), | ||
| "teacher_scored_texts": [p["teacher_scored_text"] for p in prepared_samples], | ||
| "teacher_scored_texts": [p.teacher_scored_text for p in prepared_samples], | ||
| "steps_per_batch_applied": steps_per_batch, | ||
| "per_step_metrics": step_metrics, | ||
| "divergence_mode": "topk_gjs" if use_topk else "scalar_kl", | ||
| } | ||
|
|
||
| if final_fwd_metrics is not None: | ||
|
|
@@ -308,14 +337,9 @@ async def distill( | |
| return DistillResponse(lora_id=payload.lora_id, metadata=metadata) | ||
|
|
||
|
|
||
| class PreparedSample(TypedDict): | ||
| full_tokens: list[int] | ||
| input_tokens: list[int] | ||
| target_tokens: list[int] | ||
| prompt_len: int | ||
| completion_len: int | ||
| teacher_logprobs: list[float] | ||
| teacher_scored_text: str | ||
| # ------------------------------------------------------------------ | ||
| # Sample preparation | ||
| # ------------------------------------------------------------------ | ||
|
|
||
|
|
||
| async def _prepare_sample_inputs( | ||
|
|
@@ -324,8 +348,10 @@ async def _prepare_sample_inputs( | |
| tokenizer: Any, | ||
| teacher_sampling: Any, | ||
| feedback_repetitions: int, | ||
| ) -> tuple[PreparedSample, list[float]]: | ||
| """Prepare sample-invariant tensors and initial behavior logprobs.""" | ||
| use_topk: bool = False, | ||
| teacher_top_k: int = 100, | ||
| ) -> tuple[PreparedSample, BehaviorSignal]: | ||
| """Prepare sample-invariant tensors and initial behavior signal.""" | ||
| prompt_tokens = list(sample.prompt_token_ids) | ||
| response_tokens = list(sample.response_token_ids) | ||
| completion_len = len(response_tokens) | ||
|
|
@@ -361,61 +387,125 @@ async def _prepare_sample_inputs( | |
| teacher_full = T.ModelInput.from_ints(teacher_full_tokens) | ||
| teacher_scored_text = tokenizer.decode(teacher_full_tokens, skip_special_tokens=False) | ||
|
|
||
| teacher_logprobs_full = await teacher_sampling.compute_logprobs_async(teacher_full) | ||
| teacher_logprobs = _slice_completion_logprobs( | ||
| teacher_logprobs_full, | ||
| teacher_prompt_len, | ||
| completion_len, | ||
| ) | ||
| initial_behavior = ScalarBehavior(logprobs=list(sample.response_logprobs)) | ||
|
|
||
| if use_topk: | ||
| # Top-K path: fetch teacher top-K distributions via sample_async | ||
| response = await teacher_sampling.sample_async( | ||
| prompt=teacher_full, | ||
| num_samples=1, | ||
| sampling_params=T.SamplingParams(max_tokens=1), | ||
| include_prompt_logprobs=True, | ||
| topk_prompt_logprobs=teacher_top_k, | ||
| ) | ||
| topk_full = response.topk_prompt_logprobs | ||
| teacher_topk = slice_completion_topk(topk_full, teacher_prompt_len, completion_len) | ||
| teacher_logprobs = extract_token_logprobs( | ||
| teacher_topk, response_tokens, | ||
| ) | ||
|
|
||
| prepared: PreparedSample = TopKPreparedSample( | ||
| full_tokens=full_tokens, | ||
| input_tokens=input_tokens, | ||
| target_tokens=target_tokens, | ||
| prompt_len=prompt_len, | ||
| completion_len=completion_len, | ||
| teacher_scored_text=teacher_scored_text, | ||
| teacher_logprobs=teacher_logprobs, | ||
| teacher_topk=teacher_topk, | ||
| ) | ||
| else: | ||
| # Scalar path: fetch teacher logprobs via compute_logprobs_async | ||
| teacher_logprobs_full = await teacher_sampling.compute_logprobs_async(teacher_full) | ||
| teacher_logprobs = _slice_completion_logprobs( | ||
| teacher_logprobs_full, | ||
| teacher_prompt_len, | ||
| completion_len, | ||
| ) | ||
|
|
||
| prepared = ScalarPreparedSample( | ||
| full_tokens=full_tokens, | ||
| input_tokens=input_tokens, | ||
| target_tokens=target_tokens, | ||
| prompt_len=prompt_len, | ||
| completion_len=completion_len, | ||
| teacher_scored_text=teacher_scored_text, | ||
| teacher_logprobs=teacher_logprobs, | ||
| ) | ||
|
|
||
| return prepared, initial_behavior | ||
|
|
||
| prepared = PreparedSample( | ||
| full_tokens=full_tokens, | ||
| input_tokens=input_tokens, | ||
| target_tokens=target_tokens, | ||
| prompt_len=prompt_len, | ||
| completion_len=completion_len, | ||
| teacher_logprobs=teacher_logprobs, | ||
| teacher_scored_text=teacher_scored_text, | ||
| ) | ||
| return prepared, list(sample.response_logprobs) | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Datum construction | ||
| # ------------------------------------------------------------------ | ||
|
|
||
|
|
||
| def _build_sample_datum( | ||
| *, | ||
| prepared: PreparedSample, | ||
| student_logprobs: list[float], | ||
| behavior: BehaviorSignal, | ||
| kl_coef: float, | ||
| alpha: float, | ||
| ) -> tuple[T.Datum, dict[str, float]]: | ||
| """Build a Tinker datum from prepared teacher signals + current behavior policy.""" | ||
| completion_len = prepared["completion_len"] | ||
| completion_len = prepared.completion_len | ||
| student_logprobs = behavior.logprobs | ||
| if len(student_logprobs) != completion_len: | ||
| raise ValueError( | ||
| f"student_logprobs length ({len(student_logprobs)}) != " | ||
| f"completion_len ({completion_len})" | ||
| ) | ||
|
|
||
| teacher_logprobs = prepared["teacher_logprobs"] | ||
| raw_kl_deltas = [t - s for s, t in zip(student_logprobs, teacher_logprobs, strict=True)] | ||
| adv_abs_mean_raw = sum(abs(d) for d in raw_kl_deltas) / max(len(raw_kl_deltas), 1) | ||
| if isinstance(prepared, TopKPreparedSample) and isinstance(behavior, TopKBehavior): | ||
| # Top-K GJS path | ||
| raw_advantages = compute_topk_gjs( | ||
| prepared.teacher_topk, | ||
| behavior.topk, | ||
| alpha, | ||
| ) | ||
| elif isinstance(prepared, ScalarPreparedSample) and isinstance(behavior, ScalarBehavior): | ||
| # Scalar KL path (current) | ||
| raw_advantages = [ | ||
| t - s | ||
| for s, t in zip(student_logprobs, prepared.teacher_logprobs, strict=True) | ||
| ] | ||
| elif isinstance(prepared, TopKPreparedSample) and isinstance(behavior, ScalarBehavior): | ||
| # Initial step in top-K mode: behavior is still scalar from rollout cache | ||
| # Fall back to scalar KL using teacher_logprobs derived from top-K | ||
| raw_advantages = [ | ||
| t - s | ||
| for s, t in zip(student_logprobs, prepared.teacher_logprobs, strict=True) | ||
|
Comment on lines
+473
to
+478
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This fallback makes Useful? React with 👍 / 👎. |
||
| ] | ||
| else: | ||
| raise TypeError( | ||
| f"Incompatible prepared/behavior types: " | ||
| f"{type(prepared).__name__} + {type(behavior).__name__}" | ||
| ) | ||
|
|
||
| adv_abs_mean_raw = sum(abs(a) for a in raw_advantages) / max(len(raw_advantages), 1) | ||
|
|
||
| gain = 1.0 | ||
| if adv_abs_mean_raw > 0: | ||
| gain = min(max(_TARGET_ADV_ABS_MEAN / adv_abs_mean_raw, 1.0), _MAX_KL_GAIN) | ||
| effective_kl_coef = kl_coef * gain | ||
| advantages = [ | ||
| effective_kl_coef * (t - s) | ||
| for s, t in zip(student_logprobs, teacher_logprobs, strict=True) | ||
| advantages = [effective_kl_coef * a for a in raw_advantages] | ||
|
|
||
| # Also compute scalar KL for metrics (always available via teacher_logprobs) | ||
| raw_kl_deltas = [ | ||
| t - s | ||
| for s, t in zip(student_logprobs, prepared.teacher_logprobs, strict=True) | ||
| ] | ||
|
|
||
| full_logprobs = [0.0] * prepared["prompt_len"] + student_logprobs | ||
| full_advantages = [0.0] * prepared["prompt_len"] + advantages | ||
| full_logprobs = [0.0] * prepared.prompt_len + student_logprobs | ||
| full_advantages = [0.0] * prepared.prompt_len + advantages | ||
| shifted_logprobs = full_logprobs[1:] | ||
| shifted_advantages = full_advantages[1:] | ||
|
|
||
| datum = T.Datum( | ||
| model_input=T.ModelInput.from_ints(prepared["input_tokens"]), | ||
| model_input=T.ModelInput.from_ints(prepared.input_tokens), | ||
| loss_fn_inputs={ | ||
| "target_tokens": TensorData(data=prepared["target_tokens"], dtype="int64"), | ||
| "target_tokens": TensorData(data=prepared.target_tokens, dtype="int64"), | ||
| "logprobs": TensorData(data=shifted_logprobs, dtype="float32"), | ||
| "advantages": TensorData(data=shifted_advantages, dtype="float32"), | ||
| }, | ||
|
|
@@ -433,6 +523,11 @@ def _build_sample_datum( | |
| return datum, metrics | ||
|
|
||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Student behavior recomputation | ||
| # ------------------------------------------------------------------ | ||
|
|
||
|
|
||
| async def _compute_student_logprobs_for_batch( | ||
| *, | ||
| student_sampling: Any, | ||
|
|
@@ -455,13 +550,60 @@ async def _compute_student_logprobs_for_sample( | |
| prepared: PreparedSample, | ||
| ) -> list[float]: | ||
| """Compute completion logprobs for one sample under current student weights.""" | ||
| student_full = T.ModelInput.from_ints(prepared["full_tokens"]) | ||
| student_full = T.ModelInput.from_ints(prepared.full_tokens) | ||
| student_logprobs_full = await student_sampling.compute_logprobs_async(student_full) | ||
| return _slice_completion_logprobs( | ||
| student_logprobs_full, | ||
| prepared["prompt_len"], | ||
| prepared["completion_len"], | ||
| prepared.prompt_len, | ||
| prepared.completion_len, | ||
| ) | ||
|
|
||
|
|
||
| async def _compute_student_topk_for_batch( | ||
| *, | ||
| student_sampling: Any, | ||
| prepared_samples: list[PreparedSample], | ||
| top_k: int, | ||
| ) -> list[TopKBehavior]: | ||
| """Recompute student top-K behavior under the updated student policy.""" | ||
| tasks = [ | ||
| _compute_student_topk_for_sample( | ||
| student_sampling=student_sampling, | ||
| prepared=prepared, | ||
| top_k=top_k, | ||
| ) | ||
| for prepared in prepared_samples | ||
| ] | ||
| return await asyncio.gather(*tasks) | ||
|
|
||
|
|
||
| async def _compute_student_topk_for_sample( | ||
| *, | ||
| student_sampling: Any, | ||
| prepared: PreparedSample, | ||
| top_k: int, | ||
| ) -> TopKBehavior: | ||
| """Compute student top-K distributions for one sample.""" | ||
| student_full = T.ModelInput.from_ints(prepared.full_tokens) | ||
| response = await student_sampling.sample_async( | ||
| prompt=student_full, | ||
| num_samples=1, | ||
| sampling_params=T.SamplingParams(max_tokens=1), | ||
| include_prompt_logprobs=True, | ||
| topk_prompt_logprobs=top_k, | ||
| ) | ||
| topk_full = response.topk_prompt_logprobs | ||
| student_topk = slice_completion_topk(topk_full, prepared.prompt_len, prepared.completion_len) | ||
|
|
||
| # Extract response token IDs from full_tokens | ||
| response_tokens = prepared.full_tokens[prepared.prompt_len:] | ||
| scalar_logprobs = extract_token_logprobs(student_topk, response_tokens) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Here the student logprobs used downstream in Useful? React with 👍 / 👎. |
||
| return TopKBehavior(logprobs=scalar_logprobs, topk=student_topk) | ||
|
|
||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Utilities | ||
| # ------------------------------------------------------------------ | ||
|
|
||
|
|
||
| def _require_entry(lora_id: str, state_path: str) -> LoraEntry: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: kfallah/CLaaS
Length of output: 1601
🏁 Script executed:
Repository: kfallah/CLaaS
Length of output: 711
🏁 Script executed:
Repository: kfallah/CLaaS
Length of output: 681
🏁 Script executed:
Repository: kfallah/CLaaS
Length of output: 671
🏁 Script executed:
Repository: kfallah/CLaaS
Length of output: 1433
Runner.py does not extract
divergence_modefrom metadata.The Tinker engine returns
divergence_modein its metadata response (either"topk_gjs"or"scalar_kl"), but theTinkerDistillMetricsinstantiation atclaas/eval/runner.py:75-85doesn't extract it. The field will silently default to"scalar_kl"even if the engine returns a different value.Update the instantiation to include:
🤖 Prompt for AI Agents