22
33import org .beehive .gpullama3 .tokenizer .impl .Qwen3Tokenizer ;
44
5- import java .util .ArrayList ;
6- import java .util .List ;
7- import java .util .Map ;
8- import java .util .Set ;
5+ import java .util .*;
96
107/**
118 * Utility tailored for the Chat Markup Language (ChatML) prompt format.
@@ -42,9 +39,9 @@ public Qwen3ChatFormat(Qwen3Tokenizer tokenizer, ChatTokens chatTokens) {
4239 this .imStart = startHeader ;
4340 this .imEnd = endHeader ;
4441
45- fimPrefix = specialTokens .getOrDefault ("<|fim_prefix|>" , -1 );
46- fimSuffix = specialTokens .getOrDefault ("<|fim_suffix|>" , -1 );
47- fimMiddle = specialTokens .getOrDefault ("<|fim_middle|>" , -1 );
42+ this . fimPrefix = specialTokens .getOrDefault ("<|fim_prefix|>" , -1 );
43+ this . fimSuffix = specialTokens .getOrDefault ("<|fim_suffix|>" , -1 );
44+ this . fimMiddle = specialTokens .getOrDefault ("<|fim_middle|>" , -1 );
4845 }
4946
5047 public ChatTokens chatTokens () {
@@ -66,7 +63,7 @@ public List<Integer> encodeHeader(Message message) {
6663 default -> null ;
6764 };
6865 if (sToken != null ) {
69- Integer token = tokenizer .getSpecialTokens ().get ("<|User|>" );
66+ Integer token = tokenizer .getSpecialTokens ().get (sToken );
7067 if (token == null ) {
7168 throw new IllegalStateException (String .format ("Unknown token '%s'" , sToken ));
7269 }
@@ -80,19 +77,23 @@ public List<Integer> encodeHeader(Message message) {
8077 } else if (Role .FIM_MIDDLE .equals (message .role ())) {
8178 tokens .add (fimMiddle );
8279 } else {
80+ // Add the special token directly, don't try to encode it
8381 tokens .add (imStart );
84- tokens .addAll (this .tokenizer .encodeAsList (message .role ().name ()));
85- tokens .addAll (this .tokenizer .encodeAsList ("\n " ));
82+ // Encode the role name as ordinary text (no special tokens in role names)
83+ tokens .addAll (this .tokenizer .encodeOrdinaryAsList (message .role ().name ()));
84+ tokens .addAll (this .tokenizer .encodeOrdinaryAsList ("\n " ));
8685 }
8786 return tokens ;
8887 }
8988
9089 @ Override
9190 public List <Integer > encodeMessage (Message message ) {
9291 List <Integer > tokens = this .encodeHeader (message );
93- tokens .addAll (this .tokenizer .encodeAsList (message .content ().strip ()));
92+ // Encode message content as ordinary text
93+ tokens .addAll (this .tokenizer .encodeOrdinaryAsList (message .content ().strip ()));
9494 boolean isFim = Role .FIM_PREFIX .equals (message .role ()) || Role .FIM_SUFFIX .equals (message .role ()) || Role .FIM_MIDDLE .equals (message .role ());
9595 if (imEnd != -1 && !isFim ) {
96+ // Add the end token directly
9697 tokens .add (imEnd );
9798 }
9899 return tokens ;
@@ -108,9 +109,19 @@ public Set<Integer> getStopTokens() {
108109 if (imEnd == -1 && endOfText == -1 ) {
109110 throw new IllegalStateException ("No stop token is defined." );
110111 }
111- if (imEnd == -1 ) {
112- return Set .of (endOfText );
112+
113+ // Only add valid token IDs (not -1)
114+ Set <Integer > stopTokens = new HashSet <>();
115+ if (imEnd != -1 ) {
116+ stopTokens .add (imEnd );
117+ }
118+ if (endOfText != -1 ) {
119+ stopTokens .add (endOfText );
113120 }
114- return Set .of (imEnd , endOfText , endOfTextFim );
121+ if (endOfTextFim != -1 ) {
122+ stopTokens .add (endOfTextFim );
123+ }
124+
125+ return stopTokens ;
115126 }
116127}
0 commit comments