33// See the LICENSE file in the project root for more information.
44
55using System ;
6+ using System . Buffers ;
67using System . Collections . Generic ;
78using System . IO ;
9+ using System . Linq ;
810using System . Runtime . CompilerServices ;
911using System . Text . Json ;
1012using System . Text . Json . Serialization ;
@@ -34,20 +36,21 @@ private set
3436 {
3537 _unknownToken = value ;
3638
37- if ( value is null )
39+ if ( VocabReverse . TryGetValue ( 0 , out string ? v ) )
3840 {
39- if ( VocabReverse . TryGetValue ( 0 , out string ? v ) )
41+ if ( v == value )
4042 {
41- VocabReverse . Remove ( 0 ) ;
42- if ( _vocab . TryGetValue ( v , out int id ) )
43- {
44- _vocab . Remove ( v ) ;
45- }
43+ return ;
4644 }
45+
46+ VocabReverse . Remove ( 0 ) ;
47+ _vocab . Remove ( new StringSpanOrdinalKey ( v ) ) ;
4748 }
48- else
49+
50+
51+ if ( value is not null )
4952 {
50- _vocab [ value ] = 0 ;
53+ _vocab [ new StringSpanOrdinalKey ( value ) ] = 0 ;
5154 VocabReverse [ 0 ] = value ;
5255 }
5356 }
@@ -68,7 +71,6 @@ private set
6871 /// </summary>
6972 public bool FuseUnknownTokens { get ; }
7073
71-
7274 /// <summary>
7375 /// Construct a new Bpe model object to use for text encoding.
7476 /// </summary>
@@ -111,23 +113,19 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri
111113 ContinuingSubwordPrefix = continuingSubwordPrefix ;
112114 EndOfWordSuffix = endOfWordSuffix ;
113115
114- ( Dictionary < string , int > ? vocab1 , Vec < ( string , string ) > merges ) = ReadModelData ( vocabStream , mergesStream ) ;
115- _vocab = vocab1 ?? new Dictionary < string , int > ( ) ;
116- Cache = new Cache < string , Word > ( ) ;
116+ ( Dictionary < StringSpanOrdinalKey , int > ? vocab1 , Vec < ( string , string ) > merges ) = ReadModelData ( vocabStream , mergesStream ) ;
117+ _vocab = vocab1 ?? new Dictionary < StringSpanOrdinalKey , int > ( ) ;
118+ Cache = new StringSpanOrdinalKeyCache < Word > ( ) ;
117119
118120 VocabReverse = new ( ) ;
119121
120- foreach ( KeyValuePair < string , int > kvp in Vocab )
122+ foreach ( KeyValuePair < StringSpanOrdinalKey , int > kvp in _vocab )
121123 {
122- VocabReverse . Add ( kvp . Value , kvp . Key ) ;
124+ VocabReverse . Add ( kvp . Value , kvp . Key . Data ! ) ;
123125 }
124126
125- if ( unknownToken is null && VocabReverse . TryGetValue ( 0 , out string ? unkToken ) )
126- {
127- unknownToken = unkToken ;
128- }
129127
130- UnknownToken = unknownToken ;
128+ UnknownToken = unknownToken ?? ( VocabReverse . TryGetValue ( 0 , out string ? unkToken ) ? unkToken : null ) ;
131129
132130 int prefixLen = ContinuingSubwordPrefix is null ? 0 : ContinuingSubwordPrefix . Length ;
133131
@@ -197,31 +195,23 @@ public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = f
197195 /// <param name="text">The text to split.</param>
198196 /// <param name="isSpecialToken">Indicate if the token is a special token.</param>
199197 /// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
200- public override void EncodeToIds ( string text , bool isSpecialToken , IList < int > accumulatedIds ) => EncodeToIdsWithCache ( text , accumulatedIds ) ;
198+ public override void EncodeToIds ( ReadOnlySpan < char > text , bool isSpecialToken , IList < int > accumulatedIds ) => EncodeToIdsWithCache ( text , accumulatedIds ) ;
201199
202200 /// <summary>
203201 /// Get the number of tokens that the input text will be encoded to.
204202 /// </summary>
205203 /// <param name="text">The text to encode.</param>
206204 /// <param name="isSpecialToken">Indicate if the token is special token.</param>
207205 /// <returns>The number of tokens that the input text will be encoded to.</returns>
208- public override int CountTokens ( string text , bool isSpecialToken ) => EncodeToIdsWithCache ( text , null ) ;
206+ public override int CountTokens ( ReadOnlySpan < char > text , bool isSpecialToken ) => EncodeToIdsWithCache ( text , null ) ;
209207
210208 /// <summary>
211209 /// Map the token to encoded Id.
212210 /// </summary>
213211 /// <param name="token">The token to map to the Id.</param>
214212 /// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
215213 /// <returns>The mapped Id of the token.</returns>
216- public override int ? MapTokenToId ( string token , bool considerSpecialTokens = true )
217- {
218- if ( _vocab . TryGetValue ( token , out int value ) )
219- {
220- return value ;
221- }
222-
223- return null ;
224- }
214+ public override int ? MapTokenToId ( ReadOnlySpan < char > token , bool considerSpecialTokens = true ) => _vocab . TryGetValue ( token , out int value ) ? value : null ;
225215
226216 /// <summary>
227217 /// Map the encoded Id to the token.
@@ -242,24 +232,27 @@ public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = f
242232 /// <summary>
243233 /// Gets the dictionary mapping tokens to Ids.
244234 /// </summary>
245- public IReadOnlyDictionary < string , int > Vocab => _vocab ;
235+ public IReadOnlyDictionary < string , int > Vocab => _vocabOriginal ??= _vocab . ToDictionary ( kvp => kvp . Key . Data ! , kvp => kvp . Value ) ;
246236
247237 /// Read the given files to extract the vocab and merges
248- internal static ( Dictionary < string , int > ? , Vec < ( string , string ) > ) ReadModelData ( Stream vocab , Stream ? merges )
238+ internal static ( Dictionary < StringSpanOrdinalKey , int > ? , Vec < ( string , string ) > ) ReadModelData ( Stream vocab , Stream ? merges )
249239 {
250- Dictionary < string , int > ? dic = JsonSerializer . Deserialize < Dictionary < string , int > > ( vocab ) as Dictionary < string , int > ;
240+ JsonSerializerOptions options = new ( ) { Converters = { StringSpanOrdinalKeyConverter . Instance } } ;
241+ Dictionary < StringSpanOrdinalKey , int > ? dic = JsonSerializer . Deserialize < Dictionary < StringSpanOrdinalKey , int > > ( vocab , options ) as Dictionary < StringSpanOrdinalKey , int > ;
251242
252243 return ( dic , ConvertMergesToHashmap ( merges ) ) ;
253244 }
254245
255246 /// The vocabulary assigns a number to each token.
256- private readonly Dictionary < string , int > _vocab ;
247+ private readonly Dictionary < StringSpanOrdinalKey , int > _vocab ;
248+
249+ private Dictionary < string , int > ? _vocabOriginal ;
257250
258251 /// Contains the mapping between Pairs and their (rank, newId).
259252 internal Dictionary < Pair < int > , ( int , int ) > Merges { get ; }
260253
261254 /// Contains the cache for optimizing the encoding step.
262- internal Cache < string , Word > ? Cache { get ; }
255+ internal StringSpanOrdinalKeyCache < Word > ? Cache { get ; }
263256
264257 internal static readonly int DefaultCacheCapacity = 10_000 ;
265258
@@ -309,9 +302,6 @@ internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadModelData(
309302 return merges ;
310303 }
311304
312- /// Reset the cache.
313- internal void ClearCache ( ) => Cache ? . Clear ( ) ;
314-
315305 private readonly Dictionary < char , string > _charToString = new Dictionary < char , string > ( ) ;
316306
317307 [ MethodImpl ( MethodImplOptions . AggressiveInlining ) ]
@@ -327,38 +317,68 @@ internal string CharToString(char c)
327317 return s ;
328318 }
329319
330- internal Word MergeWord ( string w )
320+ internal Word MergeWord ( ReadOnlySpan < char > w )
331321 {
332322 Word word = Word . WithCapacity ( w . Length ) ;
333323 ( int Id , int Len ) ? unk = null ;
334324 int i = 0 ;
335325
326+ Span < char > buffer = stackalloc char [ 256 ] ;
327+ scoped ReadOnlySpan < char > s ;
328+
336329 while ( i < w . Length )
337330 {
338331 int length ;
339- string s ;
340332
341333 if ( Char . IsHighSurrogate ( w [ i ] ) && i < w . Length - 1 && Char . IsLowSurrogate ( w [ i + 1 ] ) )
342334 {
343335 length = 2 ;
344- s = w . Substring ( i , length ) ;
336+ s = w . Slice ( i , 2 ) ;
345337 }
346338 else
347339 {
348340 length = 1 ;
349- s = CharToString ( w [ i ] ) ;
341+ s = w . Slice ( i , 1 ) ;
350342 }
351343
352344 // Add the `continuing_subword_prefix` if relevant
353345 if ( i > 0 && ContinuingSubwordPrefix is not null )
354346 {
355- s = $ "{ ContinuingSubwordPrefix } { s } ";
347+ if ( ContinuingSubwordPrefix . Length + s . Length <= buffer . Length )
348+ {
349+ ContinuingSubwordPrefix . AsSpan ( ) . CopyTo ( buffer ) ;
350+ s . CopyTo ( buffer . Slice ( ContinuingSubwordPrefix . Length ) ) ;
351+ s = buffer . Slice ( 0 , ContinuingSubwordPrefix . Length + s . Length ) ;
352+ }
353+ else
354+ {
355+ #if NETCOREAPP
356+ s = $ "{ ContinuingSubwordPrefix } { s } ". AsSpan ( ) ;
357+ #else
358+ string s1 = s . Length == 1 ? CharToString ( s [ 0 ] ) : s . ToString ( ) ;
359+ s = $ "{ ContinuingSubwordPrefix } { s1 } ". AsSpan ( ) ;
360+ #endif
361+ }
356362 }
357363
358364 // Add the `end_of_word_suffix` if relevant
359365 if ( i + length >= w . Length && EndOfWordSuffix is not null )
360366 {
361- s = $ "{ s } { EndOfWordSuffix } ";
367+ if ( s . Length + EndOfWordSuffix . Length <= buffer . Length )
368+ {
369+ s . CopyTo ( buffer ) ;
370+ EndOfWordSuffix . AsSpan ( ) . CopyTo ( buffer . Slice ( s . Length ) ) ;
371+ s = buffer . Slice ( 0 , s . Length + EndOfWordSuffix . Length ) ;
372+ }
373+ else
374+ {
375+ #if NETCOREAPP
376+ s = $ "{ s } { EndOfWordSuffix } ". AsSpan ( ) ;
377+ #else
378+ string s1 = s . Length == 1 ? CharToString ( s [ 0 ] ) : s . ToString ( ) ;
379+ s = $ "{ s1 } { EndOfWordSuffix } ". AsSpan ( ) ;
380+ #endif
381+ }
362382 }
363383
364384 if ( _vocab . TryGetValue ( s , out int id ) )
@@ -419,17 +439,17 @@ internal List<Token> EncodeWithCache(string text)
419439 Word word ;
420440 if ( Cache is not null )
421441 {
422- if ( Cache . TryGet ( text , out word ) )
442+ if ( Cache . TryGetValue ( text , out word ) )
423443 {
424444 return WordToTokens ( ref word ) ;
425445 }
426446
427- word = MergeWord ( text ) ;
447+ word = MergeWord ( text . AsSpan ( ) ) ;
428448 Cache . Set ( text , word ) ;
429449 }
430450 else
431451 {
432- word = MergeWord ( text ) ;
452+ word = MergeWord ( text . AsSpan ( ) ) ;
433453 }
434454
435455 return WordToTokens ( ref word ) ;
@@ -445,19 +465,19 @@ internal int WordToIds(ref Word word, IList<int>? accumulatedIds)
445465 return word . SymbolsCount ;
446466 }
447467
448- internal int EncodeToIdsWithCache ( string text , IList < int > ? accumulatedIds )
468+ internal int EncodeToIdsWithCache ( ReadOnlySpan < char > text , IList < int > ? accumulatedIds )
449469 {
450470 Word word ;
451471
452472 if ( Cache is not null )
453473 {
454- if ( Cache . TryGet ( text , out Word hit ) )
474+ if ( Cache . TryGetValue ( text , out Word hit ) )
455475 {
456476 return WordToIds ( ref hit , accumulatedIds ) ;
457477 }
458478
459479 word = MergeWord ( text ) ;
460- Cache . Set ( text , word ) ;
480+ Cache . Set ( text . ToString ( ) , word ) ;
461481 }
462482 else
463483 {
0 commit comments