Skip to content

Commit 2a080dd

Browse files
sharpninjaclaude
andcommitted
Optimize benchmark perplexity calculation from O(L^3) to O(L^2)
The previous CalculatePerplexity implementation was calling ForwardLogitsPerplexityStep once per target token, slicing the prefix and running a full forward pass of the whole transformer each time. For a sample of length L that was L forward passes of cost O(L^2) each — O(L^3) per sample — which made the full WikiText2 validation set take hours per model on a consumer CPU. Because the attention is causal (UsesCausalAttentionMask=true, and every head only attends to positions 0..targetPosition), a single forward pass already emits next-token logits for every position in the sequence: row i of the returned [seq_len, vocab_size] matrix predicts token i+1. The new implementation does one forward pass per chunk of at most MaxSequenceLength tokens and reads all L-1 per-row predictions in a single sweep. That drops per-sample cost from O(L^3) to O(L^2), roughly an L-fold speedup (e.g. ~150x on a typical 150-token sample). Additional changes: - Parallelize sample tokenization in CalculatePerplexity via Parallel.For (the tokenizer is stateless/read-only). - Add per-phase and per-model progress logging to HostedAgentBenchmarkReportRunner, with intermediate JSON dumps of model and performance reports so long runs can be analyzed mid-flight (artifacts/benchmark-report/progress.log and per-model JSON files). - Add --perplexity-sample-percent CLI option to the benchmark-report command (default 10 in code) and plumb it through to the perplexity sample selection via stride sampling. - Add PerplexitySamplePercent pipeline variable to both Azure Pipelines and GitHub Actions workflows, default 100% so the optimized path is exercised on the full validation set. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b0be2d7 commit 2a080dd

File tree

5 files changed

+266
-19
lines changed

5 files changed

+266
-19
lines changed

.github/workflows/benchmark-report.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@ on:
1010
- "src/BitNetSharp.Core/**"
1111
- "tests/BitNetSharp.Tests/**"
1212
workflow_dispatch:
13+
inputs:
14+
perplexity_sample_percent:
15+
description: "Percentage (0, 100] of the WikiText2 validation set to evaluate for perplexity (default 100)"
16+
required: false
17+
default: "100"
18+
19+
env:
20+
PERPLEXITY_SAMPLE_PERCENT: ${{ github.event.inputs.perplexity_sample_percent || '100' }}
1321

1422
permissions:
1523
contents: read
@@ -51,6 +59,7 @@ jobs:
5159
--compare-model=traditional-local
5260
--commit=${{ github.sha }}
5361
--output="${{ github.workspace }}/artifacts/benchmark-report"
62+
--perplexity-sample-percent=${{ env.PERPLEXITY_SAMPLE_PERCENT }}
5463
5564
- name: Upload benchmark report artifact
5665
uses: actions/upload-artifact@v4

azure-pipelines-benchmark-report.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,17 @@ variables:
1616
SolutionFile: BitNet-b1.58-Sharp.slnx
1717
BenchmarkArtifactName: benchmark-report
1818
AzureStaticWebAppsApiToken: ""
19+
# Percentage (0, 100] of the WikiText2 validation set to evaluate for perplexity in the
20+
# benchmark report. Stride-sampled evenly across the corpus. Default 100% (full coverage).
21+
PerplexitySamplePercent: "100"
1922

2023
stages:
2124
- stage: benchmark
2225
displayName: Build benchmark report
2326
jobs:
2427
- job: benchmark
2528
displayName: Run slow-lane validation and generate report
26-
timeoutInMinutes: 45
29+
timeoutInMinutes: 0
2730
pool:
2831
name: Default
2932
steps:
@@ -58,6 +61,7 @@ stages:
5861
--compare-model=traditional-local
5962
--commit=$(Build.SourceVersion)
6063
--output="$(Build.ArtifactStagingDirectory)/benchmark-report"
64+
--perplexity-sample-percent=$(PerplexitySamplePercent)
6165
displayName: Generate benchmark comparison report
6266
6367
- task: PublishPipelineArtifact@1

src/BitNetSharp.App/HostedAgentBenchmarkReportRunner.cs

