1+ import os
2+ import sentencepiece as spm
3+ import tiktoken
4+ from tiktoken .load import load_tiktoken_bpe
5+ from pathlib import Path
6+ from typing import Dict
7+
8+ class TokenizerInterface :
9+ def __init__ (self , model_path ):
10+ self .model_path = model_path
11+
12+ def encode (self , text ):
13+ raise NotImplementedError ("This method should be overridden by subclasses." )
14+
15+ def decode (self , tokens ):
16+ raise NotImplementedError ("This method should be overridden by subclasses." )
17+
18+ def bos_id (self ):
19+ raise NotImplementedError ("This method should be overridden by subclasses." )
20+
21+ def eos_id (self ):
22+ raise NotImplementedError ("This method should be overridden by subclasses." )
23+
24+ class SentencePieceWrapper (TokenizerInterface ):
25+ def __init__ (self , model_path ):
26+ super ().__init__ (model_path )
27+ self .processor = spm .SentencePieceProcessor (str (model_path ))
28+
29+ def encode (self , text ):
30+ return self .processor .EncodeAsIds (text )
31+
32+ def decode (self , tokens ):
33+ return self .processor .DecodeIds (tokens )
34+
35+ def bos_id (self ):
36+ return self .processor .bos_id ()
37+
38+ def eos_id (self ):
39+ return self .processor .eos_id ()
40+
41+ class TiktokenWrapper (TokenizerInterface ):
42+ """
43+ Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
44+ """
45+
46+ special_tokens : Dict [str , int ]
47+
48+ num_reserved_special_tokens = 256
49+
50+ pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501
51+
52+ def __init__ (self , model_path ):
53+ super ().__init__ (model_path )
54+ assert os .path .isfile (model_path ), str (model_path )
55+ mergeable_ranks = load_tiktoken_bpe (str (model_path ))
56+ num_base_tokens = len (mergeable_ranks )
57+ special_tokens = [
58+ "<|begin_of_text|>" ,
59+ "<|end_of_text|>" ,
60+ "<|reserved_special_token_0|>" ,
61+ "<|reserved_special_token_1|>" ,
62+ "<|reserved_special_token_2|>" ,
63+ "<|reserved_special_token_3|>" ,
64+ "<|start_header_id|>" ,
65+ "<|end_header_id|>" ,
66+ "<|reserved_special_token_4|>" ,
67+ "<|eot_id|>" , # end of turn
68+ ] + [
69+ f"<|reserved_special_token_{ i } |>"
70+ for i in range (5 , self .num_reserved_special_tokens - 5 )
71+ ]
72+ self .special_tokens = {
73+ token : num_base_tokens + i for i , token in enumerate (special_tokens )
74+ }
75+ self .model = tiktoken .Encoding (
76+ name = Path (model_path ).name ,
77+ pat_str = self .pat_str ,
78+ mergeable_ranks = mergeable_ranks ,
79+ special_tokens = self .special_tokens ,
80+ )
81+ # BOS / EOS token IDs
82+ self ._bos_id : int = self .special_tokens ["<|begin_of_text|>" ]
83+ self ._eos_id : int = self .special_tokens ["<|end_of_text|>" ]
84+
85+ def encode (self , text ):
86+ return self .model .encode (text )
87+
88+ def decode (self , tokens ):
89+ return self .model .decode (tokens )
90+
91+ def bos_id (self ):
92+ return self ._bos_id
93+
94+ def eos_id (self ):
95+ return self ._eos_id
96+
97+ def get_tokenizer (tokenizer_model_path , model_name ):
98+ """
99+ Factory function to get the appropriate tokenizer based on the model name.
100+
101+ Args:
102+ - tokenizer_model_path (str): The file path to the tokenizer model.
103+ - model_name (str): The name of the model, used to determine the tokenizer type.
104+
105+ Returns:
106+ - TokenizerInterface: An instance of a tokenizer.
107+ """
108+ if "Llama-3" in str (model_name ):
109+ return TiktokenWrapper (tokenizer_model_path )
110+ else :
111+ return SentencePieceWrapper (tokenizer_model_path )
0 commit comments