Offline knowledge distillation from cached teacher top-20 logprobs into a student model. Implements the design in (./trainer.md).
Five stages, three commands:
result/*.json ──prepare.py──▶ packed/ ──train.py──▶ distilled_student/
└─eval.py (called during training)
The teacher signal (tokens, logprobs) is in Kimi-K2 token ids.
Token-aligned distillation is only valid if the student shares the teacher's tokenizer/vocab (~163,840 tokens).
prepare.py enforces this with a vocab-size + decode(tokens)==text round-trip check and aborts on mismatch.
Point --tokenizer-name at the teacher tokenizer and use a student with that vocab. gpt2 (the placeholder default) will fail the check — that is the check working.
python prepare.py \
--student-model <hf-repo-with-kimi-vocab> \
--result-glob './result/*.json' \
--packed-dir ./packedPools every axis, drops gold/valid, validates token↔logprob alignment (items like code.json whose arrays don't line up are dropped),
tokenizes prompts, and writes packed/shard_*.pkl + meta.json.
Quick smoke test without the real tokenizer:
python prepare.py --limit-items 20 --skip-checks (targets are only meaningful with the correct vocab — use this for plumbing only).
python train.py \
--student-model <hf-repo-with-kimi-vocab> \
--packed-dir ./packed \
--output-dir ./distilled_student \
--temperature 2.0 --alpha 0.5 \
--batch-size 4 --grad-accum-steps 4 --num-epochs 3 --learning-rate 2e-5 \
--bf16Loss = alpha · (T² · forward-KL over teacher top-20) + (1-alpha) · CE(teacher token) (loss.py).
Memory is dominated by the [B, S, 163k] logits — keep batch_size small and lean on --grad-accum-steps;
add --gradient-checkpointing if needed.
Runs automatically every --eval-steps and at each epoch end, reporting the subnet-style signal (not perplexity): teacher-forced forward KL,
top-1 agreement, and overlap@k.
Call eval.py evaluate() standalone for a saved checkpoint.
| file | role |
|---|---|
| config.py | one flat DistillConfig; every field is a CLI flag |
| prepare.py | result/*.json → validated packed shards (torch-free) |
| data.py | PackedDistillDataset + DistillCollator (sequence-aligned batches) |
| loss.py | topk_distill_loss — sparse top-k forward KL + hard CE |
| train.py | training loop (AdamW + warmup, grad-accum, save) |
| eval.py | evaluate — KL / top-1 / overlap on held-out data |