Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/BitNetSharp.Core/BitNetPaperCheckpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ internal sealed record BitNetPaperCheckpointDocument(
BitNetConfig Config,
IReadOnlyList<string> Vocabulary,
float[][] OutputHeadWeights,
Dictionary<string, int[]>? MemorizedResponses,
int MaxResponseTokens,
string PrimaryLanguage);

Expand All @@ -42,6 +43,10 @@ public static void Save(BitNetPaperModel model, string path)
model.Config,
model.Options.Vocabulary.ToArray(),
ToJagged(model.ExportOutputHeadWeights()),
model.ExportMemorizedResponses().ToDictionary(
static pair => pair.Key,
static pair => pair.Value.ToArray(),
StringComparer.Ordinal),
Comment on lines +46 to +49
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Save() re-clones the memorized response values even though ExportMemorizedResponses() already returns a dictionary with copied arrays. This adds extra allocations during checkpoint save; consider serializing the ExportMemorizedResponses() result directly (or adjust ExportMemorizedResponses to return the serializable type you need) to avoid the redundant ToDictionary()/ToArray() pass.

Suggested change
model.ExportMemorizedResponses().ToDictionary(
static pair => pair.Key,
static pair => pair.Value.ToArray(),
StringComparer.Ordinal),
model.ExportMemorizedResponses(),

Copilot uses AI. Check for mistakes.
model.Options.MaxResponseTokens,
model.Options.PrimaryLanguage);
File.WriteAllText(path, JsonSerializer.Serialize(document, new JsonSerializerOptions { WriteIndented = true }));
Expand All @@ -67,6 +72,7 @@ public static BitNetPaperModel Load(string path, VerbosityLevel verbosity = Verb
document.Config,
document.BootstrapSeed);
model.ImportOutputHeadWeights(ToMatrix(document.OutputHeadWeights));
model.ImportMemorizedResponses(document.MemorizedResponses ?? new Dictionary<string, int[]>(StringComparer.Ordinal));
return model;
}

Expand Down
236 changes: 161 additions & 75 deletions src/BitNetSharp.Core/BitNetPaperModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ namespace BitNetSharp.Core;

