Restore BitNetPaperModel checkpoint API after TinyLlama constructor refactor#19
Conversation
Co-authored-by: sharpninja <16146732+sharpninja@users.noreply.github.com> Agent-Logs-Url: https://github.com/sharpninja/BitNet-b1.58-Sharp/sessions/fe7b737f-4e6f-4f06-9e06-6960eed93908
Co-authored-by: sharpninja <16146732+sharpninja@users.noreply.github.com> Agent-Logs-Url: https://github.com/sharpninja/BitNet-b1.58-Sharp/sessions/fe7b737f-4e6f-4f06-9e06-6960eed93908
Co-authored-by: sharpninja <16146732+sharpninja@users.noreply.github.com> Agent-Logs-Url: https://github.com/sharpninja/BitNet-b1.58-Sharp/sessions/7f1facbe-4f33-4796-a8ca-9dd7d3c24fb7
Co-authored-by: sharpninja <16146732+sharpninja@users.noreply.github.com> Agent-Logs-Url: https://github.com/sharpninja/BitNet-b1.58-Sharp/sessions/7f1facbe-4f33-4796-a8ca-9dd7d3c24fb7
Co-authored-by: sharpninja <16146732+sharpninja@users.noreply.github.com> Agent-Logs-Url: https://github.com/sharpninja/BitNet-b1.58-Sharp/sessions/6d9e3d8f-ed53-42db-9f91-c5127e65b2af
|
You have reached your Codex usage limits for code reviews. You can see your limits in the Codex usage dashboard. |
There was a problem hiding this comment.
Pull request overview
This PR restores the BitNetPaperModel checkpoint/memory contract that was lost during the TinyLlama constructor refactor, and updates the benchmark + report pipeline to consistently use the shared TinyLlama benchmark corpus (including added perplexity/comparison reporting).
Changes:
- Restores/adjusts paper model training-corpus construction and reintroduces memorized-response export/import for checkpoint round-trips, plus adds a round-trip regression test.
- Switches benchmark training/report generation to use
BitNetTrainingCorpus.CreateBenchmarkExamples()(TinyLlama-1.1B) and adds vocabulary/perplexity validation tests. - Extends the benchmark report with WikiText2 perplexity + a BitNet-vs-traditional comparison summary (tables + inline HTML charts).
Reviewed changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/BitNetSharp.Tests/HostedAgentBenchmarksExecutionTests.cs | Updates training expectations to benchmark corpus; adds perplexity + vocabulary assertions. |
| tests/BitNetSharp.Tests/HostedAgentBenchmarkReportRunnerTests.cs | Updates report runner test to cover new comparison/perplexity + training dataset rendering. |
| tests/BitNetSharp.Tests/BitNetPaperCheckpointTests.cs | Adds checkpoint round-trip test to ensure memorized prompt responses survive save/load. |
| src/BitNetSharp.Core/TraditionalLocalModel.cs | Adds corpus-backed constructor/factory and perplexity evaluation. |
| src/BitNetSharp.Core/BitNetTrainingCorpus.cs | Introduces TinyLlama benchmark examples/vocabulary and a regex-based vocabulary builder. |
| src/BitNetSharp.Core/BitNetPaperModel.cs | Adds corpus-backed constructor/factory, perplexity evaluation, and restores memorized-response export/import for checkpoints. |
| src/BitNetSharp.Core/BitNetPaperAudit.cs | Switches perplexity fixtures to shared BitNetBenchmarkFixtures. |
| src/BitNetSharp.Core/BitNetBootstrap.cs | Adds overload to create a paper model from a provided training corpus. |
| src/BitNetSharp.Core/BitNetBenchmarkFixtures.cs | Adds shared perplexity datasets/validation samples used across audit/reporting/tests. |
| src/BitNetSharp.App/TraditionalLocalHostedAgentModel.cs | Allows construction from a training corpus and surfaces the underlying model for reporting. |
| src/BitNetSharp.App/HostedAgentModelFactory.cs | Threads an optional training corpus into built-in model construction. |
| src/BitNetSharp.App/HostedAgentBenchmarks.cs | Updates training benchmark to train on TinyLlama benchmark examples and build models with that vocabulary. |
| src/BitNetSharp.App/HostedAgentBenchmarkReportRunner.cs | Generates report using benchmark corpus, adds perplexity + comparison summary + charts, and parses performance outputs for derived metrics. |
| docs/benchmarking.md | Updates docs to reflect TinyLlama training corpus and the new comparison/perplexity reporting. |
| public static IReadOnlyList<string> CreateBenchmarkVocabulary() => | ||
| CreateVocabulary( | ||
| CreateBenchmarkExamples(), | ||
| ["tinyllama", "1", "b", "-", ".", "wikitext2", "perplexity", "benchmark", "american", "english", "agent", "framework", "hosting", "tensor", "ternary"]); |
There was a problem hiding this comment.
CreateBenchmarkVocabulary includes an additional token "wikitext2", but the tokenizer/regex splits alphanumerics into separate [A-Za-z]+ and [0-9]+ matches, so the model will never emit a wikitext2 token during tokenization (it will be wikitext + 2). This makes the token effectively unreachable noise in the vocabulary. Consider removing it or replacing it with the tokens the tokenizer can actually produce.
| ["tinyllama", "1", "b", "-", ".", "wikitext2", "perplexity", "benchmark", "american", "english", "agent", "framework", "hosting", "tensor", "ternary"]); | |
| ["tinyllama", "1", "b", "-", ".", "perplexity", "benchmark", "american", "english", "agent", "framework", "hosting", "tensor", "ternary"]); |
| _model = TraditionalLocalModel.CreateDefault(verbosity); | ||
| _trainingCorpusDescription = trainingExamples is null | ||
| ? "default corpus" | ||
| : BitNetTrainingCorpus.BenchmarkDatasetName; |
There was a problem hiding this comment.
_trainingCorpusDescription is set to BenchmarkDatasetName whenever trainingExamples is non-null, but callers can pass arbitrary corpora via HostedAgentModelFactory.Create(..., trainingExamples). This makes the displayed training description potentially incorrect. Consider using a neutral label like "provided corpus" or threading through an explicit dataset name/description instead of assuming TinyLlama.
| : BitNetTrainingCorpus.BenchmarkDatasetName; | |
| : "provided corpus"; |
| var traditionalModel = TraditionalLocalModel.CreateForTrainingCorpus(examples); | ||
|
|
||
| bitNetModel.Train(examples, epochs: 3); | ||
| traditionalModel.Train(examples, epochs: TraditionalLocalModel.DefaultTrainingEpochs); |
There was a problem hiding this comment.
This test trains TraditionalLocalModel for DefaultTrainingEpochs (24) just to assert perplexity is finite/positive, which can unnecessarily slow the unit test suite. Consider reducing epochs (e.g., 1-3) or using a cheaper training configuration while keeping the same assertion intent.
| traditionalModel.Train(examples, epochs: TraditionalLocalModel.DefaultTrainingEpochs); | |
| traditionalModel.Train(examples, epochs: 3); |
| public sealed class BitNetPaperModel | ||
| { | ||
| private const int MaxPredictionLimit = 8; | ||
| private const double ProbabilityFloor = 1e-9d; |
There was a problem hiding this comment.
MaxPredictionLimit is declared but never used. With TreatWarningsAsErrors enabled, this will fail the build due to an unused private const/field warning. Either remove the constant or use it where the prediction list limit is applied.
| private const double ProbabilityFloor = 1e-9d; |
| var totalLoss = 0d; | ||
| var totalTokens = 0; | ||
| foreach (var sample in validationSamples) | ||
| { | ||
| var tokenIds = EncodeTokenIds(sample, appendEndToken: true); | ||
| for (var index = 0; index < tokenIds.Count - 1; index++) | ||
| { | ||
| var context = tokenIds.Take(index + 1).ToArray(); | ||
| var logits = ForwardLogits(context); | ||
| totalLoss -= Math.Log(GetTargetProbability(logits, tokenIds[index + 1])); | ||
| totalTokens++; | ||
| } | ||
| } | ||
|
|
||
| return totalTokens == 0 ? 0d : Math.Exp(totalLoss / totalTokens); |
There was a problem hiding this comment.
CalculatePerplexity reads model state (token IDs + transformer logits) without taking _gate, while Train/GenerateResponse are locked. This makes perplexity evaluation potentially race with training/inference and can produce inconsistent results. Consider wrapping the calculation in lock (_gate) (or otherwise ensuring thread safety) similar to TraditionalLocalModel.CalculatePerplexity.
| var totalLoss = 0d; | |
| var totalTokens = 0; | |
| foreach (var sample in validationSamples) | |
| { | |
| var tokenIds = EncodeTokenIds(sample, appendEndToken: true); | |
| for (var index = 0; index < tokenIds.Count - 1; index++) | |
| { | |
| var context = tokenIds.Take(index + 1).ToArray(); | |
| var logits = ForwardLogits(context); | |
| totalLoss -= Math.Log(GetTargetProbability(logits, tokenIds[index + 1])); | |
| totalTokens++; | |
| } | |
| } | |
| return totalTokens == 0 ? 0d : Math.Exp(totalLoss / totalTokens); | |
| lock (_gate) | |
| { | |
| var totalLoss = 0d; | |
| var totalTokens = 0; | |
| foreach (var sample in validationSamples) | |
| { | |
| var tokenIds = EncodeTokenIds(sample, appendEndToken: true); | |
| for (var index = 0; index < tokenIds.Count - 1; index++) | |
| { | |
| var context = tokenIds.Take(index + 1).ToArray(); | |
| var logits = ForwardLogits(context); | |
| totalLoss -= Math.Log(GetTargetProbability(logits, tokenIds[index + 1])); | |
| totalTokens++; | |
| } | |
| } | |
| return totalTokens == 0 ? 0d : Math.Exp(totalLoss / totalTokens); | |
| } |
| double? responseTokensPerSecond = responseMeanMilliseconds is > 0d && modelReport.BenchmarkPromptTokenCount > 0 | ||
| ? (modelReport.BenchmarkPromptTokenCount * MillisecondsPerSecond) / responseMeanMilliseconds.Value | ||
| : null; |
There was a problem hiding this comment.
ResponseTokensPerSecond is computed using BenchmarkPromptTokenCount derived from an ad-hoc GetResponseAsync call on the (already trained) model, but the ResponseOperation mean/allocated numbers come from BenchmarkDotNet runs (HostedAgentResponseBenchmarks) that construct models without the TinyLlama training corpus and without training. This mismatch can make the reported tokens/sec and deltas inaccurate/misleading. Align the response benchmark setup with the report model setup (e.g., train once in a BenchmarkDotNet GlobalSetup using the same corpus), or compute tokens/sec from the benchmarked response output itself.
| double? responseTokensPerSecond = responseMeanMilliseconds is > 0d && modelReport.BenchmarkPromptTokenCount > 0 | |
| ? (modelReport.BenchmarkPromptTokenCount * MillisecondsPerSecond) / responseMeanMilliseconds.Value | |
| : null; | |
| // Tokens-per-second cannot be reliably computed here because the prompt token count | |
| // comes from a different execution path/model configuration than the benchmark timings. | |
| // Leave this metric unset until it can be derived from consistent benchmark data. | |
| double? responseTokensPerSecond = null; |
| var parts = value.Split(' ', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); | ||
| if (parts.Length < 2 || !double.TryParse(parts[0], out var magnitude)) | ||
| { | ||
| return null; | ||
| } | ||
|
|
There was a problem hiding this comment.
double.TryParse is used without an explicit culture when parsing BenchmarkDotNet output. On machines with a non-invariant current culture (e.g., comma decimal separators), this can fail to parse and silently drop metrics. Use double.TryParse(..., NumberStyles.Float, CultureInfo.InvariantCulture, out ...) (and similarly for allocated parsing).
The TinyLlama benchmark work moved training/model initialization onto constructor-backed corpus setup, but it also dropped part of the
BitNetPaperModelsurface that checkpoint serialization still depends on. This PR restores that model contract so the PR builds again and checkpoint round-trips continue to work.Checkpoint/model contract
BitNetPaperModel.ExportMemorizedResponses()BitNetPaperModel.ImportMemorizedResponses(...)Why the build broke
BitNetPaperCheckpoint.Save(...)still exports trained prompt memory from the modelBitNetPaperCheckpoint.Load(...)still rehydrates that memory into the model_memorizedResponsesinternally, but removed the accessors the checkpoint path callsRegression coverage
⚡ Quickly spin up Copilot coding agent tasks from anywhere on your macOS or Windows machine with Raycast.