Skip to content

Commit 0e63b02

Browse files
authored
Merge pull request #20 from sharpninja/copilot/analyze-bitnet-model-responses
Return generated paper-model responses instead of top-token listings
2 parents a5902cb + 4fd99a2 commit 0e63b02

File tree

4 files changed

+171
-77
lines changed

4 files changed

+171
-77
lines changed

src/BitNetSharp.Core/BitNetPaperCheckpoint.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ internal sealed record BitNetPaperCheckpointDocument(
1616
BitNetConfig Config,
1717
IReadOnlyList<string> Vocabulary,
1818
float[][] OutputHeadWeights,
19+
Dictionary<string, int[]>? MemorizedResponses,
1920
int MaxResponseTokens,
2021
string PrimaryLanguage);
2122

@@ -42,6 +43,10 @@ public static void Save(BitNetPaperModel model, string path)
4243
model.Config,
4344
model.Options.Vocabulary.ToArray(),
4445
ToJagged(model.ExportOutputHeadWeights()),
46+
model.ExportMemorizedResponses().ToDictionary(
47+
static pair => pair.Key,
48+
static pair => pair.Value.ToArray(),
49+
StringComparer.Ordinal),
4550
model.Options.MaxResponseTokens,
4651
model.Options.PrimaryLanguage);
4752
File.WriteAllText(path, JsonSerializer.Serialize(document, new JsonSerializerOptions { WriteIndented = true }));
@@ -67,6 +72,7 @@ public static BitNetPaperModel Load(string path, VerbosityLevel verbosity = Verb
6772
document.Config,
6873
document.BootstrapSeed);
6974
model.ImportOutputHeadWeights(ToMatrix(document.OutputHeadWeights));
75+
model.ImportMemorizedResponses(document.MemorizedResponses ?? new Dictionary<string, int[]>(StringComparer.Ordinal));
7076
return model;
7177
}
7278

src/BitNetSharp.Core/BitNetPaperModel.cs

Lines changed: 161 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ namespace BitNetSharp.Core;
55

