Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/training-and-visualization.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dotnet run --project src/BitNetSharp.App/BitNetSharp.App.csproj -- visualize
dotnet run --project src/BitNetSharp.App/BitNetSharp.App.csproj -- paper-audit
```

This command prints the current paper-model configuration and an aggregated ternary weight histogram across every `BitLinear` projection in the seeded transformer.
This command prints the current model configuration and an aggregated signed-weight histogram for the selected built-in model, so the seeded transformer and `traditional-local` expose the same comparison-friendly inspection surface.
The `paper-audit` command adds a structured checklist on top of that inspection output so the repository can report which paper-aligned architecture requirements are currently implemented and which end-to-end reproduction items are still pending.

## Inspect next-token predictions
Expand Down
2 changes: 1 addition & 1 deletion docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ dotnet run --project /home/runner/work/BitNet-b1.58-Sharp/BitNet-b1.58-Sharp/src
dotnet run --project /home/runner/work/BitNet-b1.58-Sharp/BitNet-b1.58-Sharp/src/BitNetSharp.App/BitNetSharp.App.csproj -- paper-audit
```

The `visualize` command prints the current model summary. When the selected model is the paper-aligned BitNet transformer, it also prints the ternary weight histogram across the transformer's `BitLinear` projections.
The `visualize` command prints the current model summary. For the built-in BitNet and `traditional-local` models, it also prints a signed weight histogram so both comparison models expose the same inspection surface from the CLI.

The `paper-audit` command turns the paper checklist into an executable report. It confirms the implemented architecture requirements that the repository currently satisfies and also verifies the repository-local runtime surface for paper-model fine-tuning, named perplexity fixture measurements, zero-shot fixture evaluation, and checkpoint round-tripping.

Expand Down
4 changes: 2 additions & 2 deletions src/BitNetSharp.App/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
}
else
{
Console.WriteLine($"Model '{model.ModelId}' does not expose BitNet ternary weight inspection.");
Console.WriteLine($"Model '{model.ModelId}' does not expose repository weight-sign inspection.");
}
break;

