Skip to content

Commit 99d5366

Browse files
authored
Merge pull request #53 from beehive-lab/demo-integration
Add support for encoding ordinary text in Qwen3Tokenizer and update Q…
2 parents 788c11e + 02b8541 commit 99d5366

File tree

2 files changed

+41
-14
lines changed

2 files changed

+41
-14
lines changed

src/main/java/org/beehive/gpullama3/model/format/Qwen3ChatFormat.java

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22

33
import org.beehive.gpullama3.tokenizer.impl.Qwen3Tokenizer;
44

5-
import java.util.ArrayList;
6-
import java.util.List;
7-
import java.util.Map;
8-
import java.util.Set;
5+
import java.util.*;
96

107
/**
118
* Utility tailored for the Chat Markup Language (ChatML) prompt format.
@@ -42,9 +39,9 @@ public Qwen3ChatFormat(Qwen3Tokenizer tokenizer, ChatTokens chatTokens) {
4239
this.imStart = startHeader;
4340
this.imEnd = endHeader;
4441

45-
fimPrefix = specialTokens.getOrDefault("<|fim_prefix|>", -1);
46-
fimSuffix = specialTokens.getOrDefault("<|fim_suffix|>", -1);
47-
fimMiddle = specialTokens.getOrDefault("<|fim_middle|>", -1);
42+
this.fimPrefix = specialTokens.getOrDefault("<|fim_prefix|>", -1);
43+
this.fimSuffix = specialTokens.getOrDefault("<|fim_suffix|>", -1);
44+
this.fimMiddle = specialTokens.getOrDefault("<|fim_middle|>", -1);
4845
}
4946

5047
public ChatTokens chatTokens() {
@@ -66,7 +63,7 @@ public List<Integer> encodeHeader(Message message) {
6663
default -> null;
6764
};
6865
if (sToken != null) {
69-
Integer token = tokenizer.getSpecialTokens().get("<|User|>");
66+
Integer token = tokenizer.getSpecialTokens().get(sToken);
7067
if (token == null) {
7168
throw new IllegalStateException(String.format("Unknown token '%s'", sToken));
7269
}
@@ -80,19 +77,23 @@ public List<Integer> encodeHeader(Message message) {
8077
} else if (Role.FIM_MIDDLE.equals(message.role())) {
8178
tokens.add(fimMiddle);
8279
} else {
80+
// Add the special token directly, don't try to encode it
8381
tokens.add(imStart);
84-
tokens.addAll(this.tokenizer.encodeAsList(message.role().name()));
85-
tokens.addAll(this.tokenizer.encodeAsList("\n"));
82+
// Encode the role name as ordinary text (no special tokens in role names)
83+
tokens.addAll(this.tokenizer.encodeOrdinaryAsList(message.role().name()));
84+
tokens.addAll(this.tokenizer.encodeOrdinaryAsList("\n"));
8685
}
8786
return tokens;
8887
}
8988

9089
@Override
9190
public List<Integer> encodeMessage(Message message) {
9291
List<Integer> tokens = this.encodeHeader(message);
93-
tokens.addAll(this.tokenizer.encodeAsList(message.content().strip()));
92+
// Encode message content as ordinary text
93+
tokens.addAll(this.tokenizer.encodeOrdinaryAsList(message.content().strip()));
9494
boolean isFim = Role.FIM_PREFIX.equals(message.role()) || Role.FIM_SUFFIX.equals(message.role()) || Role.FIM_MIDDLE.equals(message.role());
9595
if (imEnd != -1 && !isFim) {
96+
// Add the end token directly
9697
tokens.add(imEnd);
9798
}
9899
return tokens;
@@ -108,9 +109,19 @@ public Set<Integer> getStopTokens() {
108109
if (imEnd == -1 && endOfText == -1) {
109110
throw new IllegalStateException("No stop token is defined.");
110111
}
111-
if (imEnd == -1) {
112-
return Set.of(endOfText);
112+
113+
// Only add valid token IDs (not -1)
114+
Set<Integer> stopTokens = new HashSet<>();
115+
if (imEnd != -1) {
116+
stopTokens.add(imEnd);
117+
}
118+
if (endOfText != -1) {
119+
stopTokens.add(endOfText);
113120
}
114-
return Set.of(imEnd, endOfText, endOfTextFim);
121+
if (endOfTextFim != -1) {
122+
stopTokens.add(endOfTextFim);
123+
}
124+
125+
return stopTokens;
115126
}
116127
}

src/main/java/org/beehive/gpullama3/tokenizer/impl/Qwen3Tokenizer.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,27 @@ public List<Integer> encode(String text, Set<String> allowedSpecial) {
271271
}
272272
// @formatter:on
273273

274+
/**
275+
* Encode text as ordinary tokens (no special token handling)
276+
*/
277+
public List<Integer> encodeOrdinaryAsList(String text) {
278+
// First convert to byte-encoded unicode representation
279+
StringBuilder sb = new StringBuilder();
280+
byte[] bytes = text.getBytes(StandardCharsets.UTF_8);
281+
for (byte b : bytes) {
282+
sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b)));
283+
}
284+
// Then encode using BPE
285+
return encodeOrdinary(sb.toString());
286+
}
287+
274288
@Override
275289
public List<Integer> encodeAsList(String text) {
276290
return Arrays.stream(encode(text)).boxed().toList();
277291
}
278292

293+
294+
279295
public String decodeImpl(List<Integer> tokens) {
280296
StringBuilder sb = new StringBuilder();
281297
for (int token : tokens) {

0 commit comments

Comments
 (0)