diff --git a/go.mod b/go.mod index 445d863a..9ed37c18 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/google/uuid v1.6.0 github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/jellydator/ttlcache/v3 v3.4.0 - github.com/llm-d/llm-d-kv-cache-manager v0.3.2 + github.com/llm-d/llm-d-kv-cache-manager v0.4.0-rc2 github.com/onsi/ginkgo/v2 v2.27.2 github.com/onsi/gomega v1.38.2 github.com/openai/openai-go v1.12.0 diff --git a/go.sum b/go.sum index d517af9e..8dcb843b 100644 --- a/go.sum +++ b/go.sum @@ -181,8 +181,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/llm-d/llm-d-kv-cache-manager v0.3.2 h1:omSTXtuII3ol37CaoI9h+2VxE0m8EoeVOor+CkQh99I= -github.com/llm-d/llm-d-kv-cache-manager v0.3.2/go.mod h1:q6u7LnzMxNcHHb5/LRdHNNeZzzGMSENFSP1NGfsJEmA= +github.com/llm-d/llm-d-kv-cache-manager v0.4.0-rc2 h1:l2Sm8W6SRg4TAme4RsndwZ++5+4aQvDI4vnf8TKrhww= +github.com/llm-d/llm-d-kv-cache-manager v0.4.0-rc2/go.mod h1:ZlK7MCuz5D/weLeHyNKEmVF/eJZDyYn3XyRowTihq9o= github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/maruel/natural v1.1.1 h1:Hja7XhhmvEFhcByqDoHz9QZbkWey+COd9xWfCfn1ioo= diff --git a/pkg/plugins/scorer/precise_prefix_cache.go b/pkg/plugins/scorer/precise_prefix_cache.go index 091b6f8a..ee42646f 100644 --- a/pkg/plugins/scorer/precise_prefix_cache.go +++ b/pkg/plugins/scorer/precise_prefix_cache.go @@ -41,14 +41,18 @@ var _ framework.Scorer = &PrecisePrefixCacheScorer{} // a new instance of the PrefixCacheTrackingPlugin. func PrecisePrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { + indexerConfig, err := kvcache.NewDefaultConfig() + if err != nil { + return nil, fmt.Errorf("failed to create default indexer config: %w", err) + } parameters := PrecisePrefixCachePluginConfig{ - IndexerConfig: kvcache.NewDefaultConfig(), + IndexerConfig: indexerConfig, KVEventsConfig: kvevents.DefaultConfig(), } // read hugging face token from environment variable if set if token := os.Getenv("HF_TOKEN"); token != "" { - parameters.IndexerConfig.TokenizersPoolConfig.HuggingFaceToken = token + parameters.IndexerConfig.TokenizersPoolConfig.HFTokenizerConfig.HuggingFaceToken = token } if rawParameters != nil { @@ -87,15 +91,9 @@ func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePr pool := kvevents.NewPool(config.KVEventsConfig, kvCacheIndexer.KVBlockIndex()) pool.Start(ctx) - chatTemplateRenderer := preprocessing.NewChatTemplatingProcessor() - if err := chatTemplateRenderer.Initialize(); err != nil { - return nil, fmt.Errorf("failed to initialize chat templating processor: %w", err) - } - return &PrecisePrefixCacheScorer{ - typedName: plugins.TypedName{Type: PrecisePrefixCachePluginType}, - kvCacheIndexer: kvCacheIndexer, - chatTemplateRenderer: chatTemplateRenderer, + typedName: plugins.TypedName{Type: PrecisePrefixCachePluginType}, + kvCacheIndexer: kvCacheIndexer, }, nil } @@ -105,9 +103,8 @@ func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePr // state, and the `kvevents.Pool` to subscribe to KV-cache events // to keep the internal KV-cache index state up-to-date. type PrecisePrefixCacheScorer struct { - typedName plugins.TypedName - kvCacheIndexer *kvcache.Indexer - chatTemplateRenderer *preprocessing.ChatTemplatingProcessor + typedName plugins.TypedName + kvCacheIndexer *kvcache.Indexer } // TypedName returns the typed name of the plugin. @@ -132,13 +129,13 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, _ *types.CycleStat } // Extract the flattened prompt from the request - prompt, err := s.extractPrompt(ctx, request) + renderReq, prompt, err := s.extractRequest(ctx, request) if err != nil { logger.Error(err, "Failed to extract prompt from request") return nil } - scores, err := s.kvCacheIndexer.GetPodScores(ctx, prompt, request.TargetModel, nil) + scores, err := s.kvCacheIndexer.GetPodScores(ctx, renderReq, prompt, request.TargetModel, nil) if err != nil { logger.Error(err, "Failed to get pod scores") return nil @@ -158,10 +155,10 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, _ *types.CycleStat return indexedScoresToNormalizedScoredPods(pods, podToKey, scores) } -// extractPrompt extracts the flattened prompt from the request. +// extractRequest extracts the flattened prompt from the request. // For chat completions, it renders the messages using the model's chat template. // For regular completions, it uses the prompt directly. -func (s *PrecisePrefixCacheScorer) extractPrompt(ctx context.Context, request *types.LLMRequest) (string, error) { +func (s *PrecisePrefixCacheScorer) extractRequest(ctx context.Context, request *types.LLMRequest) (*preprocessing.RenderJinjaTemplateRequest, string, error) { traceLogger := log.FromContext(ctx).V(logutil.TRACE).WithName(s.typedName.String()) // The upstream parser guarantees exactly one body is populated, but we defensively prioritize chat completions. @@ -194,47 +191,15 @@ func (s *PrecisePrefixCacheScorer) extractPrompt(ctx context.Context, request *t }) } - // Fetch the chat template from the model - fetchReq := preprocessing.FetchChatTemplateRequest{ - Model: request.TargetModel, - } - - chatTemplate, chatTemplateKWArgs, err := s.chatTemplateRenderer.FetchChatTemplate(ctx, fetchReq) - if err != nil { - return "", fmt.Errorf("failed to fetch chat template: %w", err) - } - - traceLogger.Info("Chat template fetched", - "model", request.TargetModel, - "templateLength", len(chatTemplate), - "hasKwargs", len(chatTemplateKWArgs) > 0) - - // Set the fetched template in the render request - renderReq.ChatTemplate = chatTemplate - renderReq.ChatTemplateKWArgs = chatTemplateKWArgs - - // Render the template to get flattened prompt - resp, err := s.chatTemplateRenderer.RenderChatTemplate(ctx, renderReq) - if err != nil { - return "", fmt.Errorf("failed to render chat template: %w", err) - } - - if len(resp.RenderedChats) == 0 { - return "", errors.New("no rendered chat returned from template rendering") - } - - prompt := resp.RenderedChats[0] - traceLogger.Info("Chat template rendered successfully", - "promptLength", len(prompt)) - return prompt, nil + return renderReq, "", nil } // For regular completions, use the prompt directly if request.Body != nil && request.Body.Completions != nil { prompt := request.Body.Completions.Prompt traceLogger.Info("Using completion prompt directly", "promptLength", len(prompt)) - return prompt, nil + return nil, prompt, nil } - return "", errors.New("no valid prompt found in request") + return nil, "", errors.New("no valid prompt found in request") } diff --git a/pkg/plugins/scorer/utils.go b/pkg/plugins/scorer/utils.go index 74d03eef..533eace4 100644 --- a/pkg/plugins/scorer/utils.go +++ b/pkg/plugins/scorer/utils.go @@ -1,6 +1,10 @@ package scorer -import "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +import ( + "math" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) // podToKey is a function type that converts a Pod to a string key. // It returns the key and a boolean indicating success. @@ -11,7 +15,7 @@ type podToKeyFunc func(pod types.Pod) (string, bool) // a pod to a key, and a map of scores indexed by those keys. It returns a map // of pods to their normalized scores. func indexedScoresToNormalizedScoredPods(pods []types.Pod, podToKey podToKeyFunc, - scores map[string]int) map[types.Pod]float64 { + scores map[string]float64) map[types.Pod]float64 { scoredPods := make(map[types.Pod]float64) minScore, maxScore := getMinMax(scores) @@ -36,9 +40,9 @@ func indexedScoresToNormalizedScoredPods(pods []types.Pod, podToKey podToKeyFunc return scoredPods } -func getMinMax(scores map[string]int) (int, int) { - minScore := int(^uint(0) >> 1) // max int - maxScore := -1 +func getMinMax(scores map[string]float64) (float64, float64) { + minScore := math.MaxFloat64 + maxScore := float64(-1) for _, score := range scores { if score < minScore {