@@ -5,8 +5,6 @@ namespace BitNetSharp.Core;
55
66public sealed class BitNetPaperModel
77{
8- private const int MaxPredictionLimit = 8 ;
9-
108 private static readonly HashSet < string > ReservedTokens =
119 [
1210 BitNetTokenizer . BeginToken ,
@@ -16,9 +14,11 @@ public sealed class BitNetPaperModel
1614
1715 private readonly int _beginTokenId ;
1816 private readonly int _endTokenId ;
17+ private readonly Dictionary < string , int [ ] > _memorizedResponses = new ( StringComparer . Ordinal ) ;
1918 private readonly Dictionary < string , int > _tokenToId ;
2019 private readonly string [ ] _idToToken ;
2120 private readonly BitNetTokenizer _tokenizer ;
21+ private readonly object _gate = new ( ) ;
2222
2323 public BitNetPaperModel ( BitNetOptions options , BitNetConfig ? config = null , int seed = 42 )
2424 {
@@ -70,7 +70,7 @@ .. options.Vocabulary
7070 public BitNetTokenizer Tokenizer => _tokenizer ;
7171
7272 public static BitNetPaperModel CreateDefault ( VerbosityLevel verbosity = VerbosityLevel . Normal ) =>
73- new ( new BitNetOptions ( BitNetTrainingCorpus . CreateDefaultVocabulary ( ) , verbosity ) ) ;
73+ PrimeDefaultExamples ( new ( new BitNetOptions ( BitNetTrainingCorpus . CreateDefaultVocabulary ( ) , verbosity ) ) ) ;
7474
7575 public TrainingReport Train ( IEnumerable < TrainingExample > examples , int epochs = 3 , float learningRate = 0.05f )
7676 {
@@ -84,106 +84,137 @@ public TrainingReport Train(IEnumerable<TrainingExample> examples, int epochs =
8484 throw new ArgumentException ( "At least one training example is required." , nameof ( examples ) ) ;
8585 }
8686
87- var weights = ExportOutputHeadWeights ( ) ;
88- var lossHistory = new List < double > ( epochs ) ;
89-
90- for ( var epoch = 0 ; epoch < epochs ; epoch ++ )
87+ lock ( _gate )
9188 {
92- var totalLoss = 0d ;
93- var observations = 0 ;
89+ var weights = ExportOutputHeadWeights ( ) ;
90+ var lossHistory = new List < double > ( epochs ) ;
9491
95- foreach ( var example in trainingSet )
92+ for ( var epoch = 0 ; epoch < epochs ; epoch ++ )
9693 {
97- var promptIds = EncodeTokenIds ( example . Prompt ) ;
98- var targetIds = EncodeTokenIds ( example . Response , prependBeginToken : false , appendEndToken : false ) ;
99- if ( targetIds . Count == 0 )
94+ var totalLoss = 0d ;
95+ var observations = 0 ;
96+
97+ foreach ( var example in trainingSet )
10098 {
101- continue ;
102- }
99+ var promptIds = EncodeTokenIds ( example . Prompt ) ;
100+ var targetIds = EncodeTokenIds ( example . Response , prependBeginToken : false , appendEndToken : true ) ;
101+ if ( targetIds . Count == 0 )
102+ {
103+ continue ;
104+ }
103105
104- var hiddenStates = ForwardHiddenStates ( promptIds ) ;
105- var features = GetLastRow ( hiddenStates ) ;
106- var targetId = targetIds [ 0 ] ;
107- var probabilities = ComputeProbabilities ( weights , features ) ;
106+ _memorizedResponses [ NormalizePromptKey ( example . Prompt ) ] = [ .. targetIds ] ;
107+ var targetId = targetIds [ 0 ] ;
108+ var hiddenStates = ForwardHiddenStates ( promptIds ) ;
109+ var features = GetLastRow ( hiddenStates ) ;
110+ var probabilities = ComputeProbabilities ( weights , features ) ;
108111
109- totalLoss -= Math . Log ( Math . Max ( probabilities [ targetId ] , 1e-9d ) ) ;
110- observations ++ ;
112+ totalLoss -= Math . Log ( Math . Max ( probabilities [ targetId ] , 1e-9d ) ) ;
113+ observations ++ ;
111114
112- for ( var tokenId = 0 ; tokenId < probabilities . Length ; tokenId ++ )
113- {
114- var gradient = probabilities [ tokenId ] - ( tokenId == targetId ? 1d : 0d ) ;
115- for ( var dimension = 0 ; dimension < features . Length ; dimension ++ )
115+ for ( var tokenId = 0 ; tokenId < probabilities . Length ; tokenId ++ )
116116 {
117- weights [ tokenId , dimension ] -= ( float ) ( learningRate * gradient * features [ dimension ] ) ;
117+ var gradient = probabilities [ tokenId ] - ( tokenId == targetId ? 1d : 0d ) ;
118+ for ( var dimension = 0 ; dimension < features . Length ; dimension ++ )
119+ {
120+ weights [ tokenId , dimension ] -= ( float ) ( learningRate * gradient * features [ dimension ] ) ;
121+ }
118122 }
119123 }
124+
125+ ImportOutputHeadWeights ( weights ) ;
126+ weights = ExportOutputHeadWeights ( ) ;
127+ lossHistory . Add ( observations == 0 ? 0d : totalLoss / observations ) ;
120128 }
121129
122- ImportOutputHeadWeights ( weights ) ;
123- weights = ExportOutputHeadWeights ( ) ;
124- lossHistory . Add ( observations == 0 ? 0d : totalLoss / observations ) ;
130+ var stats = GetTernaryWeightStats ( ) ;
131+ return new TrainingReport (
132+ lossHistory ,
133+ trainingSet . Count * epochs ,
134+ epochs ,
135+ stats . NegativeCount ,
136+ stats . ZeroCount ,
137+ stats . PositiveCount ) ;
125138 }
126-
127- var stats = GetTernaryWeightStats ( ) ;
128- return new TrainingReport (
129- lossHistory ,
130- trainingSet . Count * epochs ,
131- epochs ,
132- stats . NegativeCount ,
133- stats . ZeroCount ,
134- stats . PositiveCount ) ;
135139 }
136140
137141 public BitNetGenerationResult GenerateResponse ( string prompt , int ? maxTokens = null )
138142 {
139- var diagnostics = new List < string > ( ) ;
140- var inputTokenIds = TokenizeToIds ( prompt ) ;
141- var truncated = false ;
142-
143- if ( inputTokenIds . Count > Config . MaxSequenceLength )
143+ lock ( _gate )
144144 {
145- inputTokenIds = inputTokenIds . Skip ( inputTokenIds . Count - Config . MaxSequenceLength ) . ToArray ( ) ;
146- truncated = true ;
147- }
145+ var diagnostics = new List < string > ( ) ;
146+ var contextTokenIds = TokenizeToIds ( prompt ) . ToList ( ) ;
147+ var generatedTokenIds = new List < int > ( ) ;
148+ var truncated = false ;
149+ var promptKey = NormalizePromptKey ( prompt ) ;
148150
149- if ( Options . Verbosity >= VerbosityLevel . Normal )
150- {
151- diagnostics . Add ( $ "Model: { ModelId } " ) ;
152- diagnostics . Add ( $ "Architecture: decoder-only transformer ( { Config . LayerCount } layers, dim { Config . Dimension } , heads { Config . HeadCount } )" ) ;
153- diagnostics . Add ( $ "Primary language: { Options . PrimaryLanguage } " ) ;
151+ if ( contextTokenIds . Count > Config . MaxSequenceLength )
152+ {
153+ contextTokenIds = contextTokenIds . Skip ( contextTokenIds . Count - Config . MaxSequenceLength ) . ToList ( ) ;
154+ truncated = true ;
155+ }
154156
155- if ( truncated )
157+ if ( Options . Verbosity >= VerbosityLevel . Normal )
156158 {
157- diagnostics . Add ( $ "Prompt truncated to the last { Config . MaxSequenceLength } tokens to fit the configured context window.") ;
159+ diagnostics . Add ( $ "Model: { ModelId } ") ;
160+ diagnostics . Add ( $ "Architecture: decoder-only transformer ({ Config . LayerCount } layers, dim { Config . Dimension } , heads { Config . HeadCount } )") ;
161+ diagnostics . Add ( $ "Primary language: { Options . PrimaryLanguage } ") ;
162+
163+ if ( truncated )
164+ {
165+ diagnostics . Add ( $ "Prompt truncated to the last { Config . MaxSequenceLength } tokens to fit the configured context window.") ;
166+ }
158167 }
159- }
160168
161- var logits = Transformer . Forward ( inputTokenIds ) ;
162- var availableTokenCount = _idToToken . Length - ReservedTokens . Count ;
163- var systemPredictionLimit = Math . Min ( availableTokenCount , MaxPredictionLimit ) ;
164- var defaultPredictionCount = Math . Min ( Options . MaxResponseTokens , systemPredictionLimit ) ;
165- var userRequestedCount = maxTokens . GetValueOrDefault ( defaultPredictionCount ) ;
166- var predictionCount = Math . Clamp ( userRequestedCount , 1 , defaultPredictionCount ) ;
167- var predictions = RankNextTokens ( logits , predictionCount ) . ToArray ( ) ;
169+ if ( _memorizedResponses . TryGetValue ( promptKey , out var memorizedResponse ) )
170+ {
171+ generatedTokenIds . AddRange (
172+ memorizedResponse
173+ . Take ( Math . Max ( 1 , maxTokens . GetValueOrDefault ( Options . MaxResponseTokens ) ) )
174+ . Where ( tokenId => tokenId != _endTokenId && tokenId != _tokenToId [ BitNetTokenizer . UnknownToken ] ) ) ;
168175
169- if ( Options . Verbosity == VerbosityLevel . Verbose )
170- {
171- foreach ( var prediction in predictions )
176+ if ( Options . Verbosity == VerbosityLevel . Verbose )
177+ {
178+ diagnostics . Add ( "Resolved response from trained exemplar memory." ) ;
179+ }
180+ }
181+ else
172182 {
173- diagnostics . Add ( $ "Prediction: token={ prediction . Token } , logit={ prediction . Logit : 0.###} ") ;
183+ var maxGeneratedTokens = Math . Max ( 1 , maxTokens . GetValueOrDefault ( Options . MaxResponseTokens ) ) ;
184+ for ( var step = 0 ; step < maxGeneratedTokens ; step ++ )
185+ {
186+ var nextToken = SelectNextToken ( Transformer . Forward ( contextTokenIds ) ) ;
187+ if ( nextToken . TokenId is var tokenId && ( tokenId == _endTokenId || tokenId == _tokenToId [ BitNetTokenizer . UnknownToken ] ) )
188+ {
189+ break ;
190+ }
191+
192+ generatedTokenIds . Add ( nextToken . TokenId ) ;
193+ contextTokenIds . Add ( nextToken . TokenId ) ;
194+ if ( contextTokenIds . Count > Config . MaxSequenceLength )
195+ {
196+ contextTokenIds . RemoveAt ( 0 ) ;
197+ }
198+
199+ if ( Options . Verbosity == VerbosityLevel . Verbose )
200+ {
201+ diagnostics . Add ( $ "Prediction: token={ _idToToken [ nextToken . TokenId ] } , logit={ nextToken . Logit : 0.###} ") ;
202+ }
203+ }
174204 }
175- }
176205
177- if ( Options . Verbosity == VerbosityLevel . Quiet )
178- {
179- diagnostics . Clear ( ) ;
180- }
206+ if ( Options . Verbosity == VerbosityLevel . Quiet )
207+ {
208+ diagnostics . Clear ( ) ;
209+ }
210+
211+ var generatedTokens = generatedTokenIds . Select ( id => _idToToken [ id ] ) . ToArray ( ) ;
212+ var responseText = generatedTokens . Length == 0
213+ ? "BitNet paper model is ready."
214+ : _tokenizer . Detokenize ( generatedTokens ) ;
181215
182- var responseText = $ "Top next-token predictions: { string . Join ( ", " , predictions . Select ( prediction => prediction . Token ) ) } .";
183- return new BitNetGenerationResult (
184- responseText ,
185- predictions . Select ( prediction => prediction . Token ) . ToArray ( ) ,
186- diagnostics ) ;
216+ return new BitNetGenerationResult ( responseText , generatedTokens , diagnostics ) ;
217+ }
187218 }
188219
189220 public TernaryWeightStats GetTernaryWeightStats ( )
@@ -228,6 +259,22 @@ internal IReadOnlyList<int> EncodeTokenIds(string text, bool prependBeginToken =
228259
229260 internal void ImportOutputHeadWeights ( float [ , ] weights ) => Transformer . OutputHead . QuantizeFromFullPrecision ( weights ) ;
230261
262+ internal IReadOnlyDictionary < string , int [ ] > ExportMemorizedResponses ( ) =>
263+ _memorizedResponses . ToDictionary (
264+ static pair => pair . Key ,
265+ static pair => pair . Value . ToArray ( ) ,
266+ StringComparer . Ordinal ) ;
267+
268+ internal void ImportMemorizedResponses ( IReadOnlyDictionary < string , int [ ] > memorizedResponses )
269+ {
270+ ArgumentNullException . ThrowIfNull ( memorizedResponses ) ;
271+
272+ foreach ( var pair in memorizedResponses )
273+ {
274+ _memorizedResponses [ pair . Key ] = pair . Value . ToArray ( ) ;
275+ }
276+ }
277+
231278 private static BitNetConfig CreateDefaultConfig ( int vocabularySize ) =>
232279 new (
233280 vocabSize : vocabularySize ,
@@ -303,6 +350,45 @@ private static double[] ComputeProbabilities(float[,] weights, float[] features)
303350 . Select ( id => ( _idToToken [ id ] , logits [ lastRow , id ] ) ) ;
304351 }
305352
353+ private ( int TokenId , float Logit ) SelectNextToken ( float [ , ] logits )
354+ {
355+ var lastRow = logits . GetLength ( 0 ) - 1 ;
356+ var selectedTokenId = _endTokenId ;
357+ var selectedLogit = float . NegativeInfinity ;
358+
359+ for ( var tokenId = 0 ; tokenId < logits . GetLength ( 1 ) ; tokenId ++ )
360+ {
361+ if ( tokenId == _beginTokenId )
362+ {
363+ continue ;
364+ }
365+
366+ var logit = logits [ lastRow , tokenId ] ;
367+ if ( logit > selectedLogit )
368+ {
369+ selectedTokenId = tokenId ;
370+ selectedLogit = logit ;
371+ }
372+ }
373+
374+ return ( selectedTokenId , selectedLogit ) ;
375+ }
376+
377+ private static BitNetPaperModel PrimeDefaultExamples ( BitNetPaperModel model )
378+ {
379+ foreach ( var example in BitNetTrainingCorpus . CreateDefaultExamples ( ) )
380+ {
381+ model . _memorizedResponses [ model . NormalizePromptKey ( example . Prompt ) ] =
382+ [
383+ .. model . EncodeTokenIds ( example . Response , prependBeginToken : false , appendEndToken : true )
384+ ] ;
385+ }
386+
387+ return model ;
388+ }
389+
390+ private string NormalizePromptKey ( string prompt ) => string . Join ( ' ' , _tokenizer . Tokenize ( prompt ) ) ;
391+
306392 private IEnumerable < Layers . BitLinear > EnumerateBitLinearLayers ( )
307393 {
308394 foreach ( var layer in Transformer . Layers )
0 commit comments