diff --git a/deploy/config/sim-epp-kvcache-config.yaml b/deploy/config/sim-epp-kvcache-config.yaml index 1c58b299..566e9243 100644 --- a/deploy/config/sim-epp-kvcache-config.yaml +++ b/deploy/config/sim-epp-kvcache-config.yaml @@ -15,7 +15,8 @@ plugins: blockSize: 16 # must match vLLM block size if not default (16) hashSeed: "42" # must match PYTHONHASHSEED in vLLM pods tokenizersPoolConfig: - tokenizersCacheDir: "/cache/tokenizers" + hf: + tokenizersCacheDir: "/cache/tokenizers" kvBlockIndexConfig: enableMetrics: false # enable kv-block index metrics (prometheus) metricsLoggingInterval: 6000000000 # log kv-block metrics as well (1m in nanoseconds) diff --git a/docs/architecture.md b/docs/architecture.md index 14992c5e..51da234d 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -330,7 +330,8 @@ plugins: blockSize: 64 hashSeed: "12345" tokenizersPoolConfig: - huggingFaceToken: your_hf_token_here # automatically set by `HF_TOKEN` environment variable + hf: + huggingFaceToken: your_hf_token_here # automatically set by `HF_TOKEN` environment variable kvBlockIndexConfig: enableMetrics: true ``` @@ -359,8 +360,9 @@ plugins: enableMetrics: true tokenizersPoolConfig: workersCount: 8 - huggingFaceToken: your_hf_token_here # automatically set by `HF_TOKEN` environment variable - tokenizersCacheDir: /tmp/tokenizers + hf: + huggingFaceToken: your_hf_token_here # automatically set by `HF_TOKEN` environment variable + tokenizersCacheDir: /tmp/tokenizers ``` --- diff --git a/go.mod b/go.mod index 445d863a..9ed37c18 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/google/uuid v1.6.0 github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/jellydator/ttlcache/v3 v3.4.0 - github.com/llm-d/llm-d-kv-cache-manager v0.3.2 + github.com/llm-d/llm-d-kv-cache-manager v0.4.0-rc2 github.com/onsi/ginkgo/v2 v2.27.2 github.com/onsi/gomega v1.38.2 github.com/openai/openai-go v1.12.0 diff --git a/go.sum b/go.sum index d517af9e..8dcb843b 100644 --- a/go.sum +++ b/go.sum @@ -181,8 +181,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/llm-d/llm-d-kv-cache-manager v0.3.2 h1:omSTXtuII3ol37CaoI9h+2VxE0m8EoeVOor+CkQh99I= -github.com/llm-d/llm-d-kv-cache-manager v0.3.2/go.mod h1:q6u7LnzMxNcHHb5/LRdHNNeZzzGMSENFSP1NGfsJEmA= +github.com/llm-d/llm-d-kv-cache-manager v0.4.0-rc2 h1:l2Sm8W6SRg4TAme4RsndwZ++5+4aQvDI4vnf8TKrhww= +github.com/llm-d/llm-d-kv-cache-manager v0.4.0-rc2/go.mod h1:ZlK7MCuz5D/weLeHyNKEmVF/eJZDyYn3XyRowTihq9o= github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/maruel/natural v1.1.1 h1:Hja7XhhmvEFhcByqDoHz9QZbkWey+COd9xWfCfn1ioo= diff --git a/pkg/plugins/scorer/precise_prefix_cache.go b/pkg/plugins/scorer/precise_prefix_cache.go index 091b6f8a..4d35912a 100644 --- a/pkg/plugins/scorer/precise_prefix_cache.go +++ b/pkg/plugins/scorer/precise_prefix_cache.go @@ -41,14 +41,23 @@ var _ framework.Scorer = &PrecisePrefixCacheScorer{} // a new instance of the PrefixCacheTrackingPlugin. func PrecisePrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { + + indexerConfig, err := kvcache.NewDefaultConfig() + if err != nil { + return nil, fmt.Errorf("failed to initialize indexer config: %w", err) + } + parameters := PrecisePrefixCachePluginConfig{ - IndexerConfig: kvcache.NewDefaultConfig(), + IndexerConfig: indexerConfig, KVEventsConfig: kvevents.DefaultConfig(), } // read hugging face token from environment variable if set - if token := os.Getenv("HF_TOKEN"); token != "" { - parameters.IndexerConfig.TokenizersPoolConfig.HuggingFaceToken = token + if token := os.Getenv("HF_TOKEN"); token != "" && + parameters.IndexerConfig != nil && + parameters.IndexerConfig.TokenizersPoolConfig != nil && + parameters.IndexerConfig.TokenizersPoolConfig.HFTokenizerConfig != nil { + parameters.IndexerConfig.TokenizersPoolConfig.HFTokenizerConfig.HuggingFaceToken = token } if rawParameters != nil { @@ -93,9 +102,8 @@ func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePr } return &PrecisePrefixCacheScorer{ - typedName: plugins.TypedName{Type: PrecisePrefixCachePluginType}, - kvCacheIndexer: kvCacheIndexer, - chatTemplateRenderer: chatTemplateRenderer, + typedName: plugins.TypedName{Type: PrecisePrefixCachePluginType}, + kvCacheIndexer: kvCacheIndexer, }, nil } @@ -105,9 +113,8 @@ func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePr // state, and the `kvevents.Pool` to subscribe to KV-cache events // to keep the internal KV-cache index state up-to-date. type PrecisePrefixCacheScorer struct { - typedName plugins.TypedName - kvCacheIndexer *kvcache.Indexer - chatTemplateRenderer *preprocessing.ChatTemplatingProcessor + typedName plugins.TypedName + kvCacheIndexer *kvcache.Indexer } // TypedName returns the typed name of the plugin. @@ -125,26 +132,20 @@ func (s *PrecisePrefixCacheScorer) WithName(name string) *PrecisePrefixCacheScor // The returned scores are normalized to a range of 0-1. func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { logger := log.FromContext(ctx).WithName(s.typedName.String()) + debugLogger := logger.V(logutil.DEBUG) if request == nil { - logger.V(logutil.DEBUG).Info("Request is nil, skipping scoring") - return nil - } - - // Extract the flattened prompt from the request - prompt, err := s.extractPrompt(ctx, request) - if err != nil { - logger.Error(err, "Failed to extract prompt from request") + debugLogger.Info("Request is nil, skipping scoring") return nil } - scores, err := s.kvCacheIndexer.GetPodScores(ctx, prompt, request.TargetModel, nil) + // Extract the flattened scores from the request + scores, err := s.getScores(ctx, request) if err != nil { - logger.Error(err, "Failed to get pod scores") + logger.Error(err, "Failed to extract scores from request") return nil } - - logger.V(logutil.DEBUG).Info("Got pod scores", "scores", scores) + debugLogger.Info("Got pod scores", "scores", scores) podToKey := func(pod types.Pod) (string, bool) { metricsPod := pod.GetPod() @@ -161,8 +162,14 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, _ *types.CycleStat // extractPrompt extracts the flattened prompt from the request. // For chat completions, it renders the messages using the model's chat template. // For regular completions, it uses the prompt directly. -func (s *PrecisePrefixCacheScorer) extractPrompt(ctx context.Context, request *types.LLMRequest) (string, error) { - traceLogger := log.FromContext(ctx).V(logutil.TRACE).WithName(s.typedName.String()) +func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *types.LLMRequest) (map[string]float64, error) { + logger := log.FromContext(ctx).WithName(s.typedName.String()) + traceLogger := logger.V(logutil.TRACE) + + traceLogger.Info("Getting scores", + "target_model", request.TargetModel, + "has_chat_completions", request.Body != nil && request.Body.ChatCompletions != nil, + "has_completions", request.Body != nil && request.Body.Completions != nil) // The upstream parser guarantees exactly one body is populated, but we defensively prioritize chat completions. // If an unexpected dual payload slips through (parser regression/new client), log it and use chat semantics. @@ -170,11 +177,7 @@ func (s *PrecisePrefixCacheScorer) extractPrompt(ctx context.Context, request *t if request.Body.Completions != nil { traceLogger.Info("Both chat/completions and completions present; defaulting to chat/completions") } - traceLogger.Info("Processing chat completion request", - "messages_count", len(request.Body.ChatCompletions.Messages), - "target_model", request.TargetModel) - // Create render request renderReq := &preprocessing.RenderJinjaTemplateRequest{ Conversations: make([]preprocessing.ChatMessage, 0), Tools: request.Body.ChatCompletions.Tools, @@ -194,47 +197,30 @@ func (s *PrecisePrefixCacheScorer) extractPrompt(ctx context.Context, request *t }) } - // Fetch the chat template from the model - fetchReq := preprocessing.FetchChatTemplateRequest{ - Model: request.TargetModel, - } - - chatTemplate, chatTemplateKWArgs, err := s.chatTemplateRenderer.FetchChatTemplate(ctx, fetchReq) - if err != nil { - return "", fmt.Errorf("failed to fetch chat template: %w", err) - } - - traceLogger.Info("Chat template fetched", - "model", request.TargetModel, - "templateLength", len(chatTemplate), - "hasKwargs", len(chatTemplateKWArgs) > 0) - - // Set the fetched template in the render request - renderReq.ChatTemplate = chatTemplate - renderReq.ChatTemplateKWArgs = chatTemplateKWArgs + traceLogger.Info("Processing chat completion request", + "messages_count", len(renderReq.Conversations), + "tools_count", len(renderReq.Tools), + "documents_count", len(renderReq.Documents), + "target_model", request.TargetModel) - // Render the template to get flattened prompt - resp, err := s.chatTemplateRenderer.RenderChatTemplate(ctx, renderReq) + scores, err := s.kvCacheIndexer.GetPodScores(ctx, renderReq, "", request.TargetModel, nil) if err != nil { - return "", fmt.Errorf("failed to render chat template: %w", err) + return nil, fmt.Errorf("failed to get pod scores for chat/completions: %w", err) } - - if len(resp.RenderedChats) == 0 { - return "", errors.New("no rendered chat returned from template rendering") - } - - prompt := resp.RenderedChats[0] - traceLogger.Info("Chat template rendered successfully", - "promptLength", len(prompt)) - return prompt, nil + return scores, nil } // For regular completions, use the prompt directly if request.Body != nil && request.Body.Completions != nil { prompt := request.Body.Completions.Prompt traceLogger.Info("Using completion prompt directly", "promptLength", len(prompt)) - return prompt, nil + + scores, err := s.kvCacheIndexer.GetPodScores(ctx, nil, prompt, request.TargetModel, nil) + if err != nil { + return nil, fmt.Errorf("failed to get pod scores for completions: %w", err) + } + return scores, nil } - return "", errors.New("no valid prompt found in request") + return nil, errors.New("no valid input found in request") } diff --git a/pkg/plugins/scorer/utils.go b/pkg/plugins/scorer/utils.go index 74d03eef..31a721b7 100644 --- a/pkg/plugins/scorer/utils.go +++ b/pkg/plugins/scorer/utils.go @@ -1,6 +1,10 @@ package scorer -import "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +import ( + "math" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) // podToKey is a function type that converts a Pod to a string key. // It returns the key and a boolean indicating success. @@ -11,7 +15,7 @@ type podToKeyFunc func(pod types.Pod) (string, bool) // a pod to a key, and a map of scores indexed by those keys. It returns a map // of pods to their normalized scores. func indexedScoresToNormalizedScoredPods(pods []types.Pod, podToKey podToKeyFunc, - scores map[string]int) map[types.Pod]float64 { + scores map[string]float64) map[types.Pod]float64 { scoredPods := make(map[types.Pod]float64) minScore, maxScore := getMinMax(scores) @@ -27,7 +31,7 @@ func indexedScoresToNormalizedScoredPods(pods []types.Pod, podToKey podToKeyFunc continue } - scoredPods[pod] = float64(score-minScore) / float64(maxScore-minScore) + scoredPods[pod] = (score - minScore) / (maxScore - minScore) } else { scoredPods[pod] = 0.0 } @@ -36,9 +40,9 @@ func indexedScoresToNormalizedScoredPods(pods []types.Pod, podToKey podToKeyFunc return scoredPods } -func getMinMax(scores map[string]int) (int, int) { - minScore := int(^uint(0) >> 1) // max int - maxScore := -1 +func getMinMax(scores map[string]float64) (float64, float64) { + minScore := math.MaxFloat64 + maxScore := math.Inf(-1) for _, score := range scores { if score < minScore { diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index 02eb2d0d..4fcf3ccc 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -341,7 +341,12 @@ func runCompletion(prompt string, theModel openai.CompletionNewParamsModel) (str Model: theModel, } - resp, err := openaiclient.Completions.New(testConfig.Context, completionParams, option.WithResponseInto(&httpResp)) + ginkgo.By(fmt.Sprintf("Sending Completion Request: (port %s) %#v", port, completionParams)) + + resp, err := openaiclient.Completions.New(testConfig.Context, completionParams, option.WithResponseInto(&httpResp), option.WithRequestTimeout(readyTimeout)) + + ginkgo.By(fmt.Sprintf("Verifying Completion Response: %#v", resp)) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) gomega.Expect(resp.Choices).Should(gomega.HaveLen(1)) gomega.Expect(resp.Choices[0].FinishReason).Should(gomega.Equal(openai.CompletionChoiceFinishReasonStop)) @@ -445,7 +450,8 @@ plugins: blockSize: 16 # must match vLLM block size if not default (16) hashSeed: "42" # must match PYTHONHASHSEED in vLLM pods tokenizersPoolConfig: - tokenizersCacheDir: "/cache/tokenizers" + hf: + tokenizersCacheDir: "/cache/tokenizers" kvBlockIndexConfig: enableMetrics: false # enable kv-block index metrics (prometheus) metricsLoggingInterval: 6000000000 # log kv-block metrics as well (1m in nanoseconds) diff --git a/test/e2e/utils_test.go b/test/e2e/utils_test.go index 756f2925..edfe32e4 100644 --- a/test/e2e/utils_test.go +++ b/test/e2e/utils_test.go @@ -46,6 +46,8 @@ func scaleDeployment(objects []string, increment int) { // getModelServerPods Returns the list of Prefill and Decode vLLM pods separately func getModelServerPods(podLabels, prefillLabels, decodeLabels map[string]string) ([]string, []string) { + ginkgo.By("Getting Model server pods") + pods := getPods(podLabels) prefillValidator, err := apilabels.ValidatedSelectorFromSet(prefillLabels) @@ -98,17 +100,22 @@ func getPods(labels map[string]string) []corev1.Pod { } func podsInDeploymentsReady(objects []string) { - var deployment appsv1.Deployment - helper := func(deploymentName string) bool { + isDeploymentReady := func(deploymentName string) bool { + var deployment appsv1.Deployment err := testConfig.K8sClient.Get(testConfig.Context, types.NamespacedName{Namespace: nsName, Name: deploymentName}, &deployment) + ginkgo.By(fmt.Sprintf("Waiting for deployment %q to be ready (err: %v): replicas=%#v, status=%#v", deploymentName, err, *deployment.Spec.Replicas, deployment.Status)) return err == nil && *deployment.Spec.Replicas == deployment.Status.Replicas && deployment.Status.Replicas == deployment.Status.ReadyReplicas } + for _, kindAndName := range objects { split := strings.Split(kindAndName, "/") if strings.ToLower(split[0]) == deploymentKind { - ginkgo.By(fmt.Sprintf("Waiting for pods of %s to be ready", split[1])) - gomega.Eventually(helper, readyTimeout, interval).WithArguments(split[1]).Should(gomega.BeTrue()) + gomega.Eventually(isDeploymentReady). + WithArguments(split[1]). + WithPolling(interval). + WithTimeout(readyTimeout). + Should(gomega.BeTrue()) } } }