From 7cba319ad9b59c79d775c0bb477dbecfd7386689 Mon Sep 17 00:00:00 2001 From: joshistoast Date: Thu, 30 Oct 2025 02:32:10 -0600 Subject: [PATCH 1/5] feat(prompts): add abstract syntax tree (AST) builder for prompts --- .../frontend/web/src/common/util/promptAST.ts | 264 ++++++++++++++++++ 1 file changed, 264 insertions(+) create mode 100644 invokeai/frontend/web/src/common/util/promptAST.ts diff --git a/invokeai/frontend/web/src/common/util/promptAST.ts b/invokeai/frontend/web/src/common/util/promptAST.ts new file mode 100644 index 00000000000..861c3eacea8 --- /dev/null +++ b/invokeai/frontend/web/src/common/util/promptAST.ts @@ -0,0 +1,264 @@ +/** + * Expected as either '+', '-', '++', '--', etc. or a numeric string like '1.2', '0.8', etc. + */ +export type Attention = string | number; + +type Word = string; + +type Punct = string; + +type Whitespace = string; + +type Embedding = string; + +export type Token = + | { type: 'word'; value: Word } + | { type: 'whitespace'; value: Whitespace } + | { type: 'punct'; value: Punct } + | { type: 'lparen' } + | { type: 'rparen' } + | { type: 'weight'; value: Attention } + | { type: 'lembed' } + | { type: 'rembed' }; + +export type ASTNode = + | { type: 'word'; text: Word; attention?: Attention } + | { type: 'group'; children: ASTNode[]; attention?: Attention } + | { type: 'embedding'; value: Embedding } + | { type: 'whitespace'; value: Whitespace } + | { type: 'punct'; value: Punct }; + +/** + * Convert a prompt string into an AST. + * @param prompt string + * @returns ASTNode[] + */ +export function tokenize(prompt: string): Token[] { + if (!prompt) { + return []; + } + + let i = 0; + let tokens: Token[] = []; + + while (i < prompt.length) { + const char = prompt[i]; + if (!char) { + break; + } + + // Whitespace (including newlines) + if (/\s/.test(char)) { + tokens.push({ type: 'whitespace', value: char }); + i++; + continue; + } + + // Parentheses + if (char === '(') { + tokens.push({ type: 'lparen' }); + i++; + continue; + } + + if (char === ')') { + // Look ahead for weight like ')1.1' or ')-0.9' or ')+' or ')-' + const weightMatch = prompt.slice(i + 1).match(/^[+-]?(\d+(\.\d+)?|[+-]+)/); + if (weightMatch && weightMatch[0]) { + let weight: Attention = weightMatch[0]; + if (!isNaN(Number(weight))) { + weight = Number(weight); + } + tokens.push({ type: 'rparen' }); + tokens.push({ type: 'weight', value: weight }); + i += 1 + weightMatch[0].length; + continue; + } + tokens.push({ type: 'rparen' }); + i++; + continue; + } + + // Handle punctuation (comma, period, etc.) + if (/[,.]/.test(char)) { + tokens.push({ type: 'punct', value: char }); + i++; + continue; + } + + // Read a word (letters, digits, underscores) + if (/[a-zA-Z0-9_]/.test(char)) { + let j = i; + while (j < prompt.length && /[a-zA-Z0-9_]/.test(prompt[j]!)) { + j++; + } + const word = prompt.slice(i, j); + tokens.push({ type: 'word', value: word }); + + // Check for weight immediately after word (e.g., "Lorem+", "consectetur-") + const weightMatch = prompt.slice(j).match(/^[+-]?(\d+(\.\d+)?|[+-]+)/); + if (weightMatch && weightMatch[0]) { + tokens.push({ type: 'weight', value: weightMatch[0] }); + i = j + weightMatch[0].length; + } else { + i = j; + } + continue; + } + + // Embeddings + if (char === '<') { + tokens.push({ type: 'lembed' }); + i++; + continue; + } + + if (char === '>') { + tokens.push({ type: 'rembed' }); + i++; + continue; + } + + // Any other single character punctuation + if (!/\s/.test(char)) { + tokens.push({ type: 'punct', value: char }); + } + + i++; + } + + return tokens; +} + +/** + * Convert tokens into an AST. + * @param tokens Token[] + * @returns ASTNode[] + */ +export function parseTokens(tokens: Token[]): ASTNode[] { + let pos = 0; + + function peek(): Token | undefined { + return tokens[pos]; + } + + function consume(): Token | undefined { + return tokens[pos++]; + } + + function parseGroup(): ASTNode[] { + const nodes: ASTNode[] = []; + + while (pos < tokens.length) { + const token = peek(); + if (!token || token.type === 'rparen') { + break; + } + // console.log('Parsing token:', token); + + switch (token.type) { + case 'whitespace': { + const wsToken = consume() as Token & { type: 'whitespace' }; + nodes.push({ type: 'whitespace', value: wsToken.value }); + break; + } + case 'lparen': { + consume(); + const groupChildren = parseGroup(); + + let attention: Attention | undefined; + if (peek()?.type === 'rparen') { + consume(); // consume ')' + if (peek()?.type === 'weight') { + attention = (consume() as Token & { type: 'weight' }).value; + } + } + + nodes.push({ type: 'group', children: groupChildren, attention }); + break; + } + case 'lembed': { + consume(); // consume '<' + let embedValue = ''; + while (peek() && peek()!.type !== 'rembed') { + const embedToken = consume()!; + embedValue += + embedToken.type === 'word' || embedToken.type === 'punct' || embedToken.type === 'whitespace' + ? embedToken.value + : ''; + } + if (peek()?.type === 'rembed') { + consume(); // consume '>' + } + nodes.push({ type: 'embedding', value: embedValue.trim() }); + break; + } + case 'word': { + const wordToken = consume() as Token & { type: 'word' }; + let attention: Attention | undefined; + + // Check for immediate weight after word + if (peek()?.type === 'weight') { + attention = (consume() as Token & { type: 'weight' }).value; + } + + nodes.push({ type: 'word', text: wordToken.value, attention }); + break; + } + case 'punct': { + const punctToken = consume() as Token & { type: 'punct' }; + nodes.push({ type: 'punct', value: punctToken.value }); + break; + } + default: { + consume(); + } + } + } + + return nodes; + } + + return parseGroup(); +} + +/** + * Convert an AST back into a prompt string. + * @param ast ASTNode[] + * @returns string + */ +export function serialize(ast: ASTNode[]): string { + let prompt = ''; + + for (const node of ast) { + switch (node.type) { + case 'punct': + case 'whitespace': { + prompt += node.value; + break; + } + case 'word': { + prompt += node.text; + if (node.attention) { + prompt += String(node.attention); + } + break; + } + case 'group': { + prompt += '('; + prompt += serialize(node.children); + prompt += ')'; + if (node.attention) { + prompt += String(node.attention); + } + break; + } + case 'embedding': { + prompt += `<${node.value}>`; + break; + } + } + } + + return prompt; +} From df95e1fed08145405c37f4bbb52bc7fea0eaba5f Mon Sep 17 00:00:00 2001 From: joshistoast Date: Sat, 1 Nov 2025 17:35:10 -0600 Subject: [PATCH 2/5] fix(prompts): add escaped parens to AST --- .../frontend/web/src/common/util/promptAST.ts | 25 +++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/src/common/util/promptAST.ts b/invokeai/frontend/web/src/common/util/promptAST.ts index 861c3eacea8..698e5adf5dc 100644 --- a/invokeai/frontend/web/src/common/util/promptAST.ts +++ b/invokeai/frontend/web/src/common/util/promptAST.ts @@ -19,14 +19,16 @@ export type Token = | { type: 'rparen' } | { type: 'weight'; value: Attention } | { type: 'lembed' } - | { type: 'rembed' }; + | { type: 'rembed' } + | { type: 'escaped_paren'; value: '(' | ')' }; export type ASTNode = | { type: 'word'; text: Word; attention?: Attention } | { type: 'group'; children: ASTNode[]; attention?: Attention } | { type: 'embedding'; value: Embedding } | { type: 'whitespace'; value: Whitespace } - | { type: 'punct'; value: Punct }; + | { type: 'punct'; value: Punct } + | { type: 'escaped_paren'; value: '(' | ')' }; /** * Convert a prompt string into an AST. @@ -54,6 +56,16 @@ export function tokenize(prompt: string): Token[] { continue; } + // Escaped parentheses (e.g., \( or \)) + if (char === '\\' && i + 1 < prompt.length) { + const nextChar = prompt[i + 1]; + if (nextChar === '(' || nextChar === ')') { + tokens.push({ type: 'escaped_paren', value: nextChar }); + i += 2; + continue; + } + } + // Parentheses if (char === '(') { tokens.push({ type: 'lparen' }); @@ -210,6 +222,11 @@ export function parseTokens(tokens: Token[]): ASTNode[] { nodes.push({ type: 'punct', value: punctToken.value }); break; } + case 'escaped_paren': { + const escapedToken = consume() as Token & { type: 'escaped_paren' }; + nodes.push({ type: 'escaped_paren', value: escapedToken.value }); + break; + } default: { consume(); } @@ -237,6 +254,10 @@ export function serialize(ast: ASTNode[]): string { prompt += node.value; break; } + case 'escaped_paren': { + prompt += `\\${node.value}`; + break; + } case 'word': { prompt += node.text; if (node.attention) { From a5694aa699bf7ba7bd365b09bc2e2ef8d883375c Mon Sep 17 00:00:00 2001 From: joshistoast Date: Sat, 1 Nov 2025 17:36:31 -0600 Subject: [PATCH 3/5] test(prompts): add AST tests --- .../web/src/common/util/promptAST.test.ts | 270 ++++++++++++++++++ 1 file changed, 270 insertions(+) create mode 100644 invokeai/frontend/web/src/common/util/promptAST.test.ts diff --git a/invokeai/frontend/web/src/common/util/promptAST.test.ts b/invokeai/frontend/web/src/common/util/promptAST.test.ts new file mode 100644 index 00000000000..d958b3915a0 --- /dev/null +++ b/invokeai/frontend/web/src/common/util/promptAST.test.ts @@ -0,0 +1,270 @@ +import { describe, expect, it } from 'vitest'; +import { parseTokens, serialize, tokenize } from './promptAST'; + +describe('promptAST', () => { + describe('tokenize', () => { + it('should tokenize basic text', () => { + const tokens = tokenize('a cat'); + expect(tokens).toEqual([ + { type: 'word', value: 'a' }, + { type: 'whitespace', value: ' ' }, + { type: 'word', value: 'cat' }, + ]); + }); + + it('should tokenize groups with parentheses', () => { + const tokens = tokenize('(a cat)'); + expect(tokens).toEqual([ + { type: 'lparen' }, + { type: 'word', value: 'a' }, + { type: 'whitespace', value: ' ' }, + { type: 'word', value: 'cat' }, + { type: 'rparen' }, + ]); + }); + + it('should tokenize escaped parentheses', () => { + const tokens = tokenize('\\(medium\\)'); + expect(tokens).toEqual([ + { type: 'escaped_paren', value: '(' }, + { type: 'word', value: 'medium' }, + { type: 'escaped_paren', value: ')' }, + ]); + }); + + it('should tokenize mixed escaped and unescaped parentheses', () => { + const tokens = tokenize('colored pencil \\(medium\\) (enhanced)'); + expect(tokens).toEqual([ + { type: 'word', value: 'colored' }, + { type: 'whitespace', value: ' ' }, + { type: 'word', value: 'pencil' }, + { type: 'whitespace', value: ' ' }, + { type: 'escaped_paren', value: '(' }, + { type: 'word', value: 'medium' }, + { type: 'escaped_paren', value: ')' }, + { type: 'whitespace', value: ' ' }, + { type: 'lparen' }, + { type: 'word', value: 'enhanced' }, + { type: 'rparen' }, + ]); + }); + + it('should tokenize groups with weights', () => { + const tokens = tokenize('(a cat)1.2'); + expect(tokens).toEqual([ + { type: 'lparen' }, + { type: 'word', value: 'a' }, + { type: 'whitespace', value: ' ' }, + { type: 'word', value: 'cat' }, + { type: 'rparen' }, + { type: 'weight', value: 1.2 }, + ]); + }); + + it('should tokenize words with weights', () => { + const tokens = tokenize('cat+'); + expect(tokens).toEqual([ + { type: 'word', value: 'cat' }, + { type: 'weight', value: '+' }, + ]); + }); + + it('should tokenize embeddings', () => { + const tokens = tokenize(''); + expect(tokens).toEqual([ + { type: 'lembed' }, + { type: 'word', value: 'embedding_name' }, + { type: 'rembed' }, + ]); + }); + }); + + describe('parseTokens', () => { + it('should parse basic text', () => { + const tokens = tokenize('a cat'); + const ast = parseTokens(tokens); + expect(ast).toEqual([ + { type: 'word', text: 'a' }, + { type: 'whitespace', value: ' ' }, + { type: 'word', text: 'cat' }, + ]); + }); + + it('should parse groups', () => { + const tokens = tokenize('(a cat)'); + const ast = parseTokens(tokens); + expect(ast).toEqual([ + { + type: 'group', + children: [ + { type: 'word', text: 'a' }, + { type: 'whitespace', value: ' ' }, + { type: 'word', text: 'cat' }, + ], + }, + ]); + }); + + it('should parse escaped parentheses', () => { + const tokens = tokenize('\\(medium\\)'); + const ast = parseTokens(tokens); + expect(ast).toEqual([ + { type: 'escaped_paren', value: '(' }, + { type: 'word', text: 'medium' }, + { type: 'escaped_paren', value: ')' }, + ]); + }); + + it('should parse mixed escaped and unescaped parentheses', () => { + const tokens = tokenize('colored pencil \\(medium\\) (enhanced)'); + const ast = parseTokens(tokens); + expect(ast).toEqual([ + { type: 'word', text: 'colored' }, + { type: 'whitespace', value: ' ' }, + { type: 'word', text: 'pencil' }, + { type: 'whitespace', value: ' ' }, + { type: 'escaped_paren', value: '(' }, + { type: 'word', text: 'medium' }, + { type: 'escaped_paren', value: ')' }, + { type: 'whitespace', value: ' ' }, + { + type: 'group', + children: [{ type: 'word', text: 'enhanced' }], + }, + ]); + }); + + it('should parse groups with attention', () => { + const tokens = tokenize('(a cat)1.2'); + const ast = parseTokens(tokens); + expect(ast).toEqual([ + { + type: 'group', + attention: 1.2, + children: [ + { type: 'word', text: 'a' }, + { type: 'whitespace', value: ' ' }, + { type: 'word', text: 'cat' }, + ], + }, + ]); + }); + + it('should parse words with attention', () => { + const tokens = tokenize('cat+'); + const ast = parseTokens(tokens); + expect(ast).toEqual([{ type: 'word', text: 'cat', attention: '+' }]); + }); + + it('should parse embeddings', () => { + const tokens = tokenize(''); + const ast = parseTokens(tokens); + expect(ast).toEqual([{ type: 'embedding', value: 'embedding_name' }]); + }); + }); + + describe('serialize', () => { + it('should serialize basic text', () => { + const tokens = tokenize('a cat'); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe('a cat'); + }); + + it('should serialize groups', () => { + const tokens = tokenize('(a cat)'); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe('(a cat)'); + }); + + it('should serialize escaped parentheses', () => { + const tokens = tokenize('\\(medium\\)'); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe('\\(medium\\)'); + }); + + it('should serialize mixed escaped and unescaped parentheses', () => { + const tokens = tokenize('colored pencil \\(medium\\) (enhanced)'); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe('colored pencil \\(medium\\) (enhanced)'); + }); + + it('should serialize groups with attention', () => { + const tokens = tokenize('(a cat)1.2'); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe('(a cat)1.2'); + }); + + it('should serialize words with attention', () => { + const tokens = tokenize('cat+'); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe('cat+'); + }); + + it('should serialize embeddings', () => { + const tokens = tokenize(''); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe(''); + }); + }); + + describe('compel compatibility examples', () => { + it('should handle escaped parentheses for literal text', () => { + const prompt = 'A bear \\(with razor-sharp teeth\\) in a forest.'; + const tokens = tokenize(prompt); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe(prompt); + }); + + it('should handle unescaped parentheses as grouping syntax', () => { + const prompt = 'A bear (with razor-sharp teeth) in a forest.'; + const tokens = tokenize(prompt); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe(prompt); + }); + + it('should handle colored pencil medium example', () => { + const prompt = 'colored pencil \\(medium\\)'; + const tokens = tokenize(prompt); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe(prompt); + }); + + it('should distinguish between escaped and unescaped in same prompt', () => { + const prompt = 'portrait \\(realistic\\) (high quality)1.2'; + const tokens = tokenize(prompt); + const ast = parseTokens(tokens); + + // Should have escaped parens as nodes and a group with attention + expect(ast).toEqual([ + { type: 'word', text: 'portrait' }, + { type: 'whitespace', value: ' ' }, + { type: 'escaped_paren', value: '(' }, + { type: 'word', text: 'realistic' }, + { type: 'escaped_paren', value: ')' }, + { type: 'whitespace', value: ' ' }, + { + type: 'group', + attention: 1.2, + children: [ + { type: 'word', text: 'high' }, + { type: 'whitespace', value: ' ' }, + { type: 'word', text: 'quality' }, + ], + }, + ]); + + const result = serialize(ast); + expect(result).toBe(prompt); + }); + }); +}); From 32e19e533f6dee15cf55244ca6bf626f50ccc40b Mon Sep 17 00:00:00 2001 From: joshistoast Date: Sat, 1 Nov 2025 17:42:45 -0600 Subject: [PATCH 4/5] fix(prompts): appease the linter --- invokeai/frontend/web/src/common/util/promptAST.test.ts | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/invokeai/frontend/web/src/common/util/promptAST.test.ts b/invokeai/frontend/web/src/common/util/promptAST.test.ts index d958b3915a0..25786d417af 100644 --- a/invokeai/frontend/web/src/common/util/promptAST.test.ts +++ b/invokeai/frontend/web/src/common/util/promptAST.test.ts @@ -1,4 +1,5 @@ import { describe, expect, it } from 'vitest'; + import { parseTokens, serialize, tokenize } from './promptAST'; describe('promptAST', () => { @@ -71,11 +72,7 @@ describe('promptAST', () => { it('should tokenize embeddings', () => { const tokens = tokenize(''); - expect(tokens).toEqual([ - { type: 'lembed' }, - { type: 'word', value: 'embedding_name' }, - { type: 'rembed' }, - ]); + expect(tokens).toEqual([{ type: 'lembed' }, { type: 'word', value: 'embedding_name' }, { type: 'rembed' }]); }); }); From 2a3053945bba9f70f9ee886cc5a13b1161be56e5 Mon Sep 17 00:00:00 2001 From: joshistoast Date: Sun, 2 Nov 2025 14:21:20 -0700 Subject: [PATCH 5/5] perf(prompts): break up tokenize function into subroutines --- .../frontend/web/src/common/util/promptAST.ts | 223 ++++++++++++------ 1 file changed, 147 insertions(+), 76 deletions(-) diff --git a/invokeai/frontend/web/src/common/util/promptAST.ts b/invokeai/frontend/web/src/common/util/promptAST.ts index 698e5adf5dc..ab9df32e064 100644 --- a/invokeai/frontend/web/src/common/util/promptAST.ts +++ b/invokeai/frontend/web/src/common/util/promptAST.ts @@ -30,6 +30,11 @@ export type ASTNode = | { type: 'punct'; value: Punct } | { type: 'escaped_paren'; value: '(' | ')' }; +const WEIGHT_PATTERN = /^[+-]?(\d+(\.\d+)?|[+-]+)/; +const WHITESPACE_PATTERN = /^\s+/; +const PUNCTUATION_PATTERN = /^[.,]/; +const OTHER_PATTERN = /\s/; + /** * Convert a prompt string into an AST. * @param prompt string @@ -40,106 +45,172 @@ export function tokenize(prompt: string): Token[] { return []; } - let i = 0; + const len = prompt.length; let tokens: Token[] = []; + let i = 0; - while (i < prompt.length) { + while (i < len) { const char = prompt[i]; if (!char) { break; } - // Whitespace (including newlines) - if (/\s/.test(char)) { - tokens.push({ type: 'whitespace', value: char }); + const result = + tokenizeWhitespace(char, i) || + tokenizeEscapedParen(prompt, i) || + tokenizeLeftParen(char, i) || + tokenizeRightParen(prompt, i) || + tokenizeEmbedding(char, i) || + tokenizeWord(prompt, i) || + tokenizePunctuation(char, i) || + tokenizeOther(char, i); + + if (result) { + if (result.token) { + tokens.push(result.token); + } + if (result.extraToken) { + tokens.push(result.extraToken); + } + i = result.nextIndex; + } else { i++; - continue; } + } - // Escaped parentheses (e.g., \( or \)) - if (char === '\\' && i + 1 < prompt.length) { - const nextChar = prompt[i + 1]; - if (nextChar === '(' || nextChar === ')') { - tokens.push({ type: 'escaped_paren', value: nextChar }); - i += 2; - continue; - } - } + return tokens; +} - // Parentheses - if (char === '(') { - tokens.push({ type: 'lparen' }); - i++; - continue; - } +type TokenizeResult = { + token?: Token; + extraToken?: Token; + nextIndex: number; +} | null; + +function tokenizeWhitespace(char: string, i: number): TokenizeResult { + if (WHITESPACE_PATTERN.test(char)) { + return { + token: { type: 'whitespace', value: char }, + nextIndex: i + 1, + }; + } + return null; +} - if (char === ')') { - // Look ahead for weight like ')1.1' or ')-0.9' or ')+' or ')-' - const weightMatch = prompt.slice(i + 1).match(/^[+-]?(\d+(\.\d+)?|[+-]+)/); - if (weightMatch && weightMatch[0]) { - let weight: Attention = weightMatch[0]; - if (!isNaN(Number(weight))) { - weight = Number(weight); - } - tokens.push({ type: 'rparen' }); - tokens.push({ type: 'weight', value: weight }); - i += 1 + weightMatch[0].length; - continue; - } - tokens.push({ type: 'rparen' }); - i++; - continue; +function tokenizeEscapedParen(prompt: string, i: number): TokenizeResult { + const char = prompt[i]; + if (char === '\\' && i + 1 < prompt.length) { + const nextChar = prompt[i + 1]; + if (nextChar === '(' || nextChar === ')') { + return { + token: { type: 'escaped_paren', value: nextChar }, + nextIndex: i + 2, + }; } + } + return null; +} - // Handle punctuation (comma, period, etc.) - if (/[,.]/.test(char)) { - tokens.push({ type: 'punct', value: char }); - i++; - continue; - } +function tokenizeLeftParen(char: string, i: number): TokenizeResult { + if (char === '(') { + return { + token: { type: 'lparen' }, + nextIndex: i + 1, + }; + } + return null; +} - // Read a word (letters, digits, underscores) - if (/[a-zA-Z0-9_]/.test(char)) { - let j = i; - while (j < prompt.length && /[a-zA-Z0-9_]/.test(prompt[j]!)) { - j++; - } - const word = prompt.slice(i, j); - tokens.push({ type: 'word', value: word }); - - // Check for weight immediately after word (e.g., "Lorem+", "consectetur-") - const weightMatch = prompt.slice(j).match(/^[+-]?(\d+(\.\d+)?|[+-]+)/); - if (weightMatch && weightMatch[0]) { - tokens.push({ type: 'weight', value: weightMatch[0] }); - i = j + weightMatch[0].length; - } else { - i = j; +function tokenizeRightParen(prompt: string, i: number): TokenizeResult { + const char = prompt[i]; + if (char === ')') { + // Look ahead for weight like ')1.1' or ')-0.9' or ')+' or ')-' + const weightMatch = prompt.slice(i + 1).match(WEIGHT_PATTERN); + if (weightMatch && weightMatch[0]) { + let weight: Attention = weightMatch[0]; + if (!isNaN(Number(weight))) { + weight = Number(weight); } - continue; + return { + token: { type: 'rparen' }, + extraToken: { type: 'weight', value: weight }, + nextIndex: i + 1 + weightMatch[0].length, + }; } + return { + token: { type: 'rparen' }, + nextIndex: i + 1, + }; + } + return null; +} - // Embeddings - if (char === '<') { - tokens.push({ type: 'lembed' }); - i++; - continue; - } +function tokenizePunctuation(char: string, i: number): TokenizeResult { + if (PUNCTUATION_PATTERN.test(char)) { + return { + token: { type: 'punct', value: char }, + nextIndex: i + 1, + }; + } + return null; +} - if (char === '>') { - tokens.push({ type: 'rembed' }); - i++; - continue; - } +function tokenizeWord(prompt: string, i: number): TokenizeResult { + const char = prompt[i]; + if (!char) { + return null; + } - // Any other single character punctuation - if (!/\s/.test(char)) { - tokens.push({ type: 'punct', value: char }); + if (/[a-zA-Z0-9_]/.test(char)) { + let j = i; + while (j < prompt.length && /[a-zA-Z0-9_]/.test(prompt[j]!)) { + j++; + } + const word = prompt.slice(i, j); + + // Check for weight immediately after word (e.g., "Lorem+", "consectetur-") + const weightMatch = prompt.slice(j).match(/^[+-]?(\d+(\.\d+)?|[+-]+)/); + if (weightMatch && weightMatch[0]) { + return { + token: { type: 'word', value: word }, + extraToken: { type: 'weight', value: weightMatch[0] }, + nextIndex: j + weightMatch[0].length, + }; } - i++; + return { + token: { type: 'word', value: word }, + nextIndex: j, + }; } + return null; +} - return tokens; +function tokenizeEmbedding(char: string, i: number): TokenizeResult { + if (char === '<') { + return { + token: { type: 'lembed' }, + nextIndex: i + 1, + }; + } + if (char === '>') { + return { + token: { type: 'rembed' }, + nextIndex: i + 1, + }; + } + return null; +} + +function tokenizeOther(char: string, i: number): TokenizeResult { + // Any other single character punctuation + if (OTHER_PATTERN.test(char)) { + return { + token: { type: 'punct', value: char }, + nextIndex: i + 1, + }; + } + return null; } /**