Skip to content

Commit 8d38bb2

Browse files
authored
Merge pull request #22 from sharpninja/copilot/build-traditional-model-features
Bring `traditional-local` to feature parity for built-in model comparisons
2 parents 75c1e6b + 072a094 commit 8d38bb2

File tree

10 files changed

+299
-7
lines changed

10 files changed

+299
-7
lines changed

docs/training-and-visualization.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dotnet run --project src/BitNetSharp.App/BitNetSharp.App.csproj -- visualize
1111
dotnet run --project src/BitNetSharp.App/BitNetSharp.App.csproj -- paper-audit
1212
```
1313

14-
This command prints the current paper-model configuration and an aggregated ternary weight histogram across every `BitLinear` projection in the seeded transformer.
14+
This command prints the current model configuration and an aggregated signed-weight histogram for the selected built-in model, so the seeded transformer and `traditional-local` expose the same comparison-friendly inspection surface.
1515
The `paper-audit` command adds a structured checklist on top of that inspection output so the repository can report which paper-aligned architecture requirements are currently implemented and which end-to-end reproduction items are still pending.
1616

1717
## Inspect next-token predictions

docs/usage.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ dotnet run --project /home/runner/work/BitNet-b1.58-Sharp/BitNet-b1.58-Sharp/src
3838
dotnet run --project /home/runner/work/BitNet-b1.58-Sharp/BitNet-b1.58-Sharp/src/BitNetSharp.App/BitNetSharp.App.csproj -- paper-audit
3939
```
4040

41-
The `visualize` command prints the current model summary. When the selected model is the paper-aligned BitNet transformer, it also prints the ternary weight histogram across the transformer's `BitLinear` projections.
41+
The `visualize` command prints the current model summary. For the built-in BitNet and `traditional-local` models, it also prints a signed weight histogram so both comparison models expose the same inspection surface from the CLI.
4242

4343
The `paper-audit` command turns the paper checklist into an executable report. It confirms the implemented architecture requirements that the repository currently satisfies and also verifies the repository-local runtime surface for paper-model fine-tuning, named perplexity fixture measurements, zero-shot fixture evaluation, and checkpoint round-tripping.
4444

src/BitNetSharp.App/Program.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
}
6767
else
6868
{
69-
Console.WriteLine($"Model '{model.ModelId}' does not expose BitNet ternary weight inspection.");
69+
Console.WriteLine($"Model '{model.ModelId}' does not expose repository weight-sign inspection.");
7070
}
7171
break;
7272

