Skip to content

Commit acb4cf3

Browse files
committed
feature claude count token
1 parent e3e4cf5 commit acb4cf3

File tree

6 files changed

+413
-33
lines changed

6 files changed

+413
-33
lines changed

src/lib/tokenizer.ts

Lines changed: 339 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,351 @@
1-
import { countTokens } from "gpt-tokenizer/model/gpt-4o"
1+
import type {
2+
ChatCompletionsPayload,
3+
ContentPart,
4+
Message,
5+
Tool,
6+
ToolCall,
7+
} from "~/services/copilot/create-chat-completions"
8+
import type { Model } from "~/services/copilot/get-models"
29

3-
import type { Message } from "~/services/copilot/create-chat-completions"
10+
// Encoder type mapping
11+
const ENCODING_MAP = {
12+
o200k_base: () => import("gpt-tokenizer/encoding/o200k_base"),
13+
cl100k_base: () => import("gpt-tokenizer/encoding/cl100k_base"),
14+
p50k_base: () => import("gpt-tokenizer/encoding/p50k_base"),
15+
p50k_edit: () => import("gpt-tokenizer/encoding/p50k_edit"),
16+
r50k_base: () => import("gpt-tokenizer/encoding/r50k_base"),
17+
} as const
418

5-
export const getTokenCount = (messages: Array<Message>) => {
6-
const simplifiedMessages = messages.map((message) => {
7-
let content = ""
8-
if (typeof message.content === "string") {
9-
content = message.content
10-
} else if (Array.isArray(message.content)) {
11-
content = message.content
12-
.filter((part) => part.type === "text")
13-
.map((part) => (part as { text: string }).text)
14-
.join("")
19+
type SupportedEncoding = keyof typeof ENCODING_MAP
20+
21+
// Define encoder interface
22+
interface Encoder {
23+
encode: (text: string) => Array<number>
24+
}
25+
26+
// Cache loaded encoders to avoid repeated imports
27+
const encodingCache = new Map<string, Encoder>()
28+
29+
/**
30+
* Calculate tokens for tool calls
31+
*/
32+
const calculateToolCallsTokens = (
33+
toolCalls: Array<ToolCall>,
34+
encoder: Encoder,
35+
constants: ReturnType<typeof getModelConstants>,
36+
): number => {
37+
let tokens = 0
38+
for (const toolCall of toolCalls) {
39+
tokens += constants.funcInit
40+
tokens += encoder.encode(toolCall.id).length
41+
tokens += encoder.encode(toolCall.type).length
42+
tokens += encoder.encode(toolCall.function.name).length
43+
tokens += encoder.encode(toolCall.function.arguments).length
44+
}
45+
tokens += constants.funcEnd
46+
return tokens
47+
}
48+
49+
/**
50+
* Calculate tokens for content parts
51+
*/
52+
const calculateContentPartsTokens = (
53+
contentParts: Array<ContentPart>,
54+
encoder: Encoder,
55+
): number => {
56+
let tokens = 0
57+
for (const part of contentParts) {
58+
if (part.type === "image_url") {
59+
tokens += encoder.encode(part.image_url.url).length + 85
60+
if (part.image_url.detail === "high") {
61+
tokens += 85
62+
}
63+
} else if (part.text) {
64+
tokens += encoder.encode(part.text).length
65+
}
66+
}
67+
return tokens
68+
}
69+
70+
/**
71+
* Calculate tokens for a single message
72+
*/
73+
const calculateMessageTokens = (
74+
message: Message,
75+
encoder: Encoder,
76+
constants: ReturnType<typeof getModelConstants>,
77+
): number => {
78+
const tokensPerMessage = 3
79+
const tokensPerName = 1
80+
let tokens = tokensPerMessage
81+
for (const [key, value] of Object.entries(message)) {
82+
if (typeof value === "string") {
83+
tokens += encoder.encode(value).length
84+
}
85+
if (key === "name") {
86+
tokens += tokensPerName
87+
}
88+
if (key === "tool_calls") {
89+
tokens += calculateToolCallsTokens(
90+
value as Array<ToolCall>,
91+
encoder,
92+
constants,
93+
)
94+
}
95+
if (key === "content" && Array.isArray(value)) {
96+
tokens += calculateContentPartsTokens(
97+
value as Array<ContentPart>,
98+
encoder,
99+
)
100+
}
101+
}
102+
return tokens
103+
}
104+
105+
/**
106+
* Calculate tokens using custom algorithm
107+
*/
108+
const calculateTokens = (
109+
messages: Array<Message>,
110+
encoder: Encoder,
111+
constants: ReturnType<typeof getModelConstants>,
112+
): number => {
113+
if (messages.length === 0) {
114+
return 0
115+
}
116+
let numTokens = 0
117+
for (const message of messages) {
118+
numTokens += calculateMessageTokens(message, encoder, constants)
119+
}
120+
// every reply is primed with <|start|>assistant<|message|>
121+
numTokens += 3
122+
return numTokens
123+
}
124+
125+
/**
126+
* Get the corresponding encoder module based on encoding type
127+
*/
128+
const getEncodeChatFunction = async (encoding: string): Promise<Encoder> => {
129+
if (encodingCache.has(encoding)) {
130+
const cached = encodingCache.get(encoding)
131+
if (cached) {
132+
return cached
133+
}
134+
}
135+
136+
const supportedEncoding = encoding as SupportedEncoding
137+
if (!(supportedEncoding in ENCODING_MAP)) {
138+
const fallbackModule = (await ENCODING_MAP.o200k_base()) as Encoder
139+
encodingCache.set(encoding, fallbackModule)
140+
return fallbackModule
141+
}
142+
143+
const encodingModule = (await ENCODING_MAP[supportedEncoding]()) as Encoder
144+
encodingCache.set(encoding, encodingModule)
145+
return encodingModule
146+
}
147+
148+
/**
149+
* Get tokenizer type from model information
150+
*/
151+
export const getTokenizerFromModel = (model: Model): string => {
152+
return model.capabilities.tokenizer || "o200k_base"
153+
}
154+
155+
/**
156+
* Get model-specific constants for token calculation
157+
*/
158+
const getModelConstants = (model: Model) => {
159+
return model.id === "gpt-3.5-turbo" || model.id === "gpt-4" ?
160+
{
161+
funcInit: 10,
162+
propInit: 3,
163+
propKey: 3,
164+
enumInit: -3,
165+
enumItem: 3,
166+
funcEnd: 12,
167+
}
168+
: {
169+
funcInit: 7,
170+
propInit: 3,
171+
propKey: 3,
172+
enumInit: -3,
173+
enumItem: 3,
174+
funcEnd: 12,
175+
}
176+
}
177+
178+
/**
179+
* Calculate tokens for a single parameter
180+
*/
181+
const calculateParameterTokens = (
182+
key: string,
183+
prop: unknown,
184+
context: {
185+
encoder: Encoder
186+
constants: ReturnType<typeof getModelConstants>
187+
},
188+
): number => {
189+
const { encoder, constants } = context
190+
let tokens = constants.propKey
191+
192+
// Early return if prop is not an object
193+
if (typeof prop !== "object" || prop === null) {
194+
return tokens
195+
}
196+
197+
// Type assertion for parameter properties
198+
const param = prop as {
199+
type?: string
200+
description?: string
201+
enum?: Array<unknown>
202+
[key: string]: unknown
203+
}
204+
205+
const paramName = key
206+
const paramType = param.type || "string"
207+
let paramDesc = param.description || ""
208+
209+
// Handle enum values
210+
if (param.enum && Array.isArray(param.enum)) {
211+
tokens += constants.enumInit
212+
for (const item of param.enum) {
213+
tokens += constants.enumItem
214+
tokens += encoder.encode(String(item)).length
15215
}
16-
return { ...message, content }
17-
})
216+
}
217+
218+
// Clean up description
219+
if (paramDesc.endsWith(".")) {
220+
paramDesc = paramDesc.slice(0, -1)
221+
}
222+
223+
// Encode the main parameter line
224+
const line = `${paramName}:${paramType}:${paramDesc}`
225+
tokens += encoder.encode(line).length
18226

19-
let inputMessages = simplifiedMessages.filter((message) => {
20-
return message.role !== "tool"
21-
})
22-
let outputMessages: typeof simplifiedMessages = []
227+
// Handle additional properties (excluding standard ones)
228+
const excludedKeys = new Set(["type", "description", "enum"])
229+
for (const propertyName of Object.keys(param)) {
230+
if (!excludedKeys.has(propertyName)) {
231+
const propertyValue = param[propertyName]
232+
const propertyText =
233+
typeof propertyValue === "string" ? propertyValue : (
234+
JSON.stringify(propertyValue)
235+
)
236+
tokens += encoder.encode(`${propertyName}:${propertyText}`).length
237+
}
238+
}
23239

24-
const lastMessage = simplifiedMessages.at(-1)
240+
return tokens
241+
}
25242

26-
if (lastMessage?.role === "assistant") {
27-
inputMessages = simplifiedMessages.slice(0, -1)
28-
outputMessages = [lastMessage]
243+
/**
244+
* Calculate tokens for function parameters
245+
*/
246+
const calculateParametersTokens = (
247+
parameters: unknown,
248+
encoder: Encoder,
249+
constants: ReturnType<typeof getModelConstants>,
250+
): number => {
251+
if (!parameters || typeof parameters !== "object") {
252+
return 0
29253
}
30254

31-
// @ts-expect-error TS can't infer from arr.filter()
32-
const inputTokens = countTokens(inputMessages)
33-
// @ts-expect-error TS can't infer from arr.filter()
34-
const outputTokens = countTokens(outputMessages)
255+
const params = parameters as Record<string, unknown>
256+
let tokens = 0
35257

258+
for (const [key, value] of Object.entries(params)) {
259+
if (key === "properties") {
260+
const properties = value as Record<string, unknown>
261+
if (Object.keys(properties).length > 0) {
262+
tokens += constants.propInit
263+
for (const propKey of Object.keys(properties)) {
264+
tokens += calculateParameterTokens(propKey, properties[propKey], {
265+
encoder,
266+
constants,
267+
})
268+
}
269+
}
270+
} else {
271+
const paramText =
272+
typeof value === "string" ? value : JSON.stringify(value)
273+
tokens += encoder.encode(`${key}:${paramText}`).length
274+
}
275+
}
276+
277+
return tokens
278+
}
279+
280+
/**
281+
* Calculate tokens for a single tool
282+
*/
283+
const calculateToolTokens = (
284+
tool: Tool,
285+
encoder: Encoder,
286+
constants: ReturnType<typeof getModelConstants>,
287+
): number => {
288+
let tokens = constants.funcInit
289+
const func = tool.function
290+
const fName = func.name
291+
let fDesc = func.description || ""
292+
if (fDesc.endsWith(".")) {
293+
fDesc = fDesc.slice(0, -1)
294+
}
295+
const line = fName + ":" + fDesc
296+
tokens += encoder.encode(line).length
297+
if (
298+
typeof func.parameters === "object" // eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
299+
&& func.parameters !== null
300+
) {
301+
tokens += calculateParametersTokens(func.parameters, encoder, constants)
302+
}
303+
return tokens
304+
}
305+
306+
/**
307+
* Calculate token count for tools based on model
308+
*/
309+
export const numTokensForTools = (
310+
tools: Array<Tool>,
311+
encoder: Encoder,
312+
constants: ReturnType<typeof getModelConstants>,
313+
): number => {
314+
let funcTokenCount = 0
315+
for (const tool of tools) {
316+
funcTokenCount += calculateToolTokens(tool, encoder, constants)
317+
}
318+
funcTokenCount += constants.funcEnd
319+
return funcTokenCount
320+
}
321+
322+
/**
323+
* Calculate the token count of messages, supporting multiple GPT encoders
324+
*/
325+
export const getTokenCount = async (
326+
payload: ChatCompletionsPayload,
327+
model: Model,
328+
): Promise<{ input: number; output: number }> => {
329+
// Get tokenizer string
330+
const tokenizer = getTokenizerFromModel(model)
331+
332+
// Get corresponding encoder module
333+
const encoder = await getEncodeChatFunction(tokenizer)
334+
335+
const simplifiedMessages = payload.messages
336+
const inputMessages = simplifiedMessages.filter(
337+
(msg) => msg.role !== "assistant",
338+
)
339+
const outputMessages = simplifiedMessages.filter(
340+
(msg) => msg.role === "assistant",
341+
)
342+
343+
const constants = getModelConstants(model)
344+
let inputTokens = calculateTokens(inputMessages, encoder, constants)
345+
if (payload.tools && payload.tools.length > 0) {
346+
inputTokens += numTokensForTools(payload.tools, encoder, constants)
347+
}
348+
const outputTokens = calculateTokens(outputMessages, encoder, constants)
36349
return {
37350
input: inputTokens,
38351
output: outputTokens,

0 commit comments

Comments
 (0)