@@ -21,6 +21,15 @@ def bos_id(self):
2121 def eos_id (self ):
2222 raise NotImplementedError ("This method should be overridden by subclasses." )
2323
24+ def id_to_piece (self , token_id ):
25+ raise NotImplementedError ("This method should be overridden by subclasses." )
26+
27+ def piece_to_id (self , token_str ):
28+ raise NotImplementedError ("This method should be overridden by subclasses." )
29+
30+ def is_special_token (self , token_id ):
31+ raise NotImplementedError ("This method should be overridden by subclasses." )
32+
2433class SentencePieceWrapper (TokenizerInterface ):
2534 def __init__ (self , model_path ):
2635 super ().__init__ (model_path )
@@ -38,6 +47,17 @@ def bos_id(self):
3847 def eos_id (self ):
3948 return self .processor .eos_id ()
4049
50+ def id_to_piece (self , token_id ):
51+ return self .processor .id_to_piece (token_id ).replace ("▁" , " " )
52+
53+ def piece_to_id (self , token_str ):
54+ return self .processor .piece_to_id (token_str .replace (" " , "▁" ))
55+
56+ def is_special_token (self , token_id ):
57+ return self .processor .IsControl (token_id ) \
58+ or self .processor .IsUnknown (token_id ) \
59+ or self .processor .IsUnused (token_id )
60+
4161class TiktokenWrapper (TokenizerInterface ):
4262 """
4363 Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
@@ -53,7 +73,7 @@ def __init__(self, model_path):
5373 super ().__init__ (model_path )
5474 assert os .path .isfile (model_path ), str (model_path )
5575 mergeable_ranks = load_tiktoken_bpe (str (model_path ))
56- num_base_tokens = len (mergeable_ranks )
76+ self . num_base_tokens = len (mergeable_ranks )
5777 special_tokens = [
5878 "<|begin_of_text|>" ,
5979 "<|end_of_text|>" ,
@@ -70,7 +90,7 @@ def __init__(self, model_path):
7090 for i in range (5 , self .num_reserved_special_tokens - 5 )
7191 ]
7292 self .special_tokens = {
73- token : num_base_tokens + i for i , token in enumerate (special_tokens )
93+ token : self . num_base_tokens + i for i , token in enumerate (special_tokens )
7494 }
7595 self .model = tiktoken .Encoding (
7696 name = Path (model_path ).name ,
@@ -94,6 +114,15 @@ def bos_id(self):
94114 def eos_id (self ):
95115 return self ._eos_id
96116
117+ def id_to_piece (self , token_id ):
118+ return self .model .decode ([token_id ])
119+
120+ def piece_to_id (self , token_str ):
121+ return self .model .encode_single_token (token_str )
122+
123+ def is_special_token (self , token_id ):
124+ return token_id >= self .num_base_tokens and token_id < self .num_base_tokens + len (self .special_tokens )
125+
97126def get_tokenizer (tokenizer_model_path , model_name ):
98127 """
99128 Factory function to get the appropriate tokenizer based on the model name.
0 commit comments