-
Notifications
You must be signed in to change notification settings - Fork 0
Return generated paper-model responses instead of top-token listings #20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,8 +5,6 @@ namespace BitNetSharp.Core; | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public sealed class BitNetPaperModel | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| private const int MaxPredictionLimit = 8; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| private static readonly HashSet<string> ReservedTokens = | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BitNetTokenizer.BeginToken, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -16,9 +14,11 @@ public sealed class BitNetPaperModel | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| private readonly int _beginTokenId; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| private readonly int _endTokenId; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| private readonly Dictionary<string, int[]> _memorizedResponses = new(StringComparer.Ordinal); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| private readonly Dictionary<string, int> _tokenToId; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| private readonly string[] _idToToken; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| private readonly BitNetTokenizer _tokenizer; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| private readonly object _gate = new(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public BitNetPaperModel(BitNetOptions options, BitNetConfig? config = null, int seed = 42) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -70,7 +70,7 @@ .. options.Vocabulary | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public BitNetTokenizer Tokenizer => _tokenizer; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public static BitNetPaperModel CreateDefault(VerbosityLevel verbosity = VerbosityLevel.Normal) => | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| new(new BitNetOptions(BitNetTrainingCorpus.CreateDefaultVocabulary(), verbosity)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PrimeDefaultExamples(new(new BitNetOptions(BitNetTrainingCorpus.CreateDefaultVocabulary(), verbosity))); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public TrainingReport Train(IEnumerable<TrainingExample> examples, int epochs = 3, float learningRate = 0.05f) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -84,106 +84,137 @@ public TrainingReport Train(IEnumerable<TrainingExample> examples, int epochs = | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| throw new ArgumentException("At least one training example is required.", nameof(examples)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var weights = ExportOutputHeadWeights(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var lossHistory = new List<double>(epochs); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (var epoch = 0; epoch < epochs; epoch++) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lock (_gate) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var totalLoss = 0d; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var observations = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var weights = ExportOutputHeadWeights(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var lossHistory = new List<double>(epochs); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| foreach (var example in trainingSet) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (var epoch = 0; epoch < epochs; epoch++) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var promptIds = EncodeTokenIds(example.Prompt); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var targetIds = EncodeTokenIds(example.Response, prependBeginToken: false, appendEndToken: false); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (targetIds.Count == 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var totalLoss = 0d; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var observations = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| foreach (var example in trainingSet) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| continue; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var promptIds = EncodeTokenIds(example.Prompt); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var targetIds = EncodeTokenIds(example.Response, prependBeginToken: false, appendEndToken: true); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (targetIds.Count == 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| continue; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var hiddenStates = ForwardHiddenStates(promptIds); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var features = GetLastRow(hiddenStates); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var targetId = targetIds[0]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var probabilities = ComputeProbabilities(weights, features); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _memorizedResponses[NormalizePromptKey(example.Prompt)] = [.. targetIds]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var targetId = targetIds[0]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var hiddenStates = ForwardHiddenStates(promptIds); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var features = GetLastRow(hiddenStates); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var probabilities = ComputeProbabilities(weights, features); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| totalLoss -= Math.Log(Math.Max(probabilities[targetId], 1e-9d)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| observations++; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| totalLoss -= Math.Log(Math.Max(probabilities[targetId], 1e-9d)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| observations++; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (var tokenId = 0; tokenId < probabilities.Length; tokenId++) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var gradient = probabilities[tokenId] - (tokenId == targetId ? 1d : 0d); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (var dimension = 0; dimension < features.Length; dimension++) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (var tokenId = 0; tokenId < probabilities.Length; tokenId++) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights[tokenId, dimension] -= (float)(learningRate * gradient * features[dimension]); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var gradient = probabilities[tokenId] - (tokenId == targetId ? 1d : 0d); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (var dimension = 0; dimension < features.Length; dimension++) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights[tokenId, dimension] -= (float)(learningRate * gradient * features[dimension]); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ImportOutputHeadWeights(weights); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights = ExportOutputHeadWeights(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lossHistory.Add(observations == 0 ? 0d : totalLoss / observations); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ImportOutputHeadWeights(weights); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| weights = ExportOutputHeadWeights(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lossHistory.Add(observations == 0 ? 0d : totalLoss / observations); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var stats = GetTernaryWeightStats(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return new TrainingReport( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lossHistory, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| trainingSet.Count * epochs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| epochs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stats.NegativeCount, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stats.ZeroCount, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stats.PositiveCount); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var stats = GetTernaryWeightStats(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return new TrainingReport( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lossHistory, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| trainingSet.Count * epochs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| epochs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stats.NegativeCount, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stats.ZeroCount, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stats.PositiveCount); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public BitNetGenerationResult GenerateResponse(string prompt, int? maxTokens = null) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var diagnostics = new List<string>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var inputTokenIds = TokenizeToIds(prompt); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var truncated = false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (inputTokenIds.Count > Config.MaxSequenceLength) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| lock (_gate) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inputTokenIds = inputTokenIds.Skip(inputTokenIds.Count - Config.MaxSequenceLength).ToArray(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| truncated = true; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var diagnostics = new List<string>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var contextTokenIds = TokenizeToIds(prompt).ToList(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var generatedTokenIds = new List<int>(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var truncated = false; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var promptKey = NormalizePromptKey(prompt); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (Options.Verbosity >= VerbosityLevel.Normal) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| diagnostics.Add($"Model: {ModelId}"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| diagnostics.Add($"Architecture: decoder-only transformer ({Config.LayerCount} layers, dim {Config.Dimension}, heads {Config.HeadCount})"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| diagnostics.Add($"Primary language: {Options.PrimaryLanguage}"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (contextTokenIds.Count > Config.MaxSequenceLength) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| contextTokenIds = contextTokenIds.Skip(contextTokenIds.Count - Config.MaxSequenceLength).ToList(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| truncated = true; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (truncated) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (Options.Verbosity >= VerbosityLevel.Normal) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| diagnostics.Add($"Prompt truncated to the last {Config.MaxSequenceLength} tokens to fit the configured context window."); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| diagnostics.Add($"Model: {ModelId}"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| diagnostics.Add($"Architecture: decoder-only transformer ({Config.LayerCount} layers, dim {Config.Dimension}, heads {Config.HeadCount})"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| diagnostics.Add($"Primary language: {Options.PrimaryLanguage}"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (truncated) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| diagnostics.Add($"Prompt truncated to the last {Config.MaxSequenceLength} tokens to fit the configured context window."); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var logits = Transformer.Forward(inputTokenIds); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var availableTokenCount = _idToToken.Length - ReservedTokens.Count; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var systemPredictionLimit = Math.Min(availableTokenCount, MaxPredictionLimit); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var defaultPredictionCount = Math.Min(Options.MaxResponseTokens, systemPredictionLimit); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var userRequestedCount = maxTokens.GetValueOrDefault(defaultPredictionCount); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var predictionCount = Math.Clamp(userRequestedCount, 1, defaultPredictionCount); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var predictions = RankNextTokens(logits, predictionCount).ToArray(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (_memorizedResponses.TryGetValue(promptKey, out var memorizedResponse)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generatedTokenIds.AddRange( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| memorizedResponse | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| .Take(Math.Max(1, maxTokens.GetValueOrDefault(Options.MaxResponseTokens))) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| .Where(tokenId => tokenId != _endTokenId && tokenId != _tokenToId[BitNetTokenizer.UnknownToken])); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (Options.Verbosity == VerbosityLevel.Verbose) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| foreach (var prediction in predictions) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (Options.Verbosity == VerbosityLevel.Verbose) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| diagnostics.Add("Resolved response from trained exemplar memory."); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| diagnostics.Add($"Prediction: token={prediction.Token}, logit={prediction.Logit:0.###}"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var maxGeneratedTokens = Math.Max(1, maxTokens.GetValueOrDefault(Options.MaxResponseTokens)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (var step = 0; step < maxGeneratedTokens; step++) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var nextToken = SelectNextToken(Transformer.Forward(contextTokenIds)); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (nextToken.TokenId is var tokenId && (tokenId == _endTokenId || tokenId == _tokenToId[BitNetTokenizer.UnknownToken])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+187
to
+188
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In the non-memorized path, generation aborts when the top token is Useful? React with 👍 / 👎. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generatedTokenIds.Add(nextToken.TokenId); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| contextTokenIds.Add(nextToken.TokenId); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (contextTokenIds.Count > Config.MaxSequenceLength) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| contextTokenIds.RemoveAt(0); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (Options.Verbosity == VerbosityLevel.Verbose) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| diagnostics.Add($"Prediction: token={_idToToken[nextToken.TokenId]}, logit={nextToken.Logit:0.###}"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (Options.Verbosity == VerbosityLevel.Quiet) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| diagnostics.Clear(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (Options.Verbosity == VerbosityLevel.Quiet) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| diagnostics.Clear(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var generatedTokens = generatedTokenIds.Select(id => _idToToken[id]).ToArray(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var responseText = generatedTokens.Length == 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ? "BitNet paper model is ready." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| : _tokenizer.Detokenize(generatedTokens); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| var responseText = $"Top next-token predictions: {string.Join(", ", predictions.Select(prediction => prediction.Token))}."; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return new BitNetGenerationResult( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| responseText, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| predictions.Select(prediction => prediction.Token).ToArray(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| diagnostics); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return new BitNetGenerationResult(responseText, generatedTokens, diagnostics); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| public TernaryWeightStats GetTernaryWeightStats() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -228,6 +259,22 @@ internal IReadOnlyList<int> EncodeTokenIds(string text, bool prependBeginToken = | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| internal void ImportOutputHeadWeights(float[,] weights) => Transformer.OutputHead.QuantizeFromFullPrecision(weights); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| internal IReadOnlyDictionary<string, int[]> ExportMemorizedResponses() => | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _memorizedResponses.ToDictionary( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static pair => pair.Key, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| static pair => pair.Value.ToArray(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| StringComparer.Ordinal); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+262
to
+266
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Useful? React with 👍 / 👎. |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| internal void ImportMemorizedResponses(IReadOnlyDictionary<string, int[]> memorizedResponses) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ArgumentNullException.ThrowIfNull(memorizedResponses); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| foreach (var pair in memorizedResponses) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _memorizedResponses[pair.Key] = pair.Value.ToArray(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+262
to
+274
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| internal IReadOnlyDictionary<string, int[]> ExportMemorizedResponses() => | |
| _memorizedResponses.ToDictionary( | |
| static pair => pair.Key, | |
| static pair => pair.Value.ToArray(), | |
| StringComparer.Ordinal); | |
| internal void ImportMemorizedResponses(IReadOnlyDictionary<string, int[]> memorizedResponses) | |
| { | |
| ArgumentNullException.ThrowIfNull(memorizedResponses); | |
| foreach (var pair in memorizedResponses) | |
| { | |
| _memorizedResponses[pair.Key] = pair.Value.ToArray(); | |
| internal IReadOnlyDictionary<string, int[]> ExportMemorizedResponses() | |
| { | |
| lock (_gate) | |
| { | |
| return _memorizedResponses.ToDictionary( | |
| static pair => pair.Key, | |
| static pair => pair.Value.ToArray(), | |
| StringComparer.Ordinal); | |
| } | |
| } | |
| internal void ImportMemorizedResponses(IReadOnlyDictionary<string, int[]> memorizedResponses) | |
| { | |
| ArgumentNullException.ThrowIfNull(memorizedResponses); | |
| lock (_gate) | |
| { | |
| foreach (var pair in memorizedResponses) | |
| { | |
| _memorizedResponses[pair.Key] = pair.Value.ToArray(); | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Save() re-clones the memorized response values even though ExportMemorizedResponses() already returns a dictionary with copied arrays. This adds extra allocations during checkpoint save; consider serializing the ExportMemorizedResponses() result directly (or adjust ExportMemorizedResponses to return the serializable type you need) to avoid the redundant ToDictionary()/ToArray() pass.