@@ -119,7 +119,7 @@ static string FormatWeightHistogram(TernaryWeightStats stats)
119119
return string.Join(
120120
Environment.NewLine,
121121
[
122-
"Ternary weight distribution",
122+
"Weight sign distribution",
123123
FormatBar("-1", stats.NegativeCount, max, scale),
124124
FormatBar(" 0", stats.ZeroCount, max, scale),
125125
FormatBar("+1", stats.PositiveCount, max, scale)

src/BitNetSharp.App/TraditionalLocalHostedAgentModel.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
using BitNetSharp.Core;
2+
using BitNetSharp.Core.Quantization;
23

34
namespace BitNetSharp.App;
45

5-
public sealed class TraditionalLocalHostedAgentModel : IHostedAgentModel, ITrainableHostedAgentModel
6+
public sealed class TraditionalLocalHostedAgentModel : IHostedAgentModel, IInspectableHostedAgentModel, ITrainableHostedAgentModel
67
{
78
private readonly string _trainingCorpusDescription;
89

@@ -51,6 +52,8 @@ public Task<HostedAgentModelResponse> GetResponseAsync(
5152
return Task.FromResult(new HostedAgentModelResponse(result.ResponseText, result.Diagnostics));
5253
}
5354

55+
public TernaryWeightStats GetTernaryWeightStats() => Model.GetTernaryWeightStats();
56+
5457
public void Train(IEnumerable<TrainingExample> examples, int epochs = 1)
5558
{
5659
Model.Train(examples, Math.Max(TraditionalLocalModel.DefaultTrainingEpochs, epochs));

src/BitNetSharp.Core/BitNetVisualizer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ private static string RenderWeightHistogram(TrainingReport trainingReport)
3737
Environment.NewLine,
3838
new[]
3939
{
40-
"Ternary weight distribution",
40+
"Weight sign distribution",
4141
FormatBar("-1", trainingReport.NegativeWeights, max),
4242
FormatBar(" 0", trainingReport.ZeroWeights, max),
4343
FormatBar("+1", trainingReport.PositiveWeights, max)
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
using System.Text.Json;
2+
3+
namespace BitNetSharp.Core;
4+
5+
public sealed record TraditionalLocalCheckpointValidationResult(
6+
string Prompt,
7+
string OriginalResponse,
8+
string ReloadedResponse,
9+
bool ResponsesMatch);
10+
11+
internal sealed record TraditionalLocalCheckpointDocument(
12+
string Format,
13+
string ModelId,
14+
int Seed,
15+
int EmbeddingDimension,
16+
int ContextWindow,
17+
IReadOnlyList<string> Vocabulary,
18+
float[] TokenEmbeddings,
19+
float[] OutputWeights,
20+
float[] OutputBias,
21+
int MaxResponseTokens,
22+
string PrimaryLanguage);
23+
24+
public static class TraditionalLocalCheckpoint
25+
{
26+
private const string FormatName = "traditional-local.repository-checkpoint.v1";
27+
28+
public static void Save(TraditionalLocalModel model, string path)
29+
{
30+
ArgumentNullException.ThrowIfNull(model);
31+
ArgumentException.ThrowIfNullOrWhiteSpace(path);
32+
33+
var directory = Path.GetDirectoryName(path);
34+
if (!string.IsNullOrWhiteSpace(directory))
35+
{
36+
Directory.CreateDirectory(directory);
37+
}
38+
39+
var document = new TraditionalLocalCheckpointDocument(
40+
FormatName,
41+
model.ModelId,
42+
model.Seed,
43+
model.EmbeddingDimension,
44+
model.ContextWindow,
45+
model.Options.Vocabulary.ToArray(),
46+
model.ExportTokenEmbeddings(),
47+
model.ExportOutputWeights(),
48+
model.ExportOutputBias(),
49+
model.Options.MaxResponseTokens,
50+
model.Options.PrimaryLanguage);
51+
File.WriteAllText(path, JsonSerializer.Serialize(document, new JsonSerializerOptions { WriteIndented = true }));
52+
}
53+
54+
public static TraditionalLocalModel Load(string path, VerbosityLevel verbosity = VerbosityLevel.Normal)
55+
{
56+
ArgumentException.ThrowIfNullOrWhiteSpace(path);
57+
58+
var document = JsonSerializer.Deserialize<TraditionalLocalCheckpointDocument>(File.ReadAllText(path))
59+
?? throw new InvalidOperationException("Could not deserialize the traditional local checkpoint document.");
60+
if (!string.Equals(document.Format, FormatName, StringComparison.Ordinal))
61+
{
62+
throw new InvalidOperationException($"Unsupported checkpoint format '{document.Format}'.");
63+
}
64+
65+
var model = new TraditionalLocalModel(
66+
new BitNetOptions(
67+
document.Vocabulary.ToArray(),
68+
verbosity,
69+
document.MaxResponseTokens,
70+
document.PrimaryLanguage),
71+
document.EmbeddingDimension,
72+
document.ContextWindow,
73+
document.Seed);
74+
model.ImportState(document.TokenEmbeddings, document.OutputWeights, document.OutputBias);
75+
return model;
76+
}
77+
78+
public static TraditionalLocalCheckpointValidationResult ValidateRoundTrip(TraditionalLocalModel model, string prompt)
79+
{
80+
ArgumentNullException.ThrowIfNull(model);
81+
ArgumentException.ThrowIfNullOrWhiteSpace(prompt);
82+
83+
var checkpointPath = Path.Combine(Path.GetTempPath(), $"traditional-local-checkpoint-{Guid.NewGuid():N}.json");
84+
try
85+
{
86+
Save(model, checkpointPath);
87+
var reloaded = Load(checkpointPath, model.Options.Verbosity);
88+
var original = model.GenerateResponse(prompt, maxTokens: 4);
89+
var roundTripped = reloaded.GenerateResponse(prompt, maxTokens: 4);
90+
return new TraditionalLocalCheckpointValidationResult(
91+
prompt,
92+
original.ResponseText,
93+
roundTripped.ResponseText,
94+
string.Equals(original.ResponseText, roundTripped.ResponseText, StringComparison.Ordinal));
95+
}
96+
finally
97+
{
98+
if (File.Exists(checkpointPath))
99+
{
100+
File.Delete(checkpointPath);
101+
}
102+
}
103+
}
104+
}

src/BitNetSharp.Core/TraditionalLocalModel.cs

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System.Numerics.Tensors;
2+
using BitNetSharp.Core.Quantization;
23

34
namespace BitNetSharp.Core;
45

@@ -106,6 +107,8 @@ .. options.Vocabulary
106107

107108
public BitNetTokenizer Tokenizer => _tokenizer;
108109

110+
internal int Seed => _seed;
111+
109112
public static TraditionalLocalModel CreateDefault(VerbosityLevel verbosity = VerbosityLevel.Normal) =>
110113
new(new BitNetOptions(BitNetTrainingCorpus.CreateDefaultVocabulary(), verbosity));
111114

@@ -176,7 +179,14 @@ public TrainingReport Train(IEnumerable<TrainingExample> examples, int epochs =
176179
}
177180

178181
_isTrained = true;
179-
return new TrainingReport(history, totalSamples, epochs, 0, 0, 0);
182+
var stats = GetTernaryWeightStats();
183+
return new TrainingReport(
184+
history,
185+
totalSamples,
186+
epochs,
187+
stats.NegativeCount,
188+
stats.ZeroCount,
189+
stats.PositiveCount);
180190
}
181191
}
182192

@@ -269,6 +279,76 @@ public double CalculatePerplexity(IEnumerable<string> validationSamples)
269279
}
270280
}
271281

282+
public TernaryWeightStats GetTernaryWeightStats()
283+
{
284+
lock (_gate)
285+
{
286+
var negative = 0;
287+
var zero = 0;
288+
var positive = 0;
289+
290+
CountWeightSigns(_tokenEmbeddings, ref negative, ref zero, ref positive);
291+
CountWeightSigns(_outputWeights, ref negative, ref zero, ref positive);
292+
CountWeightSigns(_outputBias, ref negative, ref zero, ref positive);
293+
294+
return new TernaryWeightStats(negative, zero, positive);
295+
}
296+
}
297+
298+
internal float[] ExportTokenEmbeddings()
299+
{
300+
lock (_gate)
301+
{
302+
return [.. _tokenEmbeddings];
303+
}
304+
}
305+
306+
internal float[] ExportOutputWeights()
307+
{
308+
lock (_gate)
309+
{
310+
return [.. _outputWeights];
311+
}
312+
}
313+
314+
internal float[] ExportOutputBias()
315+
{
316+
lock (_gate)
317+
{
318+
return [.. _outputBias];
319+
}
320+
}
321+
322+
internal void ImportState(float[] tokenEmbeddings, float[] outputWeights, float[] outputBias)
323+
{
324+
ArgumentNullException.ThrowIfNull(tokenEmbeddings);
325+
ArgumentNullException.ThrowIfNull(outputWeights);
326+
ArgumentNullException.ThrowIfNull(outputBias);
327+
328+
lock (_gate)
329+
{
330+
if (tokenEmbeddings.Length != _tokenEmbeddings.Length)
331+
{
332+
throw new ArgumentException($"Token embedding length {tokenEmbeddings.Length} does not match expected length {_tokenEmbeddings.Length}.", nameof(tokenEmbeddings));
333+
}
334+
335+
if (outputWeights.Length != _outputWeights.Length)
336+
{
337+
throw new ArgumentException($"Output weight length {outputWeights.Length} does not match expected length {_outputWeights.Length}.", nameof(outputWeights));
338+
}
339+
340+
if (outputBias.Length != _outputBias.Length)
341+
{
342+
throw new ArgumentException($"Output bias length {outputBias.Length} does not match expected length {_outputBias.Length}.", nameof(outputBias));
343+
}
344+
345+
tokenEmbeddings.CopyTo(_tokenEmbeddings, 0);
346+
outputWeights.CopyTo(_outputWeights, 0);
347+
outputBias.CopyTo(_outputBias, 0);
348+
_isTrained = true;
349+
}
350+
}
351+
272352
private void EnsureTrained()
273353
{
274354
if (_isTrained)
@@ -424,6 +504,25 @@ private void ResetParameters()
424504
_isTrained = false;
425505
}
426506

507+
private static void CountWeightSigns(float[] values, ref int negative, ref int zero, ref int positive)
508+
{
509+
foreach (var value in values)
510+
{
511+
if (value > 0f)
512+
{
513+
positive++;
514+
}
515+
else if (value < 0f)
516+
{
517+
negative++;
518+
}
519+
else
520+
{
521+
zero++;
522+
}
523+
}
524+
}
525+
427526
private static void FillWithDeterministicNoise(float[] values, Random random)
428527
{
429528
for (var index = 0; index < values.Length; index++)

tests/BitNetSharp.Tests/BitNetModelTests.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,40 @@ public void TraditionalLocalModelLearnsSimplePromptResponse()
134134
Assert.Contains(result.Diagnostics, diagnostic => diagnostic.Contains("tensor-based ordered-context", StringComparison.OrdinalIgnoreCase));
135135
}
136136

137+
[Fact]
138+
public void TraditionalLocalTrainingReportIncludesWeightSignDistribution()
139+
{
140+
var model = new TraditionalLocalModel(
141+
new BitNetOptions(["alpha", "beta", "gamma", "delta"], VerbosityLevel.Quiet),
142+
embeddingDimension: 8,
143+
contextWindow: 4,
144+
seed: 19);
145+
146+
var report = model.Train(
147+
[
148+
new TrainingExample("alpha beta", "gamma delta")
149+
],
150+
epochs: 12,
151+
learningRate: 0.3f);
152+
153+
Assert.True(report.NegativeWeights > 0);
154+
Assert.True(report.PositiveWeights > 0);
155+
Assert.Equal(report.NegativeWeights + report.ZeroWeights + report.PositiveWeights, model.GetTernaryWeightStats().TotalCount);
156+
}
157+
158+
[Fact]
159+
public void TraditionalHostedAgentModelExposesInspectableWeightStats()
160+
{
161+
using var model = HostedAgentModelFactory.Create(HostedAgentModelFactory.TraditionalLocalModelId, VerbosityLevel.Quiet);
162+
163+
var inspectable = Assert.IsAssignableFrom<IInspectableHostedAgentModel>(model);
164+
var stats = inspectable.GetTernaryWeightStats();
165+
166+
Assert.True(stats.TotalCount > 0);
167+
Assert.True(stats.NegativeCount > 0);
168+
Assert.True(stats.PositiveCount > 0);
169+
}
170+
137171
[Fact]
138172
public void BenchmarkOptionsIncludePrimaryAndComparisonModels()
139173
{

tests/BitNetSharp.Tests/HostedAgentBenchmarksExecutionTests.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,23 @@ public void BenchmarkModelConstructionUsesTheTinyLlamaTrainingVocabulary()
154154
Assert.Equal("tinyllama", ((TraditionalLocalHostedAgentModel)traditional).Model.Tokenizer.Normalize("tinyllama"));
155155
}
156156

157+
[Fact]
158+
public void BuiltInModelsPreserveTrainedResponsesAcrossCheckpointRoundTrips()
159+
{
160+
var examples = BitNetTrainingCorpus.CreateBenchmarkExamples();
161+
var bitNetModel = BitNetPaperModel.CreateForTrainingCorpus(examples, VerbosityLevel.Quiet);
162+
var traditionalModel = TraditionalLocalModel.CreateForTrainingCorpus(examples, VerbosityLevel.Quiet);
163+
164+
bitNetModel.Train(examples, epochs: 1);
165+
traditionalModel.Train(examples, epochs: TraditionalLocalModel.DefaultTrainingEpochs);
166+
167+
var bitNetRoundTrip = BitNetPaperCheckpoint.ValidateRoundTrip(bitNetModel, "what does the paper model train on");
168+
var traditionalRoundTrip = TraditionalLocalCheckpoint.ValidateRoundTrip(traditionalModel, "what does the paper model train on");
169+
170+
Assert.True(bitNetRoundTrip.ResponsesMatch);
171+
Assert.True(traditionalRoundTrip.ResponsesMatch);
172+
}
173+
157174
private static async Task WithBenchmarkOptionsAsync(HostedAgentBenchmarkOptions options, Func<Task> assertion)
158175
{
159176
var originalValue = Environment.GetEnvironmentVariable(HostedAgentBenchmarkOptions.EnvironmentVariableName);

0 commit comments

Comments
 (0)