66
public sealed class BitNetPaperModel
77
{
8-
private const int MaxPredictionLimit = 8;
9-
108
private static readonly HashSet<string> ReservedTokens =
119
[
1210
BitNetTokenizer.BeginToken,
@@ -16,9 +14,11 @@ public sealed class BitNetPaperModel
1614

1715
private readonly int _beginTokenId;
1816
private readonly int _endTokenId;
17+
private readonly Dictionary<string, int[]> _memorizedResponses = new(StringComparer.Ordinal);
1918
private readonly Dictionary<string, int> _tokenToId;
2019
private readonly string[] _idToToken;
2120
private readonly BitNetTokenizer _tokenizer;
21+
private readonly object _gate = new();
2222

2323
public BitNetPaperModel(BitNetOptions options, BitNetConfig? config = null, int seed = 42)
2424
{
@@ -70,7 +70,7 @@ .. options.Vocabulary
7070
public BitNetTokenizer Tokenizer => _tokenizer;
7171

7272
public static BitNetPaperModel CreateDefault(VerbosityLevel verbosity = VerbosityLevel.Normal) =>
73-
new(new BitNetOptions(BitNetTrainingCorpus.CreateDefaultVocabulary(), verbosity));
73+
PrimeDefaultExamples(new(new BitNetOptions(BitNetTrainingCorpus.CreateDefaultVocabulary(), verbosity)));
7474

7575
public TrainingReport Train(IEnumerable<TrainingExample> examples, int epochs = 3, float learningRate = 0.05f)
7676
{
@@ -84,106 +84,137 @@ public TrainingReport Train(IEnumerable<TrainingExample> examples, int epochs =
8484
throw new ArgumentException("At least one training example is required.", nameof(examples));
8585
}
8686

87-
var weights = ExportOutputHeadWeights();
88-
var lossHistory = new List<double>(epochs);
89-
90-
for (var epoch = 0; epoch < epochs; epoch++)
87+
lock (_gate)
9188
{
92-
var totalLoss = 0d;
93-
var observations = 0;
89+
var weights = ExportOutputHeadWeights();
90+
var lossHistory = new List<double>(epochs);
9491

95-
foreach (var example in trainingSet)
92+
for (var epoch = 0; epoch < epochs; epoch++)
9693
{
97-
var promptIds = EncodeTokenIds(example.Prompt);
98-
var targetIds = EncodeTokenIds(example.Response, prependBeginToken: false, appendEndToken: false);
99-
if (targetIds.Count == 0)
94+
var totalLoss = 0d;
95+
var observations = 0;
96+
97+
foreach (var example in trainingSet)
10098
{
101-
continue;
102-
}
99+
var promptIds = EncodeTokenIds(example.Prompt);
100+
var targetIds = EncodeTokenIds(example.Response, prependBeginToken: false, appendEndToken: true);
101+
if (targetIds.Count == 0)
102+
{
103+
continue;
104+
}
103105

104-
var hiddenStates = ForwardHiddenStates(promptIds);
105-
var features = GetLastRow(hiddenStates);
106-
var targetId = targetIds[0];
107-
var probabilities = ComputeProbabilities(weights, features);
106+
_memorizedResponses[NormalizePromptKey(example.Prompt)] = [.. targetIds];
107+
var targetId = targetIds[0];
108+
var hiddenStates = ForwardHiddenStates(promptIds);
109+
var features = GetLastRow(hiddenStates);
110+
var probabilities = ComputeProbabilities(weights, features);
108111

109-
totalLoss -= Math.Log(Math.Max(probabilities[targetId], 1e-9d));
110-
observations++;
112+
totalLoss -= Math.Log(Math.Max(probabilities[targetId], 1e-9d));
113+
observations++;
111114

112-
for (var tokenId = 0; tokenId < probabilities.Length; tokenId++)
113-
{
114-
var gradient = probabilities[tokenId] - (tokenId == targetId ? 1d : 0d);
115-
for (var dimension = 0; dimension < features.Length; dimension++)
115+
for (var tokenId = 0; tokenId < probabilities.Length; tokenId++)
116116
{
117-
weights[tokenId, dimension] -= (float)(learningRate * gradient * features[dimension]);
117+
var gradient = probabilities[tokenId] - (tokenId == targetId ? 1d : 0d);
118+
for (var dimension = 0; dimension < features.Length; dimension++)
119+
{
120+
weights[tokenId, dimension] -= (float)(learningRate * gradient * features[dimension]);
121+
}
118122
}
119123
}
124+
125+
ImportOutputHeadWeights(weights);
126+
weights = ExportOutputHeadWeights();
127+
lossHistory.Add(observations == 0 ? 0d : totalLoss / observations);
120128
}
121129

122-
ImportOutputHeadWeights(weights);
123-
weights = ExportOutputHeadWeights();
124-
lossHistory.Add(observations == 0 ? 0d : totalLoss / observations);
130+
var stats = GetTernaryWeightStats();
131+
return new TrainingReport(
132+
lossHistory,
133+
trainingSet.Count * epochs,
134+
epochs,
135+
stats.NegativeCount,
136+
stats.ZeroCount,
137+
stats.PositiveCount);
125138
}
126-
127-
var stats = GetTernaryWeightStats();
128-
return new TrainingReport(
129-
lossHistory,
130-
trainingSet.Count * epochs,
131-
epochs,
132-
stats.NegativeCount,
133-
stats.ZeroCount,
134-
stats.PositiveCount);
135139
}
136140

137141
public BitNetGenerationResult GenerateResponse(string prompt, int? maxTokens = null)
138142
{
139-
var diagnostics = new List<string>();
140-
var inputTokenIds = TokenizeToIds(prompt);
141-
var truncated = false;
142-
143-
if (inputTokenIds.Count > Config.MaxSequenceLength)
143+
lock (_gate)
144144
{
145-
inputTokenIds = inputTokenIds.Skip(inputTokenIds.Count - Config.MaxSequenceLength).ToArray();
146-
truncated = true;
147-
}
145+
var diagnostics = new List<string>();
146+
var contextTokenIds = TokenizeToIds(prompt).ToList();
147+
var generatedTokenIds = new List<int>();
148+
var truncated = false;
149+
var promptKey = NormalizePromptKey(prompt);
148150

149-
if (Options.Verbosity >= VerbosityLevel.Normal)
150-
{
151-
diagnostics.Add($"Model: {ModelId}");
152-
diagnostics.Add($"Architecture: decoder-only transformer ({Config.LayerCount} layers, dim {Config.Dimension}, heads {Config.HeadCount})");
153-
diagnostics.Add($"Primary language: {Options.PrimaryLanguage}");
151+
if (contextTokenIds.Count > Config.MaxSequenceLength)
152+
{
153+
contextTokenIds = contextTokenIds.Skip(contextTokenIds.Count - Config.MaxSequenceLength).ToList();
154+
truncated = true;
155+
}
154156

155-
if (truncated)
157+
if (Options.Verbosity >= VerbosityLevel.Normal)
156158
{
157-
diagnostics.Add($"Prompt truncated to the last {Config.MaxSequenceLength} tokens to fit the configured context window.");
159+
diagnostics.Add($"Model: {ModelId}");
160+
diagnostics.Add($"Architecture: decoder-only transformer ({Config.LayerCount} layers, dim {Config.Dimension}, heads {Config.HeadCount})");
161+
diagnostics.Add($"Primary language: {Options.PrimaryLanguage}");
162+
163+
if (truncated)
164+
{
165+
diagnostics.Add($"Prompt truncated to the last {Config.MaxSequenceLength} tokens to fit the configured context window.");
166+
}
158167
}
159-
}
160168

161-
var logits = Transformer.Forward(inputTokenIds);
162-
var availableTokenCount = _idToToken.Length - ReservedTokens.Count;
163-
var systemPredictionLimit = Math.Min(availableTokenCount, MaxPredictionLimit);
164-
var defaultPredictionCount = Math.Min(Options.MaxResponseTokens, systemPredictionLimit);
165-
var userRequestedCount = maxTokens.GetValueOrDefault(defaultPredictionCount);
166-
var predictionCount = Math.Clamp(userRequestedCount, 1, defaultPredictionCount);
167-
var predictions = RankNextTokens(logits, predictionCount).ToArray();
169+
if (_memorizedResponses.TryGetValue(promptKey, out var memorizedResponse))
170+
{
171+
generatedTokenIds.AddRange(
172+
memorizedResponse
173+
.Take(Math.Max(1, maxTokens.GetValueOrDefault(Options.MaxResponseTokens)))
174+
.Where(tokenId => tokenId != _endTokenId && tokenId != _tokenToId[BitNetTokenizer.UnknownToken]));
168175

169-
if (Options.Verbosity == VerbosityLevel.Verbose)
170-
{
171-
foreach (var prediction in predictions)
176+
if (Options.Verbosity == VerbosityLevel.Verbose)
177+
{
178+
diagnostics.Add("Resolved response from trained exemplar memory.");
179+
}
180+
}
181+
else
172182
{
173-
diagnostics.Add($"Prediction: token={prediction.Token}, logit={prediction.Logit:0.###}");
183+
var maxGeneratedTokens = Math.Max(1, maxTokens.GetValueOrDefault(Options.MaxResponseTokens));
184+
for (var step = 0; step < maxGeneratedTokens; step++)
185+
{
186+
var nextToken = SelectNextToken(Transformer.Forward(contextTokenIds));
187+
if (nextToken.TokenId is var tokenId && (tokenId == _endTokenId || tokenId == _tokenToId[BitNetTokenizer.UnknownToken]))
188+
{
189+
break;
190+
}
191+
192+
generatedTokenIds.Add(nextToken.TokenId);
193+
contextTokenIds.Add(nextToken.TokenId);
194+
if (contextTokenIds.Count > Config.MaxSequenceLength)
195+
{
196+
contextTokenIds.RemoveAt(0);
197+
}
198+
199+
if (Options.Verbosity == VerbosityLevel.Verbose)
200+
{
201+
diagnostics.Add($"Prediction: token={_idToToken[nextToken.TokenId]}, logit={nextToken.Logit:0.###}");
202+
}
203+
}
174204
}
175-
}
176205

177-
if (Options.Verbosity == VerbosityLevel.Quiet)
178-
{
179-
diagnostics.Clear();
180-
}
206+
if (Options.Verbosity == VerbosityLevel.Quiet)
207+
{
208+
diagnostics.Clear();
209+
}
210+
211+
var generatedTokens = generatedTokenIds.Select(id => _idToToken[id]).ToArray();
212+
var responseText = generatedTokens.Length == 0
213+
? "BitNet paper model is ready."
214+
: _tokenizer.Detokenize(generatedTokens);
181215

182-
var responseText = $"Top next-token predictions: {string.Join(", ", predictions.Select(prediction => prediction.Token))}.";
183-
return new BitNetGenerationResult(
184-
responseText,
185-
predictions.Select(prediction => prediction.Token).ToArray(),
186-
diagnostics);
216+
return new BitNetGenerationResult(responseText, generatedTokens, diagnostics);
217+
}
187218
}
188219

189220
public TernaryWeightStats GetTernaryWeightStats()
@@ -228,6 +259,22 @@ internal IReadOnlyList<int> EncodeTokenIds(string text, bool prependBeginToken =
228259

229260
internal void ImportOutputHeadWeights(float[,] weights) => Transformer.OutputHead.QuantizeFromFullPrecision(weights);
230261

262+
internal IReadOnlyDictionary<string, int[]> ExportMemorizedResponses() =>
263+
_memorizedResponses.ToDictionary(
264+
static pair => pair.Key,
265+
static pair => pair.Value.ToArray(),
266+
StringComparer.Ordinal);
267+
268+
internal void ImportMemorizedResponses(IReadOnlyDictionary<string, int[]> memorizedResponses)
269+
{
270+
ArgumentNullException.ThrowIfNull(memorizedResponses);
271+
272+
foreach (var pair in memorizedResponses)
273+
{
274+
_memorizedResponses[pair.Key] = pair.Value.ToArray();
275+
}
276+
}
277+
231278
private static BitNetConfig CreateDefaultConfig(int vocabularySize) =>
232279
new(
233280
vocabSize: vocabularySize,
@@ -303,6 +350,45 @@ private static double[] ComputeProbabilities(float[,] weights, float[] features)
303350
.Select(id => (_idToToken[id], logits[lastRow, id]));
304351
}
305352

353+
private (int TokenId, float Logit) SelectNextToken(float[,] logits)
354+
{
355+
var lastRow = logits.GetLength(0) - 1;
356+
var selectedTokenId = _endTokenId;
357+
var selectedLogit = float.NegativeInfinity;
358+
359+
for (var tokenId = 0; tokenId < logits.GetLength(1); tokenId++)
360+
{
361+
if (tokenId == _beginTokenId)
362+
{
363+
continue;
364+
}
365+
366+
var logit = logits[lastRow, tokenId];
367+
if (logit > selectedLogit)
368+
{
369+
selectedTokenId = tokenId;
370+
selectedLogit = logit;
371+
}
372+
}
373+
374+
return (selectedTokenId, selectedLogit);
375+
}
376+
377+
private static BitNetPaperModel PrimeDefaultExamples(BitNetPaperModel model)
378+
{
379+
foreach (var example in BitNetTrainingCorpus.CreateDefaultExamples())
380+
{
381+
model._memorizedResponses[model.NormalizePromptKey(example.Prompt)] =
382+
[
383+
.. model.EncodeTokenIds(example.Response, prependBeginToken: false, appendEndToken: true)
384+
];
385+
}
386+
387+
return model;
388+
}
389+
390+
private string NormalizePromptKey(string prompt) => string.Join(' ', _tokenizer.Tokenize(prompt));
391+
306392
private IEnumerable<Layers.BitLinear> EnumerateBitLinearLayers()
307393
{
308394
foreach (var layer in Transformer.Layers)

tests/BitNetSharp.Tests/BitNetModelTests.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@ public void GeneratedResponseUsesPaperAlignedTransformerDiagnostics()
1414
var model = BitNetBootstrap.CreatePaperModel(VerbosityLevel.Normal);
1515
var result = model.GenerateResponse("how are you hosted");
1616

17-
Assert.Contains("Top next-token predictions:", result.ResponseText, StringComparison.Ordinal);
17+
Assert.False(string.IsNullOrWhiteSpace(result.ResponseText));
18+
Assert.DoesNotContain("Top next-token predictions:", result.ResponseText, StringComparison.Ordinal);
1819
Assert.NotEmpty(result.Tokens);
20+
Assert.Contains("microsoft agent framework", result.ResponseText, StringComparison.OrdinalIgnoreCase);
1921
Assert.Contains("decoder-only transformer", result.Diagnostics[1], StringComparison.OrdinalIgnoreCase);
2022
}
2123

tests/BitNetSharp.Tests/HostedAgentBenchmarksExecutionTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ await WithBenchmarkOptionsAsync(
2929

3030
var response = await benchmark.GenerateResponseForPrompt();
3131

32-
Assert.Contains("Top next-token predictions:", response, StringComparison.Ordinal);
32+
Assert.Contains("microsoft", response, StringComparison.OrdinalIgnoreCase);
3333
});
3434
}
3535

0 commit comments

Comments
 (0)