@@ -112,7 +112,7 @@ private Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTo
112112 /// <param name="cacheSize">The size of the cache to use.</param>
113113 /// <param name="normalizer">To normalize the text before tokenization</param>
114114 /// <returns>The tokenizer</returns>
115- public static Tokenizer CreateByModelName (
115+ public static Tokenizer CreateTokenizerForModel (
116116 string modelName ,
117117 Stream vocabStream ,
118118 IReadOnlyDictionary < string , int > ? extraSpecialTokens = null ,
@@ -124,7 +124,7 @@ public static Tokenizer CreateByModelName(
124124 throw new ArgumentNullException ( nameof ( modelName ) ) ;
125125 }
126126
127- ( Dictionary < string , int > SpecialTokens , Regex Regex ) tiktokenConfiguration = GetTiktokenConfigurations ( modelName ) ;
127+ ( Dictionary < string , int > SpecialTokens , Regex Regex , string _ ) tiktokenConfiguration = GetTiktokenConfigurations ( modelName ) ;
128128
129129 if ( extraSpecialTokens is not null )
130130 {
@@ -150,7 +150,7 @@ public static Tokenizer CreateByModelName(
150150 /// <param name="normalizer">To normalize the text before tokenization</param>
151151 /// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
152152 /// <returns>The tokenizer</returns>
153- public static async Task < Tokenizer > CreateByModelNameAsync (
153+ public static async Task < Tokenizer > CreateTokenizerForModelAsync (
154154 string modelName ,
155155 Stream vocabStream ,
156156 IReadOnlyDictionary < string , int > ? extraSpecialTokens = null ,
@@ -163,7 +163,7 @@ public static async Task<Tokenizer> CreateByModelNameAsync(
163163 throw new ArgumentNullException ( nameof ( modelName ) ) ;
164164 }
165165
166- ( Dictionary < string , int > SpecialTokens , Regex Regex ) tiktokenConfiguration = GetTiktokenConfigurations ( modelName ) ;
166+ ( Dictionary < string , int > SpecialTokens , Regex Regex , string _ ) tiktokenConfiguration = GetTiktokenConfigurations ( modelName ) ;
167167
168168 if ( extraSpecialTokens is not null )
169169 {
@@ -738,31 +738,30 @@ private static ModelEncoding GetModelEncoding(string modelName)
738738 return encoder ;
739739 }
740740
741- internal static ( Dictionary < string , int > SpecialTokens , Regex Regex ) GetTiktokenConfigurations ( string modelName )
741+ internal static ( Dictionary < string , int > SpecialTokens , Regex Regex , string Url ) GetTiktokenConfigurations ( string modelName )
742742 {
743743 ModelEncoding modelEncoding = GetModelEncoding ( modelName ) ;
744744
745745 switch ( modelEncoding )
746746 {
747747 case ModelEncoding . Cl100kBase :
748748 return ( new Dictionary < string , int >
749- { { EndOfText , 100257 } , { FimPrefix , 100258 } , { FimMiddle , 100259 } , { FimSuffix , 100260 } , { EndOfPrompt , 100276 } } , Cl100kBaseRegex ( ) ) ;
749+ { { EndOfText , 100257 } , { FimPrefix , 100258 } , { FimMiddle , 100259 } , { FimSuffix , 100260 } , { EndOfPrompt , 100276 } } , Cl100kBaseRegex ( ) , Cl100kBaseVocabUrl ) ;
750750
751751 case ModelEncoding . P50kBase :
752- return ( new Dictionary < string , int > { { EndOfText , 50256 } } , P50kBaseRegex ( ) ) ;
752+ return ( new Dictionary < string , int > { { EndOfText , 50256 } } , P50kBaseRegex ( ) , P50RanksUrl ) ;
753753
754754 case ModelEncoding . P50kEdit :
755755 return ( new Dictionary < string , int >
756- { { EndOfText , 50256 } , { FimPrefix , 50281 } , { FimMiddle , 50282 } , { FimSuffix , 50283 } } , P50kBaseRegex ( ) ) ;
756+ { { EndOfText , 50256 } , { FimPrefix , 50281 } , { FimMiddle , 50282 } , { FimSuffix , 50283 } } , P50kBaseRegex ( ) , P50RanksUrl ) ;
757757
758758 case ModelEncoding . R50kBase :
759- return ( new Dictionary < string , int > { { EndOfText , 50256 } } , P50kBaseRegex ( ) ) ;
759+ return ( new Dictionary < string , int > { { EndOfText , 50256 } } , P50kBaseRegex ( ) , R50RanksUrl ) ;
760760
761761 case ModelEncoding . GPT2 :
762- return ( new Dictionary < string , int > { { EndOfText , 50256 } , } , P50kBaseRegex ( ) ) ;
762+ return ( new Dictionary < string , int > { { EndOfText , 50256 } , } , P50kBaseRegex ( ) , GPT2Url ) ;
763763
764764 default :
765- Debug . Assert ( false , $ "Unexpected encoder [{ modelEncoding } ]") ;
766765 throw new NotSupportedException ( $ "The model '{ modelName } ' is not supported.") ;
767766 }
768767 }
@@ -775,22 +774,64 @@ internal static (Dictionary<string, int> SpecialTokens, Regex Regex) GetTiktoken
775774 /// <param name="normalizer">To normalize the text before tokenization</param>
776775 /// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
777776 /// <returns>The tokenizer</returns>
778- public static Task < Tokenizer > CreateByModelNameAsync (
777+ public static Task < Tokenizer > CreateTokenizerForModelAsync (
779778 string modelName ,
780779 IReadOnlyDictionary < string , int > ? extraSpecialTokens = null ,
781780 Normalizer ? normalizer = null ,
782781 CancellationToken cancellationToken = default )
783782 {
784783 try
785784 {
786- return CreateByEncoderNameAsync ( modelName , GetModelEncoding ( modelName ) , extraSpecialTokens , normalizer , cancellationToken ) ;
785+ return CreateByEncoderNameAsync ( GetModelEncoding ( modelName ) , extraSpecialTokens , normalizer , cancellationToken ) ;
787786 }
788787 catch ( Exception ex )
789788 {
790789 return Task . FromException < Tokenizer > ( ex ) ;
791790 }
792791 }
793792
793+ /// <summary>
794+ /// Create tokenizer based on model name
795+ /// </summary>
796+ /// <param name="modelName">Model name</param>
797+ /// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the model</param>
798+ /// <param name="normalizer">To normalize the text before tokenization</param>
799+ /// <returns>The tokenizer</returns>
800+ public static Tokenizer CreateTokenizerForModel (
801+ string modelName ,
802+ IReadOnlyDictionary < string , int > ? extraSpecialTokens = null ,
803+ Normalizer ? normalizer = null )
804+ {
805+ if ( string . IsNullOrEmpty ( modelName ) )
806+ {
807+ throw new ArgumentNullException ( nameof ( modelName ) ) ;
808+ }
809+
810+ ( Dictionary < string , int > SpecialTokens , Regex Regex , string Url ) tiktokenConfiguration = GetTiktokenConfigurations ( modelName ) ;
811+
812+ if ( extraSpecialTokens is not null )
813+ {
814+ foreach ( var extraSpecialToken in extraSpecialTokens )
815+ {
816+ tiktokenConfiguration . SpecialTokens . Add ( extraSpecialToken . Key , extraSpecialToken . Value ) ;
817+ }
818+ }
819+
820+ if ( ! _tiktokenCache . TryGetValue ( tiktokenConfiguration . Url ,
821+ out ( Dictionary < ReadOnlyMemory < byte > , int > encoder , Dictionary < StringSpanOrdinalKey , int > vocab , Dictionary < int , ReadOnlyMemory < byte > > decoder ) cache ) )
822+ {
823+ using Stream stream = Helpers . GetStream ( _httpClient , tiktokenConfiguration . Url ) ;
824+ cache = LoadTikTokenBpeAsync ( stream , useAsync : false ) . GetAwaiter ( ) . GetResult ( ) ;
825+
826+ _tiktokenCache . TryAdd ( tiktokenConfiguration . Url , cache ) ;
827+ }
828+
829+ return new Tokenizer (
830+ new Tiktoken ( cache . encoder , cache . decoder , cache . vocab , tiktokenConfiguration . SpecialTokens , LruCache < int [ ] > . DefaultCacheSize ) ,
831+ new TikTokenPreTokenizer ( tiktokenConfiguration . Regex , tiktokenConfiguration . SpecialTokens ) ,
832+ normalizer ) ;
833+ }
834+
794835 // Regex patterns based on https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py
795836
796837 private const string Cl100kBaseRegexPattern = /*lang=regex*/ @"'(?i:[sdmt]|re|ve|ll)|(?>[^\r\n\p{L}\p{N}]?)\p{L}+|\p{N}{1,3}| ?(?>[^\s\p{L}\p{N}]+)[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+" ;
@@ -818,15 +859,13 @@ public static Task<Tokenizer> CreateByModelNameAsync(
818859 /// <summary>
819860 /// Create tokenizer based on encoder name and extra special tokens
820861 /// </summary>
821- /// <param name="modelName">Model name</param>
822862 /// <param name="modelEncoding">Encoder label</param>
823863 /// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the encoder</param>
824864 /// <param name="normalizer">To normalize the text before tokenization</param>
825865 /// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
826866 /// <returns>The tokenizer</returns>
827867 /// <exception cref="NotSupportedException">Throws if the model name is not supported</exception>
828868 private static Task < Tokenizer > CreateByEncoderNameAsync (
829- string modelName ,
830869 ModelEncoding modelEncoding ,
831870 IReadOnlyDictionary < string , int > ? extraSpecialTokens ,
832871 Normalizer ? normalizer ,
@@ -857,8 +896,7 @@ private static Task<Tokenizer> CreateByEncoderNameAsync(
857896 return CreateTikTokenTokenizerAsync ( P50kBaseRegex ( ) , GPT2Url , specialTokens , extraSpecialTokens , normalizer , cancellationToken ) ;
858897
859898 default :
860- Debug . Assert ( false , $ "Unexpected encoder [{ modelEncoding } ]") ;
861- throw new NotSupportedException ( $ "The model '{ modelName } ' is not supported.") ;
899+ throw new NotSupportedException ( $ "The encoder '{ modelEncoding } ' is not supported.") ;
862900 }
863901 }
864902
@@ -894,7 +932,7 @@ private static async Task<Tokenizer> CreateTikTokenTokenizerAsync(
894932 {
895933 using ( Stream stream = await Helpers . GetStreamAsync ( _httpClient , mergeableRanksFileUrl , cancellationToken ) . ConfigureAwait ( false ) )
896934 {
897- cache = await Tiktoken . LoadTikTokenBpeAsync ( stream , useAsync : true , cancellationToken ) . ConfigureAwait ( false ) ;
935+ cache = await LoadTikTokenBpeAsync ( stream , useAsync : true , cancellationToken ) . ConfigureAwait ( false ) ;
898936 }
899937
900938 _tiktokenCache . TryAdd ( mergeableRanksFileUrl , cache ) ;
0 commit comments