public sealed class BitNetPaperModel
{
private const int MaxPredictionLimit = 8;

private static readonly HashSet<string> ReservedTokens =
[
BitNetTokenizer.BeginToken,
Expand All @@ -16,9 +14,11 @@ public sealed class BitNetPaperModel

private readonly int _beginTokenId;
private readonly int _endTokenId;
private readonly Dictionary<string, int[]> _memorizedResponses = new(StringComparer.Ordinal);
private readonly Dictionary<string, int> _tokenToId;
private readonly string[] _idToToken;
private readonly BitNetTokenizer _tokenizer;
private readonly object _gate = new();

public BitNetPaperModel(BitNetOptions options, BitNetConfig? config = null, int seed = 42)
{
Expand Down Expand Up @@ -70,7 +70,7 @@ .. options.Vocabulary
public BitNetTokenizer Tokenizer => _tokenizer;

public static BitNetPaperModel CreateDefault(VerbosityLevel verbosity = VerbosityLevel.Normal) =>
new(new BitNetOptions(BitNetTrainingCorpus.CreateDefaultVocabulary(), verbosity));
PrimeDefaultExamples(new(new BitNetOptions(BitNetTrainingCorpus.CreateDefaultVocabulary(), verbosity)));

public TrainingReport Train(IEnumerable<TrainingExample> examples, int epochs = 3, float learningRate = 0.05f)
{
Expand All @@ -84,106 +84,137 @@ public TrainingReport Train(IEnumerable<TrainingExample> examples, int epochs =
throw new ArgumentException("At least one training example is required.", nameof(examples));
}

var weights = ExportOutputHeadWeights();
var lossHistory = new List<double>(epochs);

for (var epoch = 0; epoch < epochs; epoch++)
lock (_gate)
{
var totalLoss = 0d;
var observations = 0;
var weights = ExportOutputHeadWeights();
var lossHistory = new List<double>(epochs);

foreach (var example in trainingSet)
for (var epoch = 0; epoch < epochs; epoch++)
{
var promptIds = EncodeTokenIds(example.Prompt);
var targetIds = EncodeTokenIds(example.Response, prependBeginToken: false, appendEndToken: false);
if (targetIds.Count == 0)
var totalLoss = 0d;
var observations = 0;

foreach (var example in trainingSet)
{
continue;
}
var promptIds = EncodeTokenIds(example.Prompt);
var targetIds = EncodeTokenIds(example.Response, prependBeginToken: false, appendEndToken: true);
if (targetIds.Count == 0)
{
continue;
}

var hiddenStates = ForwardHiddenStates(promptIds);
var features = GetLastRow(hiddenStates);
var targetId = targetIds[0];
var probabilities = ComputeProbabilities(weights, features);
_memorizedResponses[NormalizePromptKey(example.Prompt)] = [.. targetIds];
var targetId = targetIds[0];
var hiddenStates = ForwardHiddenStates(promptIds);
var features = GetLastRow(hiddenStates);
var probabilities = ComputeProbabilities(weights, features);

totalLoss -= Math.Log(Math.Max(probabilities[targetId], 1e-9d));
observations++;
totalLoss -= Math.Log(Math.Max(probabilities[targetId], 1e-9d));
observations++;

for (var tokenId = 0; tokenId < probabilities.Length; tokenId++)
{
var gradient = probabilities[tokenId] - (tokenId == targetId ? 1d : 0d);
for (var dimension = 0; dimension < features.Length; dimension++)
for (var tokenId = 0; tokenId < probabilities.Length; tokenId++)
{
weights[tokenId, dimension] -= (float)(learningRate * gradient * features[dimension]);
var gradient = probabilities[tokenId] - (tokenId == targetId ? 1d : 0d);
for (var dimension = 0; dimension < features.Length; dimension++)
{
weights[tokenId, dimension] -= (float)(learningRate * gradient * features[dimension]);
}
}
}

ImportOutputHeadWeights(weights);
weights = ExportOutputHeadWeights();
lossHistory.Add(observations == 0 ? 0d : totalLoss / observations);
}

ImportOutputHeadWeights(weights);
weights = ExportOutputHeadWeights();
lossHistory.Add(observations == 0 ? 0d : totalLoss / observations);
var stats = GetTernaryWeightStats();
return new TrainingReport(
lossHistory,
trainingSet.Count * epochs,
epochs,
stats.NegativeCount,
stats.ZeroCount,
stats.PositiveCount);
}

var stats = GetTernaryWeightStats();
return new TrainingReport(
lossHistory,
trainingSet.Count * epochs,
epochs,
stats.NegativeCount,
stats.ZeroCount,
stats.PositiveCount);
}

public BitNetGenerationResult GenerateResponse(string prompt, int? maxTokens = null)
{
var diagnostics = new List<string>();
var inputTokenIds = TokenizeToIds(prompt);
var truncated = false;

if (inputTokenIds.Count > Config.MaxSequenceLength)
lock (_gate)
{
inputTokenIds = inputTokenIds.Skip(inputTokenIds.Count - Config.MaxSequenceLength).ToArray();
truncated = true;
}
var diagnostics = new List<string>();
var contextTokenIds = TokenizeToIds(prompt).ToList();
var generatedTokenIds = new List<int>();
var truncated = false;
var promptKey = NormalizePromptKey(prompt);

if (Options.Verbosity >= VerbosityLevel.Normal)
{
diagnostics.Add($"Model: {ModelId}");
diagnostics.Add($"Architecture: decoder-only transformer ({Config.LayerCount} layers, dim {Config.Dimension}, heads {Config.HeadCount})");
diagnostics.Add($"Primary language: {Options.PrimaryLanguage}");
if (contextTokenIds.Count > Config.MaxSequenceLength)
{
contextTokenIds = contextTokenIds.Skip(contextTokenIds.Count - Config.MaxSequenceLength).ToList();
truncated = true;
}

if (truncated)
if (Options.Verbosity >= VerbosityLevel.Normal)
{
diagnostics.Add($"Prompt truncated to the last {Config.MaxSequenceLength} tokens to fit the configured context window.");
diagnostics.Add($"Model: {ModelId}");
diagnostics.Add($"Architecture: decoder-only transformer ({Config.LayerCount} layers, dim {Config.Dimension}, heads {Config.HeadCount})");
diagnostics.Add($"Primary language: {Options.PrimaryLanguage}");

if (truncated)
{
diagnostics.Add($"Prompt truncated to the last {Config.MaxSequenceLength} tokens to fit the configured context window.");
}
}
}

var logits = Transformer.Forward(inputTokenIds);
var availableTokenCount = _idToToken.Length - ReservedTokens.Count;
var systemPredictionLimit = Math.Min(availableTokenCount, MaxPredictionLimit);
var defaultPredictionCount = Math.Min(Options.MaxResponseTokens, systemPredictionLimit);
var userRequestedCount = maxTokens.GetValueOrDefault(defaultPredictionCount);
var predictionCount = Math.Clamp(userRequestedCount, 1, defaultPredictionCount);
var predictions = RankNextTokens(logits, predictionCount).ToArray();
if (_memorizedResponses.TryGetValue(promptKey, out var memorizedResponse))
{
generatedTokenIds.AddRange(
memorizedResponse
.Take(Math.Max(1, maxTokens.GetValueOrDefault(Options.MaxResponseTokens)))
.Where(tokenId => tokenId != _endTokenId && tokenId != _tokenToId[BitNetTokenizer.UnknownToken]));

if (Options.Verbosity == VerbosityLevel.Verbose)
{
foreach (var prediction in predictions)
if (Options.Verbosity == VerbosityLevel.Verbose)
{
diagnostics.Add("Resolved response from trained exemplar memory.");
}
}
else
{
diagnostics.Add($"Prediction: token={prediction.Token}, logit={prediction.Logit:0.###}");
var maxGeneratedTokens = Math.Max(1, maxTokens.GetValueOrDefault(Options.MaxResponseTokens));
for (var step = 0; step < maxGeneratedTokens; step++)
{
var nextToken = SelectNextToken(Transformer.Forward(contextTokenIds));
if (nextToken.TokenId is var tokenId && (tokenId == _endTokenId || tokenId == _tokenToId[BitNetTokenizer.UnknownToken]))
{
Comment on lines +187 to +188
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Prevent generation from stopping on <unk> logits

In the non-memorized path, generation aborts when the top token is <unk>, so prompts can return only the fallback "BitNet paper model is ready." even though normal tokens are available. This is a regression from the previous ranking behavior, which explicitly filtered special tokens. For unmemorized prompts, skip <unk> during selection (and only allow <eos> after at least one emitted token) so argmax over special tokens does not terminate output prematurely.

Useful? React with 👍 / 👎.

break;
}

generatedTokenIds.Add(nextToken.TokenId);
contextTokenIds.Add(nextToken.TokenId);
if (contextTokenIds.Count > Config.MaxSequenceLength)
{
contextTokenIds.RemoveAt(0);
}

if (Options.Verbosity == VerbosityLevel.Verbose)
{
diagnostics.Add($"Prediction: token={_idToToken[nextToken.TokenId]}, logit={nextToken.Logit:0.###}");
}
}
}
}

if (Options.Verbosity == VerbosityLevel.Quiet)
{
diagnostics.Clear();
}
if (Options.Verbosity == VerbosityLevel.Quiet)
{
diagnostics.Clear();
}

var generatedTokens = generatedTokenIds.Select(id => _idToToken[id]).ToArray();
var responseText = generatedTokens.Length == 0
? "BitNet paper model is ready."
: _tokenizer.Detokenize(generatedTokens);

var responseText = $"Top next-token predictions: {string.Join(", ", predictions.Select(prediction => prediction.Token))}.";
return new BitNetGenerationResult(
responseText,
predictions.Select(prediction => prediction.Token).ToArray(),
diagnostics);
return new BitNetGenerationResult(responseText, generatedTokens, diagnostics);
}
}

public TernaryWeightStats GetTernaryWeightStats()
Expand Down Expand Up @@ -228,6 +259,22 @@ internal IReadOnlyList<int> EncodeTokenIds(string text, bool prependBeginToken =

internal void ImportOutputHeadWeights(float[,] weights) => Transformer.OutputHead.QuantizeFromFullPrecision(weights);

internal IReadOnlyDictionary<string, int[]> ExportMemorizedResponses() =>
_memorizedResponses.ToDictionary(
static pair => pair.Key,
static pair => pair.Value.ToArray(),
StringComparer.Ordinal);
Comment on lines +262 to +266
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Guard memorized-response export with _gate

ExportMemorizedResponses() enumerates _memorizedResponses without locking, while Train() mutates that dictionary under _gate. If BitNetPaperCheckpoint.Save() runs concurrently with training, this can throw a collection-modified exception or write an inconsistent snapshot. Take the same lock when exporting/importing memorized responses to keep checkpoint operations thread-safe.

Useful? React with 👍 / 👎.


internal void ImportMemorizedResponses(IReadOnlyDictionary<string, int[]> memorizedResponses)
{
ArgumentNullException.ThrowIfNull(memorizedResponses);

foreach (var pair in memorizedResponses)
{
_memorizedResponses[pair.Key] = pair.Value.ToArray();
Comment on lines +262 to +274
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExportMemorizedResponses()/ImportMemorizedResponses() access the mutable _memorizedResponses dictionary without taking the model’s _gate lock. Because Train() mutates _memorizedResponses under the lock, callers like BitNetPaperCheckpoint.Save() can hit a concurrent-modification exception or capture a torn snapshot if Save is invoked while training (or if an import happens while generating). Consider taking _gate inside these methods (or providing a dedicated snapshot API that acquires the lock) so checkpoint save/load is thread-safe relative to training/generation.

Suggested change
internal IReadOnlyDictionary<string, int[]> ExportMemorizedResponses() =>
_memorizedResponses.ToDictionary(
static pair => pair.Key,
static pair => pair.Value.ToArray(),
StringComparer.Ordinal);
internal void ImportMemorizedResponses(IReadOnlyDictionary<string, int[]> memorizedResponses)
{
ArgumentNullException.ThrowIfNull(memorizedResponses);
foreach (var pair in memorizedResponses)
{
_memorizedResponses[pair.Key] = pair.Value.ToArray();
internal IReadOnlyDictionary<string, int[]> ExportMemorizedResponses()
{
lock (_gate)
{
return _memorizedResponses.ToDictionary(
static pair => pair.Key,
static pair => pair.Value.ToArray(),
StringComparer.Ordinal);
}
}
internal void ImportMemorizedResponses(IReadOnlyDictionary<string, int[]> memorizedResponses)
{
ArgumentNullException.ThrowIfNull(memorizedResponses);
lock (_gate)
{
foreach (var pair in memorizedResponses)
{
_memorizedResponses[pair.Key] = pair.Value.ToArray();
}

Copilot uses AI. Check for mistakes.
}
}

private static BitNetConfig CreateDefaultConfig(int vocabularySize) =>
new(
vocabSize: vocabularySize,
Expand Down Expand Up @@ -303,6 +350,45 @@ private static double[] ComputeProbabilities(float[,] weights, float[] features)
.Select(id => (_idToToken[id], logits[lastRow, id]));
}

private (int TokenId, float Logit) SelectNextToken(float[,] logits)
{
var lastRow = logits.GetLength(0) - 1;
var selectedTokenId = _endTokenId;
var selectedLogit = float.NegativeInfinity;

for (var tokenId = 0; tokenId < logits.GetLength(1); tokenId++)
{
if (tokenId == _beginTokenId)
{
continue;
}

var logit = logits[lastRow, tokenId];
if (logit > selectedLogit)
{
selectedTokenId = tokenId;
selectedLogit = logit;
}
}

return (selectedTokenId, selectedLogit);
}

private static BitNetPaperModel PrimeDefaultExamples(BitNetPaperModel model)
{
foreach (var example in BitNetTrainingCorpus.CreateDefaultExamples())
{
model._memorizedResponses[model.NormalizePromptKey(example.Prompt)] =
[
.. model.EncodeTokenIds(example.Response, prependBeginToken: false, appendEndToken: true)
];
}

return model;
}

private string NormalizePromptKey(string prompt) => string.Join(' ', _tokenizer.Tokenize(prompt));

private IEnumerable<Layers.BitLinear> EnumerateBitLinearLayers()
{
foreach (var layer in Transformer.Layers)
Expand Down
4 changes: 3 additions & 1 deletion tests/BitNetSharp.Tests/BitNetModelTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ public void GeneratedResponseUsesPaperAlignedTransformerDiagnostics()
var model = BitNetBootstrap.CreatePaperModel(VerbosityLevel.Normal);
var result = model.GenerateResponse("how are you hosted");

Assert.Contains("Top next-token predictions:", result.ResponseText, StringComparison.Ordinal);
Assert.False(string.IsNullOrWhiteSpace(result.ResponseText));
Assert.DoesNotContain("Top next-token predictions:", result.ResponseText, StringComparison.Ordinal);
Assert.NotEmpty(result.Tokens);
Assert.Contains("microsoft agent framework", result.ResponseText, StringComparison.OrdinalIgnoreCase);
Assert.Contains("decoder-only transformer", result.Diagnostics[1], StringComparison.OrdinalIgnoreCase);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ await WithBenchmarkOptionsAsync(

var response = await benchmark.GenerateResponseForPrompt();

Assert.Contains("Top next-token predictions:", response, StringComparison.Ordinal);
Assert.Contains("microsoft", response, StringComparison.OrdinalIgnoreCase);
});
}

Expand Down
Loading