@@ -34,6 +34,7 @@ import {
3434 PrefillChunkSizeSmallerThanImageError ,
3535 CannotFindImageEmbedError ,
3636} from "./error" ;
37+ import { ChatCompletionMessageParam } from "./openai_api_protocols/chat_completion" ;
3738
3839type ImageURL = ChatCompletionContentPartImage . ImageURL ;
3940
@@ -128,6 +129,8 @@ export class LLMChatPipeline {
128129 private curRoundGrammarInitTotalTime = 0 ;
129130 // Total time of getting next bitmask and accepting token in seconds
130131 private curRoundGrammarPerTokenTotalTime = 0 ;
132+ private seqIdToPrefix : Map < number , number [ ] > ;
133+ private nextSequenceId : number ;
131134
132135 constructor (
133136 tvm : tvmjs . Instance ,
@@ -173,6 +176,8 @@ export class LLMChatPipeline {
173176 log . info ( "token_postproc_method: " , this . token_postproc_method ) ;
174177 log . info ( "prepend_space_in_encode: " , this . prepend_space_in_encode ) ;
175178
179+ this . seqIdToPrefix = new Map < number , number [ ] > ( ) ;
180+ this . nextSequenceId = 0 ;
176181 this . device = this . tvm . webgpu ( ) ;
177182
178183 // 1. Create VM and get the core functions
@@ -344,7 +349,12 @@ export class LLMChatPipeline {
344349 * Reset KV Cache
345350 */
346351 resetKVCache ( ) {
347- this . fclearKVCaches ( this . kvCache ) ;
352+ // Check whether to keep prefixes in the KV cache
353+ if ( this . seqIdToPrefix . size === 0 ) {
354+ this . fclearKVCaches ( this . kvCache ) ;
355+ } else {
356+ this . fKVCacheRemoveSequence ! ( this . kvCache , new tvmjs . Scalar ( 0 , "int64" ) ) ;
357+ }
348358 this . fKVCacheAddSequence ! ( this . kvCache , new tvmjs . Scalar ( 0 , "int64" ) ) ;
349359 if ( this . slidingWindowSize != - 1 ) {
350360 this . fKVCacheEnableSlidingWindowForSeq (
@@ -483,6 +493,15 @@ export class LLMChatPipeline {
483493 await this . tvm . asyncLoadWebGPUPipelines ( this . vm . getInternalModule ( ) ) ;
484494 }
485495
496+ matchPrefix ( inputTokens : number [ ] , prefixTokens : number [ ] ) : number {
497+ for ( let i = 0 ; i < prefixTokens . length ; i ++ ) {
498+ if ( inputTokens [ i ] !== prefixTokens [ i ] ) {
499+ return i ;
500+ }
501+ }
502+ return prefixTokens . length ;
503+ }
504+
486505 /**
487506 * Generate the first token given input prompt
488507 */
@@ -491,11 +510,17 @@ export class LLMChatPipeline {
491510 msgRole : Role , // either user or tool
492511 inp_role_str ?: string ,
493512 genConfig ?: GenerationConfig ,
513+ seqID = 0 ,
494514 ) : Promise < void > {
495- if ( msgRole !== Role . user && msgRole !== Role . tool ) {
496- throw new MessageOrderError (
497- "The last message should be from `user` or `tool`." ,
498- ) ;
515+ if ( seqID === 0 ) {
516+ if ( msgRole !== Role . user && msgRole !== Role . tool ) {
517+ throw new MessageOrderError (
518+ "The last message should be from `user` or `tool`." ,
519+ ) ;
520+ }
521+ } else {
522+ // Set the input as system prompt during prefix prefilling
523+ this . conversation . override_system_message = inp ;
499524 }
500525 if ( this . resetStatsPerPrefill ) {
501526 this . resetRuntimeStats ( ) ;
@@ -583,11 +608,13 @@ export class LLMChatPipeline {
583608 }
584609
585610 // 0. Get inputData from conversation
586- if ( conversation . isTextCompletion ) {
587- conversation . prompt = inp ;
588- } else {
589- conversation . appendMessage ( msgRole , inp , inp_role_str ) ;
590- conversation . appendReplyHeader ( Role . assistant ) ;
611+ if ( seqID === 0 ) {
612+ if ( conversation . isTextCompletion ) {
613+ conversation . prompt = inp ;
614+ } else {
615+ conversation . appendMessage ( msgRole , inp , inp_role_str ) ;
616+ conversation . appendReplyHeader ( Role . assistant ) ;
617+ }
591618 }
592619 const retGetInputData = this . getInputData ( ) ;
593620 const inputData : Array < Array < number > | ImageURL > = retGetInputData [ 0 ] ;
@@ -610,11 +637,68 @@ export class LLMChatPipeline {
610637 throw new CannotFindImageEmbedError ( ) ;
611638 }
612639
640+ let maxMatchedLen = - 1 ;
641+ let matchedSeqId = - 1 ;
642+
643+ // Prefix matching and forking
644+ const inputTokens = inputData . flat ( ) as number [ ] ;
645+ for ( const [ id , prefixTokens ] of this . seqIdToPrefix ) {
646+ const matchedLen = this . matchPrefix ( inputTokens , prefixTokens ) ;
647+ if ( matchedLen > maxMatchedLen ) {
648+ maxMatchedLen = matchedLen ;
649+ matchedSeqId = id ;
650+ }
651+ }
652+
653+ // If a match is found, fork the sequence
654+ if ( matchedSeqId !== - 1 && maxMatchedLen > 0 ) {
655+ console . log (
656+ "Forking sequence" ,
657+ matchedSeqId ,
658+ "at position" ,
659+ maxMatchedLen ,
660+ ) ;
661+ if ( seqID === 0 ) {
662+ this . fKVCacheRemoveSequence ! (
663+ this . kvCache ,
664+ new tvmjs . Scalar ( seqID , "int64" ) ,
665+ ) ;
666+ }
667+ this . tvm . beginScope ( ) ;
668+ this . tvm . getGlobalFunc ( "vm.builtin.kv_state_fork_sequence" ) (
669+ this . kvCache ,
670+ new tvmjs . Scalar ( matchedSeqId , "int64" ) , // fork_parent_id
671+ new tvmjs . Scalar ( seqID , "int64" ) , // fork_child_id
672+ new tvmjs . Scalar ( maxMatchedLen , "int64" ) , // fork_position
673+ ) ;
674+ this . tvm . endScope ( ) ;
675+ } else if ( seqID !== 0 ) {
676+ // If no match is found, add the new sequence to the KV cache
677+ console . log ( "Adding new sequence to KV cache: " , seqID ) ;
678+ this . fKVCacheAddSequence ! ( this . kvCache , new tvmjs . Scalar ( seqID , "int64" ) ) ;
679+ }
680+
681+ // Add the new sequence to the seqIdToPrefix map (if it is a prefix)
682+ if ( seqID !== 0 ) {
683+ this . seqIdToPrefix . set ( seqID , inputTokens ) ;
684+ }
685+
613686 // 1. Chunk inputData to embed and forward in one shot for each, minimize intermediate data
614- const retGetChunks = getChunkedPrefillInputData (
615- inputData ,
616- this . prefillChunkSize ,
617- ) ;
687+ let retGetChunks ;
688+ if ( maxMatchedLen === - 1 ) {
689+ retGetChunks = getChunkedPrefillInputData (
690+ inputData ,
691+ this . prefillChunkSize ,
692+ ) ;
693+ } else {
694+ // If a matched prefix exists, only forward the remaining tokens
695+ retGetChunks = getChunkedPrefillInputData (
696+ inputData . map ( ( arr ) =>
697+ Array . isArray ( arr ) ? arr . slice ( maxMatchedLen ) : arr ,
698+ ) ,
699+ this . prefillChunkSize ,
700+ ) ;
701+ }
618702 const chunks : Array < Array < number > | ImageURL > [ ] = retGetChunks [ 0 ] ;
619703 const chunkLens : Array < number > = retGetChunks [ 1 ] ;
620704
@@ -626,7 +710,7 @@ export class LLMChatPipeline {
626710 const chunkLen = chunkLens [ i ] ;
627711 const prevFilledLen = this . filledKVCacheLength ;
628712 logits = this . tvm . detachFromCurrentScope (
629- await this . embedAndForward ( chunk , chunkLen ) ,
713+ await this . embedAndForward ( chunk , chunkLen , seqID ) ,
630714 ) ;
631715 if ( this . filledKVCacheLength !== prevFilledLen + chunkLen ) {
632716 throw new Error (
@@ -651,6 +735,41 @@ export class LLMChatPipeline {
651735 this . processNextToken ( nextToken , genConfig ) ;
652736 }
653737
738+ async prefillConvSequence (
739+ messages : ChatCompletionMessageParam [ ] ,
740+ inp_role_str ?: string ,
741+ genConfig ?: GenerationConfig ,
742+ ) : Promise < void > {
743+ for ( const message of messages ) {
744+ this . nextSequenceId = this . nextSequenceId + 1 ;
745+ const newSeqId = this . nextSequenceId ;
746+ // Call the regular prefillStep with the new seqID
747+ if ( typeof message . content === "string" ) {
748+ // Support long system prompt
749+ if ( message . role === "system" ) {
750+ await this . prefillStep (
751+ message . content ,
752+ Role . tool ,
753+ inp_role_str ,
754+ genConfig ,
755+ newSeqId ,
756+ ) ;
757+ } else {
758+ throw Error (
759+ "Invalid role in prefix message: " +
760+ message . role +
761+ ", expected 'system'." ,
762+ ) ;
763+ }
764+ } else {
765+ throw Error (
766+ "Invalid content in prefix message, does not support image input." ,
767+ ) ;
768+ }
769+ }
770+ this . conversation . reset ( ) ;
771+ }
772+
654773 async decodeStep ( genConfig ?: GenerationConfig ) : Promise < void > {
655774 if ( this . stopTriggered ) {
656775 throw Error ( "Cannot run decode when stopped" ) ;
@@ -869,13 +988,15 @@ export class LLMChatPipeline {
869988 *
870989 * @param inputData data to embed and forward
871990 * @param inputDataLen length of this inputData, should smaller than prefill chunk size.
991+ * @param seqID sequence ID of the input data in KV cache for prefix caching
872992 * @returns The logits returned by this forward as tvmjs.NDArray on GPU.
873993 *
874994 * @note Precondition: inputData's data length is smaller than prefill chunk size
875995 */
876996 private async embedAndForward (
877997 inputData : Array < Array < number > | ImageURL > ,
878998 inputDataLen : number ,
999+ seqID = 0 ,
8791000 ) : Promise < tvmjs . NDArray > {
8801001 if ( inputDataLen > this . prefillChunkSize ) {
8811002 throw new Error (
@@ -913,7 +1034,8 @@ export class LLMChatPipeline {
9131034
9141035 // 3. Forward the concatenated embeddings
9151036 const inputLenShape = this . tvm . makeShapeTuple ( [ inputDataLen ] ) ;
916- const seqIdsTuple = this . tvm . makeShapeTuple ( [ 0 ] ) ;
1037+ // set seqIdsTuple to be childID
1038+ const seqIdsTuple = this . tvm . makeShapeTuple ( [ seqID ] ) ;
9171039 this . fKVCacheBeginForward ! ( this . kvCache , seqIdsTuple , inputLenShape ) ;
9181040 let retValue ;
9191041 if ( inputDataLen > 1 ) {
0 commit comments