@@ -204,18 +204,16 @@ def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
204204 self .version : int = int (_mistral_version_str .split ("v" )[- 1 ])
205205
206206 tokenizer_ = tokenizer .instruct_tokenizer .tokenizer
207- from mistral_common .tokens .tokenizers .tekken import (
208- SpecialTokenPolicy , Tekkenizer )
207+ from mistral_common .tokens .tokenizers .base import SpecialTokenPolicy
208+ from mistral_common .tokens .tokenizers .tekken import Tekkenizer
209+
209210 self .is_tekken = isinstance (tokenizer_ , Tekkenizer )
210211 from mistral_common .tokens .tokenizers .sentencepiece import (
211212 SentencePieceTokenizer )
212213 self .is_spm = isinstance (tokenizer_ , SentencePieceTokenizer )
213- if self .is_tekken :
214- # Make sure special tokens will not raise
215- tokenizer_ .special_token_policy = SpecialTokenPolicy .IGNORE
216- elif self .is_spm :
217- pass
218- else :
214+ self ._special_token_policy = (SpecialTokenPolicy .IGNORE
215+ if self .is_tekken else None )
216+ if not (self .is_tekken or self .is_spm ):
219217 raise TypeError (f"Unsupported tokenizer: { type (tokenizer_ )} " )
220218
221219 self ._vocab = tokenizer_ .vocab ()
@@ -430,7 +428,8 @@ def _token_to_id(t: str):
430428 return self .tokenizer .unk_id
431429
432430 ids = [_token_to_id (t ) for t in tokens ]
433- decoded = self .tokenizer .decode (ids )
431+ decoded = self .tokenizer .decode (ids ,
432+ self ._special_token_policy )
434433 else :
435434 decoded = "" .join (tokens )
436435 else :
@@ -444,15 +443,17 @@ def _token_to_id(t: str):
444443 if token in special_tokens :
445444 if regular_tokens :
446445 decoded_list .append (
447- self .tokenizer .decode (regular_tokens ))
446+ self .tokenizer .decode (regular_tokens ,
447+ self ._special_token_policy ))
448448 regular_tokens = []
449449 decoded_list .append (token )
450450 else :
451451 regular_tokens .append (token )
452452
453453 if regular_tokens :
454454 decoded_list .append (
455- self .tokenizer .decode (regular_tokens )) # type: ignore
455+ self .tokenizer .decode (regular_tokens ,
456+ self ._special_token_policy ))
456457
457458 decoded = '' .join (decoded_list )
458459
@@ -470,7 +471,7 @@ def decode(self,
470471
471472 if isinstance (ids , int ):
472473 ids = [ids ]
473- return self .tokenizer .decode (ids )
474+ return self .tokenizer .decode (ids , self . _special_token_policy )
474475
475476 def convert_ids_to_tokens (
476477 self ,
@@ -511,6 +512,9 @@ def convert_ids_to_tokens(
511512 # See: https://github.com/vllm-project/vllm/pull/8640
512513 # https://github.com/vllm-project/vllm/pull/9625
513514 # if underlying tokenizeir is sentencepiece, we just add "�"
514- tokens = [self .tokenizer .id_to_byte_piece (id ) for id in ids ]
515+ tokens = [
516+ self .tokenizer .id_to_byte_piece (id , self ._special_token_policy )
517+ for id in ids
518+ ]
515519
516520 return tokens
0 commit comments