Lines changed: 148 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,29 +84,64 @@ public static class HostedAgentBenchmarkReportRunner
8484
private const double MillisecondsPerSecond = 1_000d;
8585
private const double BytesPerMegabyte = 1024d * 1024d;
8686
private const double KilobytesPerMegabyte = 1024d;
87+
public const double DefaultPerplexitySamplePercent = 10d;
88+
8789
public static async Task<string> RunAsync(
8890
HostedAgentBenchmarkOptions options,
8991
string? outputDirectory,
9092
string? commitHash = null,
93+
double perplexitySamplePercent = DefaultPerplexitySamplePercent,
9194
CancellationToken cancellationToken = default)
9295
{
9396
ArgumentNullException.ThrowIfNull(options);
97+
if (perplexitySamplePercent <= 0d || perplexitySamplePercent > 100d)
98+
{
99+
throw new ArgumentOutOfRangeException(
100+
nameof(perplexitySamplePercent),
101+
perplexitySamplePercent,
102+
"perplexitySamplePercent must be in the range (0, 100].");
103+
}
94104

95105
var originalWorkingDirectory = Directory.GetCurrentDirectory();
96106
var reportDirectory = Path.GetFullPath(
97107
string.IsNullOrWhiteSpace(outputDirectory)
98108
? Path.Combine(originalWorkingDirectory, "artifacts", "benchmark-report")
99109
: outputDirectory);
100110
Directory.CreateDirectory(reportDirectory);
111+
112+
var progressLogPath = Path.Combine(reportDirectory, "progress.log");
113+
LogProgress(progressLogPath, $"RunAsync started (perplexitySamplePercent={perplexitySamplePercent:0.##}%)");
114+
115+
LogProgress(progressLogPath, "Phase 1: BenchmarkDotNet suite starting");
101116
HostedAgentBenchmarkRunner.Run(options);
117+
LogProgress(progressLogPath, "Phase 1: BenchmarkDotNet suite complete");
118+
119+
LogProgress(progressLogPath, "Phase 2: Copying BDN artifacts");
102120
CopyArtifactsDirectory(
103121
Path.Combine(originalWorkingDirectory, "BenchmarkDotNet.Artifacts"),
104122
Path.Combine(reportDirectory, "BenchmarkDotNet.Artifacts"));
123+
LogProgress(progressLogPath, "Phase 2: BDN artifacts copied");
105124

125+
LogProgress(progressLogPath, "Phase 3: Creating benchmark examples");
106126
var trainingExamples = BitNetTrainingCorpus.CreateBenchmarkExamples();
107-
var modelReports = await CreateModelReportsAsync(options, trainingExamples, cancellationToken);
127+
LogProgress(progressLogPath, $"Phase 3: Created {trainingExamples.Count} training examples");
128+
129+
LogProgress(progressLogPath, "Phase 4: CreateModelReportsAsync starting");
130+
var modelReports = await CreateModelReportsAsync(options, trainingExamples, progressLogPath, perplexitySamplePercent, cancellationToken);
131+
LogProgress(progressLogPath, $"Phase 4: CreateModelReportsAsync complete ({modelReports.Count} reports)");
132+
133+
// Save intermediate model reports in case the rest fails.
134+
SaveIntermediate(reportDirectory, "model-reports.json", modelReports);
135+
136+
LogProgress(progressLogPath, "Phase 5: Parsing performance rows");
108137
var performanceRows = ParsePerformanceRows(reportDirectory);
138+
LogProgress(progressLogPath, $"Phase 5: Parsed {performanceRows.Count} performance rows");
139+
SaveIntermediate(reportDirectory, "performance-rows.json", performanceRows);
140+
141+
LogProgress(progressLogPath, "Phase 6: Creating comparison summary");
109142
var comparisonSummary = CreateComparisonSummary(modelReports, performanceRows);
143+
LogProgress(progressLogPath, "Phase 6: Comparison summary complete");
144+
110145
var report = new HostedAgentBenchmarkComparisonReport(
111146
DateTimeOffset.UtcNow,
112147
trainingExamples.Select(static example => example.Prompt).ToArray(),
@@ -115,10 +150,43 @@ public static async Task<string> RunAsync(
115150
comparisonSummary,
116151
BitNetTrainingCorpus.BenchmarkDatasetName);
117152

153+
LogProgress(progressLogPath, "Phase 7: Writing report site");
118154
WriteReportSite(reportDirectory, report, commitHash);
155+
LogProgress(progressLogPath, "Phase 7: Report site written");
156+
157+
LogProgress(progressLogPath, "RunAsync complete");
119158
return reportDirectory;
120159
}
121160

161+
private static void LogProgress(string path, string message)
162+
{
163+
var timestamp = DateTimeOffset.UtcNow.ToString("yyyy-MM-ddTHH:mm:ss.fffZ");
164+
var line = $"{timestamp} {message}";
165+
try
166+
{
167+
File.AppendAllText(path, line + Environment.NewLine);
168+
}
169+
catch
170+
{
171+
// Ignore log failures.
172+
}
173+
174+
Console.WriteLine($"[PROGRESS] {line}");
175+
}
176+
177+
private static void SaveIntermediate<T>(string reportDirectory, string fileName, T data)
178+
{
179+
try
180+
{
181+
var path = Path.Combine(reportDirectory, fileName);
182+
File.WriteAllText(path, JsonSerializer.Serialize(data, new JsonSerializerOptions { WriteIndented = true }));
183+
}
184+
catch (Exception ex)
185+
{
186+
Console.WriteLine($"[PROGRESS] Failed to save intermediate {fileName}: {ex.Message}");
187+
}
188+
}
189+
122190
public static IReadOnlyList<HostedAgentBenchmarkPerformanceRow> ParsePerformanceRows(string reportDirectory)
123191
{
124192
var resultsDirectory = Path.Combine(reportDirectory, "BenchmarkDotNet.Artifacts", "results");
@@ -211,37 +279,54 @@ public static void WriteReportSite(string outputDirectory, HostedAgentBenchmarkC
211279
private static async Task<IReadOnlyList<HostedAgentBenchmarkModelReport>> CreateModelReportsAsync(
212280
HostedAgentBenchmarkOptions options,
213281
IReadOnlyList<TrainingExample> trainingExamples,
282+
string progressLogPath,
283+
double perplexitySamplePercent,
214284
CancellationToken cancellationToken)
215285
{
216286
var reports = new List<HostedAgentBenchmarkModelReport>();
287+
var reportDirectory = Path.GetDirectoryName(progressLogPath) ?? string.Empty;
288+
217289
foreach (var modelSpecifier in options.ModelSpecifiers)
218290
{
291+
LogProgress(progressLogPath, $"Model '{modelSpecifier}': creating prepared model");
219292
using var model = HostedAgentBenchmarkModelBootstrap.CreatePreparedModel(modelSpecifier, options, trainingExamples);
293+
LogProgress(progressLogPath, $"Model '{modelSpecifier}': prepared model created");
294+
220295
var trainingSupported = model is ITrainableHostedAgentModel;
221296
var trainingCompleted = false;
222297
var trainingEpochs = 0;
223298
if (options.EnableBucketing && model is BitNetHostedAgentModel preTrainedBitNetModel)
224299
{
300+
LogProgress(progressLogPath, $"Model '{modelSpecifier}': mining pre-training buckets");
225301
preTrainedBitNetModel.Model.MineAndLoadBuckets(trainingExamples);
302+
LogProgress(progressLogPath, $"Model '{modelSpecifier}': pre-training buckets mined");
226303
}
227304

228305
if (trainingSupported)
229306
{
230307
trainingEpochs = GetTrainingEpochs(model);
308+
LogProgress(progressLogPath, $"Model '{modelSpecifier}': training starting ({trainingEpochs} epochs)");
231309
((ITrainableHostedAgentModel)model).Train(trainingExamples, trainingEpochs);
232310
trainingCompleted = true;
311+
LogProgress(progressLogPath, $"Model '{modelSpecifier}': training complete");
233312

234313
if (options.EnableBucketing && model is BitNetHostedAgentModel trainedBitNetModel)
235314
{
315+
LogProgress(progressLogPath, $"Model '{modelSpecifier}': mining post-training buckets");
236316
trainedBitNetModel.Model.MineAndLoadBuckets(trainingExamples);
317+
LogProgress(progressLogPath, $"Model '{modelSpecifier}': post-training buckets mined");
237318
}
238319
}
239320

321+
LogProgress(progressLogPath, $"Model '{modelSpecifier}': running query examples");
240322
var queryResults = new List<HostedAgentBenchmarkQueryResult>(trainingExamples.Count);
241323
var attemptedChainTokens = 0;
242324
var acceptedChainTokens = 0;
325+
var queryIndex = 0;
243326
foreach (var example in trainingExamples)
244327
{
328+
queryIndex++;
329+
LogProgress(progressLogPath, $"Model '{modelSpecifier}': query {queryIndex}/{trainingExamples.Count}");
245330
string responseText;
246331
if (model is BitNetHostedAgentModel bitNetModel)
247332
{
@@ -269,13 +354,34 @@ private static async Task<IReadOnlyList<HostedAgentBenchmarkModelReport>> Create
269354
CalculateExpectedTokenRecall(responseText, example.Response)));
270355
}
271356

357+
LogProgress(progressLogPath, $"Model '{modelSpecifier}': queries complete");
358+
272359
double? chainBucketAcceptanceRate = null;
273360
if (attemptedChainTokens > 0)
274361
{
275362
chainBucketAcceptanceRate = acceptedChainTokens / (double)attemptedChainTokens;
276363
}
277364

278-
reports.Add(new HostedAgentBenchmarkModelReport(
365+
LogProgress(progressLogPath, $"Model '{modelSpecifier}': computing perplexity ({perplexitySamplePercent:0.##}% of WikiText2 validation)");
366+
var perplexity = GetWikiText2Perplexity(model, perplexitySamplePercent);
367+
LogProgress(progressLogPath, $"Model '{modelSpecifier}': perplexity = {perplexity}");
368+
369+
LogProgress(progressLogPath, $"Model '{modelSpecifier}': computing benchmark prompt token count");
370+
var promptTokenCount = await GetBenchmarkPromptTokenCountAsync(model, options, cancellationToken);
371+
LogProgress(progressLogPath, $"Model '{modelSpecifier}': prompt token count = {promptTokenCount}");
372+
373+
LogProgress(progressLogPath, $"Model '{modelSpecifier}': estimating resident model megabytes");
374+
var residentMb = GetEstimatedResidentModelMegabytes(model);
375+
376+
BitNetPaperAuditReport? auditReport = null;
377+
if (model is BitNetHostedAgentModel auditBitNetModel)
378+
{
379+
LogProgress(progressLogPath, $"Model '{modelSpecifier}': running paper alignment audit");
380+
auditReport = BitNetPaperAuditor.CreateReport(auditBitNetModel.Model);
381+
LogProgress(progressLogPath, $"Model '{modelSpecifier}': audit complete");
382+
}
383+
384+
var modelReport = new HostedAgentBenchmarkModelReport(
279385
model.ModelId,
280386
model.DisplayName,
281387
trainingSupported,
@@ -287,11 +393,20 @@ private static async Task<IReadOnlyList<HostedAgentBenchmarkModelReport>> Create
287393
queryResults.Count(static result => result.ExactMatch),
288394
queryResults.Count == 0 ? 0d : queryResults.Average(static result => result.ExpectedTokenRecall),
289395
queryResults,
290-
model is BitNetHostedAgentModel auditBitNetModel ? BitNetPaperAuditor.CreateReport(auditBitNetModel.Model) : null,
291-
GetWikiText2Perplexity(model),
292-
await GetBenchmarkPromptTokenCountAsync(model, options, cancellationToken),
293-
GetEstimatedResidentModelMegabytes(model),
294-
chainBucketAcceptanceRate));
396+
auditReport,
397+
perplexity,
398+
promptTokenCount,
399+
residentMb,
400+
chainBucketAcceptanceRate);
401+
402+
reports.Add(modelReport);
403+
LogProgress(progressLogPath, $"Model '{modelSpecifier}': report added ({reports.Count}/{options.ModelSpecifiers.Count})");
404+
405+
// Persist incremental reports in case a subsequent model crashes.
406+
if (!string.IsNullOrEmpty(reportDirectory))
407+
{
408+
SaveIntermediate(reportDirectory, $"model-report-{reports.Count:D2}-{modelSpecifier.Replace('/', '_').Replace('.', '_')}.json", modelReport);
409+
}
295410
}
296411

297412
return reports;
@@ -584,13 +699,35 @@ private static async Task<int> GetBenchmarkPromptTokenCountAsync(
584699
return CountResponseTokens(model, benchmarkResponse.Text);
585700
}
586701

587-
private static double? GetWikiText2Perplexity(IHostedAgentModel model) =>
588-
model switch
702+
// Use the configured percentage of the WikiText2 validation set for the benchmark-report
703+
// perplexity calculation, stride-sampled evenly across the full validation set so coverage
704+
// is representative of the entire corpus. The full 3,760-sample set takes hours to evaluate
705+
// on a consumer CPU; 10% (376 entries) runs in a few minutes and is sufficient for relative
706+
// comparison between models.
707+
private static IReadOnlyList<string> GetBenchmarkWikiText2ValidationSamples(double samplePercent)
708+
{
709+
var all = BitNetBenchmarkFixtures.WikiText2ValidationSamples;
710+
var targetCount = Math.Max(1, (int)Math.Ceiling(all.Count * (samplePercent / 100d)));
711+
var stride = Math.Max(1, all.Count / targetCount);
712+
var samples = new List<string>(targetCount);
713+
for (var i = 0; i < all.Count && samples.Count < targetCount; i += stride)
714+
{
715+
samples.Add(all[i]);
716+
}
717+
718+
return samples;
719+
}
720+
721+
private static double? GetWikiText2Perplexity(IHostedAgentModel model, double samplePercent)
722+
{
723+
var samples = GetBenchmarkWikiText2ValidationSamples(samplePercent);
724+
return model switch
589725
{
590-
BitNetHostedAgentModel bitNetModel => bitNetModel.Model.CalculatePerplexity(BitNetBenchmarkFixtures.WikiText2ValidationSamples),
591-
TraditionalLocalHostedAgentModel traditionalModel => traditionalModel.Model.CalculatePerplexity(BitNetBenchmarkFixtures.WikiText2ValidationSamples),
726+
BitNetHostedAgentModel bitNetModel => bitNetModel.Model.CalculatePerplexity(samples),
727+
TraditionalLocalHostedAgentModel traditionalModel => traditionalModel.Model.CalculatePerplexity(samples),
592728
_ => null
593729
};
730+
}
594731

595732
private static double? GetEstimatedResidentModelMegabytes(IHostedAgentModel model)
596733
{

src/BitNetSharp.App/Program.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,21 @@
2222
{
2323
var reportDirectory = ParseOption(args, "--output=");
2424
var commitHash = ParseOption(args, "--commit=");
25+
var perplexitySamplePercentRaw = ParseOption(args, "--perplexity-sample-percent=");
26+
var perplexitySamplePercent = 10d;
27+
if (!string.IsNullOrWhiteSpace(perplexitySamplePercentRaw)
28+
&& double.TryParse(perplexitySamplePercentRaw, System.Globalization.NumberStyles.Float, System.Globalization.CultureInfo.InvariantCulture, out var parsedPercent)
29+
&& parsedPercent > 0d
30+
&& parsedPercent <= 100d)
31+
{
32+
perplexitySamplePercent = parsedPercent;
33+
}
34+
2535
var outputPath = await HostedAgentBenchmarkReportRunner.RunAsync(
2636
HostedAgentBenchmarkOptions.Parse(args, verbosity),
2737
reportDirectory,
28-
commitHash);
38+
commitHash,
39+
perplexitySamplePercent);
2940
Console.WriteLine($"Saved benchmark comparison report to {outputPath}");
3041

3142
// Force exit: BenchmarkDotNet and the Microsoft.Agents.AI hosting framework

0 commit comments

Comments
 (0)