Expand Down Expand Up @@ -119,7 +119,7 @@ static string FormatWeightHistogram(TernaryWeightStats stats)
return string.Join(
Environment.NewLine,
[
"Ternary weight distribution",
"Weight sign distribution",
FormatBar("-1", stats.NegativeCount, max, scale),
FormatBar(" 0", stats.ZeroCount, max, scale),
FormatBar("+1", stats.PositiveCount, max, scale)
Expand Down
5 changes: 4 additions & 1 deletion src/BitNetSharp.App/TraditionalLocalHostedAgentModel.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using BitNetSharp.Core;
using BitNetSharp.Core.Quantization;

namespace BitNetSharp.App;

public sealed class TraditionalLocalHostedAgentModel : IHostedAgentModel, ITrainableHostedAgentModel
public sealed class TraditionalLocalHostedAgentModel : IHostedAgentModel, IInspectableHostedAgentModel, ITrainableHostedAgentModel
{
private readonly string _trainingCorpusDescription;

Expand Down Expand Up @@ -51,6 +52,8 @@ public Task<HostedAgentModelResponse> GetResponseAsync(
return Task.FromResult(new HostedAgentModelResponse(result.ResponseText, result.Diagnostics));
}

public TernaryWeightStats GetTernaryWeightStats() => Model.GetTernaryWeightStats();

public void Train(IEnumerable<TrainingExample> examples, int epochs = 1)
{
Model.Train(examples, Math.Max(TraditionalLocalModel.DefaultTrainingEpochs, epochs));
Expand Down
2 changes: 1 addition & 1 deletion src/BitNetSharp.Core/BitNetVisualizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ private static string RenderWeightHistogram(TrainingReport trainingReport)
Environment.NewLine,
new[]
{
"Ternary weight distribution",
"Weight sign distribution",
FormatBar("-1", trainingReport.NegativeWeights, max),
FormatBar(" 0", trainingReport.ZeroWeights, max),
FormatBar("+1", trainingReport.PositiveWeights, max)
Expand Down
104 changes: 104 additions & 0 deletions src/BitNetSharp.Core/TraditionalLocalCheckpoint.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
using System.Text.Json;

namespace BitNetSharp.Core;

public sealed record TraditionalLocalCheckpointValidationResult(
string Prompt,
string OriginalResponse,
string ReloadedResponse,
bool ResponsesMatch);

internal sealed record TraditionalLocalCheckpointDocument(
string Format,
string ModelId,
int Seed,
int EmbeddingDimension,
int ContextWindow,
IReadOnlyList<string> Vocabulary,
float[] TokenEmbeddings,
float[] OutputWeights,
float[] OutputBias,
int MaxResponseTokens,
string PrimaryLanguage);

public static class TraditionalLocalCheckpoint
{
private const string FormatName = "traditional-local.repository-checkpoint.v1";

public static void Save(TraditionalLocalModel model, string path)
{
ArgumentNullException.ThrowIfNull(model);
ArgumentException.ThrowIfNullOrWhiteSpace(path);

var directory = Path.GetDirectoryName(path);
if (!string.IsNullOrWhiteSpace(directory))
{
Directory.CreateDirectory(directory);
}

var document = new TraditionalLocalCheckpointDocument(
FormatName,
model.ModelId,
model.Seed,
model.EmbeddingDimension,
model.ContextWindow,
model.Options.Vocabulary.ToArray(),
model.ExportTokenEmbeddings(),
model.ExportOutputWeights(),
model.ExportOutputBias(),
model.Options.MaxResponseTokens,
model.Options.PrimaryLanguage);
File.WriteAllText(path, JsonSerializer.Serialize(document, new JsonSerializerOptions { WriteIndented = true }));
}

public static TraditionalLocalModel Load(string path, VerbosityLevel verbosity = VerbosityLevel.Normal)
{
ArgumentException.ThrowIfNullOrWhiteSpace(path);

var document = JsonSerializer.Deserialize<TraditionalLocalCheckpointDocument>(File.ReadAllText(path))
?? throw new InvalidOperationException("Could not deserialize the traditional local checkpoint document.");
if (!string.Equals(document.Format, FormatName, StringComparison.Ordinal))
{
throw new InvalidOperationException($"Unsupported checkpoint format '{document.Format}'.");
}

var model = new TraditionalLocalModel(
new BitNetOptions(
document.Vocabulary.ToArray(),
verbosity,
document.MaxResponseTokens,
document.PrimaryLanguage),
document.EmbeddingDimension,
document.ContextWindow,
document.Seed);
model.ImportState(document.TokenEmbeddings, document.OutputWeights, document.OutputBias);
return model;
}

public static TraditionalLocalCheckpointValidationResult ValidateRoundTrip(TraditionalLocalModel model, string prompt)
{
ArgumentNullException.ThrowIfNull(model);
ArgumentException.ThrowIfNullOrWhiteSpace(prompt);

var checkpointPath = Path.Combine(Path.GetTempPath(), $"traditional-local-checkpoint-{Guid.NewGuid():N}.json");
try
{
Save(model, checkpointPath);
var reloaded = Load(checkpointPath, model.Options.Verbosity);
var original = model.GenerateResponse(prompt, maxTokens: 4);
var roundTripped = reloaded.GenerateResponse(prompt, maxTokens: 4);
return new TraditionalLocalCheckpointValidationResult(
prompt,
original.ResponseText,
roundTripped.ResponseText,
string.Equals(original.ResponseText, roundTripped.ResponseText, StringComparison.Ordinal));
}
finally
{
if (File.Exists(checkpointPath))
{
File.Delete(checkpointPath);
}
}
}
}
101 changes: 100 additions & 1 deletion src/BitNetSharp.Core/TraditionalLocalModel.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Numerics.Tensors;
using BitNetSharp.Core.Quantization;

namespace BitNetSharp.Core;

Expand Down Expand Up @@ -106,6 +107,8 @@ .. options.Vocabulary

public BitNetTokenizer Tokenizer => _tokenizer;

internal int Seed => _seed;

public static TraditionalLocalModel CreateDefault(VerbosityLevel verbosity = VerbosityLevel.Normal) =>
new(new BitNetOptions(BitNetTrainingCorpus.CreateDefaultVocabulary(), verbosity));

Expand Down Expand Up @@ -176,7 +179,14 @@ public TrainingReport Train(IEnumerable<TrainingExample> examples, int epochs =
}

_isTrained = true;
return new TrainingReport(history, totalSamples, epochs, 0, 0, 0);
var stats = GetTernaryWeightStats();
return new TrainingReport(
history,
totalSamples,
epochs,
stats.NegativeCount,
stats.ZeroCount,
stats.PositiveCount);
}
}

Expand Down Expand Up @@ -269,6 +279,76 @@ public double CalculatePerplexity(IEnumerable<string> validationSamples)
}
}

public TernaryWeightStats GetTernaryWeightStats()
{
lock (_gate)
{
var negative = 0;
var zero = 0;
var positive = 0;

CountWeightSigns(_tokenEmbeddings, ref negative, ref zero, ref positive);
CountWeightSigns(_outputWeights, ref negative, ref zero, ref positive);
CountWeightSigns(_outputBias, ref negative, ref zero, ref positive);

return new TernaryWeightStats(negative, zero, positive);
}
}

internal float[] ExportTokenEmbeddings()
{
lock (_gate)
{
return [.. _tokenEmbeddings];
}
}

internal float[] ExportOutputWeights()
{
lock (_gate)
{
return [.. _outputWeights];
}
}

internal float[] ExportOutputBias()
{
lock (_gate)
{
return [.. _outputBias];
}
}

internal void ImportState(float[] tokenEmbeddings, float[] outputWeights, float[] outputBias)
{
ArgumentNullException.ThrowIfNull(tokenEmbeddings);
ArgumentNullException.ThrowIfNull(outputWeights);
ArgumentNullException.ThrowIfNull(outputBias);

lock (_gate)
{
if (tokenEmbeddings.Length != _tokenEmbeddings.Length)
{
throw new ArgumentException($"Token embedding length {tokenEmbeddings.Length} does not match expected length {_tokenEmbeddings.Length}.", nameof(tokenEmbeddings));
}

if (outputWeights.Length != _outputWeights.Length)
{
throw new ArgumentException($"Output weight length {outputWeights.Length} does not match expected length {_outputWeights.Length}.", nameof(outputWeights));
}

if (outputBias.Length != _outputBias.Length)
{
throw new ArgumentException($"Output bias length {outputBias.Length} does not match expected length {_outputBias.Length}.", nameof(outputBias));
}

tokenEmbeddings.CopyTo(_tokenEmbeddings, 0);
outputWeights.CopyTo(_outputWeights, 0);
outputBias.CopyTo(_outputBias, 0);
_isTrained = true;
}
}

private void EnsureTrained()
{
if (_isTrained)
Expand Down Expand Up @@ -424,6 +504,25 @@ private void ResetParameters()
_isTrained = false;
}

private static void CountWeightSigns(float[] values, ref int negative, ref int zero, ref int positive)
{
foreach (var value in values)
{
if (value > 0f)
{
positive++;
}
else if (value < 0f)
{
negative++;
}
else
{
zero++;
}
}
}

private static void FillWithDeterministicNoise(float[] values, Random random)
{
for (var index = 0; index < values.Length; index++)
Expand Down
34 changes: 34 additions & 0 deletions tests/BitNetSharp.Tests/BitNetModelTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,40 @@ public void TraditionalLocalModelLearnsSimplePromptResponse()
Assert.Contains(result.Diagnostics, diagnostic => diagnostic.Contains("tensor-based ordered-context", StringComparison.OrdinalIgnoreCase));
}

[Fact]
public void TraditionalLocalTrainingReportIncludesWeightSignDistribution()
{
var model = new TraditionalLocalModel(
new BitNetOptions(["alpha", "beta", "gamma", "delta"], VerbosityLevel.Quiet),
embeddingDimension: 8,
contextWindow: 4,
seed: 19);

var report = model.Train(
[
new TrainingExample("alpha beta", "gamma delta")
],
epochs: 12,
learningRate: 0.3f);

Assert.True(report.NegativeWeights > 0);
Assert.True(report.PositiveWeights > 0);
Assert.Equal(report.NegativeWeights + report.ZeroWeights + report.PositiveWeights, model.GetTernaryWeightStats().TotalCount);
}

[Fact]
public void TraditionalHostedAgentModelExposesInspectableWeightStats()
{
using var model = HostedAgentModelFactory.Create(HostedAgentModelFactory.TraditionalLocalModelId, VerbosityLevel.Quiet);

var inspectable = Assert.IsAssignableFrom<IInspectableHostedAgentModel>(model);
var stats = inspectable.GetTernaryWeightStats();

Assert.True(stats.TotalCount > 0);
Assert.True(stats.NegativeCount > 0);
Assert.True(stats.PositiveCount > 0);
}

[Fact]
public void BenchmarkOptionsIncludePrimaryAndComparisonModels()
{
Expand Down
17 changes: 17 additions & 0 deletions tests/BitNetSharp.Tests/HostedAgentBenchmarksExecutionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,23 @@ public void BenchmarkModelConstructionUsesTheTinyLlamaTrainingVocabulary()
Assert.Equal("tinyllama", ((TraditionalLocalHostedAgentModel)traditional).Model.Tokenizer.Normalize("tinyllama"));
}

[Fact]
public void BuiltInModelsPreserveTrainedResponsesAcrossCheckpointRoundTrips()
{
var examples = BitNetTrainingCorpus.CreateBenchmarkExamples();
var bitNetModel = BitNetPaperModel.CreateForTrainingCorpus(examples, VerbosityLevel.Quiet);
var traditionalModel = TraditionalLocalModel.CreateForTrainingCorpus(examples, VerbosityLevel.Quiet);

bitNetModel.Train(examples, epochs: 1);
traditionalModel.Train(examples, epochs: TraditionalLocalModel.DefaultTrainingEpochs);

var bitNetRoundTrip = BitNetPaperCheckpoint.ValidateRoundTrip(bitNetModel, "what does the paper model train on");
var traditionalRoundTrip = TraditionalLocalCheckpoint.ValidateRoundTrip(traditionalModel, "what does the paper model train on");

Assert.True(bitNetRoundTrip.ResponsesMatch);
Assert.True(traditionalRoundTrip.ResponsesMatch);
}

private static async Task WithBenchmarkOptionsAsync(HostedAgentBenchmarkOptions options, Func<Task> assertion)
{
var originalValue = Environment.GetEnvironmentVariable(HostedAgentBenchmarkOptions.EnvironmentVariableName);
Expand Down
Loading
Loading