Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion deploy/config/sim-epp-kvcache-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ plugins:
blockSize: 16 # must match vLLM block size if not default (16)
hashSeed: "42" # must match PYTHONHASHSEED in vLLM pods
tokenizersPoolConfig:
tokenizersCacheDir: "/cache/tokenizers"
hf:
tokenizersCacheDir: "/cache/tokenizers"
kvBlockIndexConfig:
enableMetrics: false # enable kv-block index metrics (prometheus)
metricsLoggingInterval: 6000000000 # log kv-block metrics as well (1m in nanoseconds)
Expand Down
8 changes: 5 additions & 3 deletions docs/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ plugins:
blockSize: 64
hashSeed: "12345"
tokenizersPoolConfig:
huggingFaceToken: your_hf_token_here # automatically set by `HF_TOKEN` environment variable
hf:
huggingFaceToken: your_hf_token_here # automatically set by `HF_TOKEN` environment variable
kvBlockIndexConfig:
enableMetrics: true
```
Expand Down Expand Up @@ -359,8 +360,9 @@ plugins:
enableMetrics: true
tokenizersPoolConfig:
workersCount: 8
huggingFaceToken: your_hf_token_here # automatically set by `HF_TOKEN` environment variable
tokenizersCacheDir: /tmp/tokenizers
hf:
huggingFaceToken: your_hf_token_here # automatically set by `HF_TOKEN` environment variable
tokenizersCacheDir: /tmp/tokenizers
```

---
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ require (
github.com/google/uuid v1.6.0
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/jellydator/ttlcache/v3 v3.4.0
github.com/llm-d/llm-d-kv-cache-manager v0.3.2
github.com/llm-d/llm-d-kv-cache-manager v0.4.0-rc2
github.com/onsi/ginkgo/v2 v2.27.2
github.com/onsi/gomega v1.38.2
github.com/openai/openai-go v1.12.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/llm-d/llm-d-kv-cache-manager v0.3.2 h1:omSTXtuII3ol37CaoI9h+2VxE0m8EoeVOor+CkQh99I=
github.com/llm-d/llm-d-kv-cache-manager v0.3.2/go.mod h1:q6u7LnzMxNcHHb5/LRdHNNeZzzGMSENFSP1NGfsJEmA=
github.com/llm-d/llm-d-kv-cache-manager v0.4.0-rc2 h1:l2Sm8W6SRg4TAme4RsndwZ++5+4aQvDI4vnf8TKrhww=
github.com/llm-d/llm-d-kv-cache-manager v0.4.0-rc2/go.mod h1:ZlK7MCuz5D/weLeHyNKEmVF/eJZDyYn3XyRowTihq9o=
github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/maruel/natural v1.1.1 h1:Hja7XhhmvEFhcByqDoHz9QZbkWey+COd9xWfCfn1ioo=
Expand Down
105 changes: 46 additions & 59 deletions pkg/plugins/scorer/precise_prefix_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,23 @@ var _ framework.Scorer = &PrecisePrefixCacheScorer{}
// a new instance of the PrefixCacheTrackingPlugin.
func PrecisePrefixCachePluginFactory(name string, rawParameters json.RawMessage,
handle plugins.Handle) (plugins.Plugin, error) {

indexerConfig, err := kvcache.NewDefaultConfig()
if err != nil {
return nil, fmt.Errorf("failed to initialize indexer config: %w", err)
}

parameters := PrecisePrefixCachePluginConfig{
IndexerConfig: kvcache.NewDefaultConfig(),
IndexerConfig: indexerConfig,
KVEventsConfig: kvevents.DefaultConfig(),
}

// read hugging face token from environment variable if set
if token := os.Getenv("HF_TOKEN"); token != "" {
parameters.IndexerConfig.TokenizersPoolConfig.HuggingFaceToken = token
if token := os.Getenv("HF_TOKEN"); token != "" &&
parameters.IndexerConfig != nil &&
parameters.IndexerConfig.TokenizersPoolConfig != nil &&
parameters.IndexerConfig.TokenizersPoolConfig.HFTokenizerConfig != nil {
parameters.IndexerConfig.TokenizersPoolConfig.HFTokenizerConfig.HuggingFaceToken = token
}
Comment on lines +56 to 61
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these lines look like internal implementation details of kvcache manager code.
can we move this to kvcache repo?
I'd expect the call indexerConfig, err := kvcache.NewDefaultConfig() in L45 to initialize that internally.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kvcache library defines a configuration with default values. This code piece is that of a user, permitting configuration through the env-var HF_TOKEN. It is arguable that this is a special env-var that is widely accepted but generally in the kvcache library we attempt to contain all configuration in the referenced structure, leaving customized UX to the users.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right. it's a user defined env var (with the user's token).
HF_TOKEN as env var is very acceptable.

the question I was trying to answer is why do we read that env var here and not in kvcache code.
this part looks very not natural, the scorer factory function writes to an internal config of the kvcache indexer parameters.
I was expecting to have the os.getenv call inside kvcache.NewDefaultConfig().

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this discussion is applicable to the current setup, should we discuss it separately?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that kv cache manager library can do this since the use can still opt-out from the default env-var injection by not using the NewDefaultConfig function if they want to have different defaults

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's track this in a new issue, I think this should not be a blocker.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a blocker 👍


if rawParameters != nil {
Expand Down Expand Up @@ -93,9 +102,8 @@ func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePr
}

return &PrecisePrefixCacheScorer{
typedName: plugins.TypedName{Type: PrecisePrefixCachePluginType},
kvCacheIndexer: kvCacheIndexer,
chatTemplateRenderer: chatTemplateRenderer,
typedName: plugins.TypedName{Type: PrecisePrefixCachePluginType},
kvCacheIndexer: kvCacheIndexer,
}, nil
}

Expand All @@ -105,9 +113,8 @@ func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePr
// state, and the `kvevents.Pool` to subscribe to KV-cache events
// to keep the internal KV-cache index state up-to-date.
type PrecisePrefixCacheScorer struct {
typedName plugins.TypedName
kvCacheIndexer *kvcache.Indexer
chatTemplateRenderer *preprocessing.ChatTemplatingProcessor
typedName plugins.TypedName
kvCacheIndexer *kvcache.Indexer
}

// TypedName returns the typed name of the plugin.
Expand All @@ -125,26 +132,20 @@ func (s *PrecisePrefixCacheScorer) WithName(name string) *PrecisePrefixCacheScor
// The returned scores are normalized to a range of 0-1.
func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
logger := log.FromContext(ctx).WithName(s.typedName.String())
debugLogger := logger.V(logutil.DEBUG)

if request == nil {
logger.V(logutil.DEBUG).Info("Request is nil, skipping scoring")
return nil
}

// Extract the flattened prompt from the request
prompt, err := s.extractPrompt(ctx, request)
if err != nil {
logger.Error(err, "Failed to extract prompt from request")
debugLogger.Info("Request is nil, skipping scoring")
return nil
}

scores, err := s.kvCacheIndexer.GetPodScores(ctx, prompt, request.TargetModel, nil)
// Extract the flattened scores from the request
scores, err := s.getScores(ctx, request)
if err != nil {
logger.Error(err, "Failed to get pod scores")
logger.Error(err, "Failed to extract scores from request")
return nil
}

logger.V(logutil.DEBUG).Info("Got pod scores", "scores", scores)
debugLogger.Info("Got pod scores", "scores", scores)

podToKey := func(pod types.Pod) (string, bool) {
metricsPod := pod.GetPod()
Expand All @@ -161,20 +162,23 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, _ *types.CycleStat
// extractPrompt extracts the flattened prompt from the request.
// For chat completions, it renders the messages using the model's chat template.
// For regular completions, it uses the prompt directly.
func (s *PrecisePrefixCacheScorer) extractPrompt(ctx context.Context, request *types.LLMRequest) (string, error) {
traceLogger := log.FromContext(ctx).V(logutil.TRACE).WithName(s.typedName.String())
func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *types.LLMRequest) (map[string]float64, error) {
logger := log.FromContext(ctx).WithName(s.typedName.String())
debugLogger := logger.V(logutil.DEBUG)
traceLogger := logger.V(logutil.TRACE)

debugLogger.Info("Getting scores",
"target_model", request.TargetModel,
"has_chat_completions", request.Body != nil && request.Body.ChatCompletions != nil,
"has_completions", request.Body != nil && request.Body.Completions != nil)

// The upstream parser guarantees exactly one body is populated, but we defensively prioritize chat completions.
// If an unexpected dual payload slips through (parser regression/new client), log it and use chat semantics.
if request.Body != nil && request.Body.ChatCompletions != nil {
if request.Body.Completions != nil {
traceLogger.Info("Both chat/completions and completions present; defaulting to chat/completions")
}
traceLogger.Info("Processing chat completion request",
"messages_count", len(request.Body.ChatCompletions.Messages),
"target_model", request.TargetModel)

// Create render request
renderReq := &preprocessing.RenderJinjaTemplateRequest{
Conversations: make([]preprocessing.ChatMessage, 0),
Tools: request.Body.ChatCompletions.Tools,
Expand All @@ -194,47 +198,30 @@ func (s *PrecisePrefixCacheScorer) extractPrompt(ctx context.Context, request *t
})
}

// Fetch the chat template from the model
fetchReq := preprocessing.FetchChatTemplateRequest{
Model: request.TargetModel,
}

chatTemplate, chatTemplateKWArgs, err := s.chatTemplateRenderer.FetchChatTemplate(ctx, fetchReq)
if err != nil {
return "", fmt.Errorf("failed to fetch chat template: %w", err)
}

traceLogger.Info("Chat template fetched",
"model", request.TargetModel,
"templateLength", len(chatTemplate),
"hasKwargs", len(chatTemplateKWArgs) > 0)

// Set the fetched template in the render request
renderReq.ChatTemplate = chatTemplate
renderReq.ChatTemplateKWArgs = chatTemplateKWArgs
traceLogger.Info("Processing chat completion request",
"messages_count", len(renderReq.Conversations),
"tools_count", len(renderReq.Tools),
"documents_count", len(renderReq.Documents),
"target_model", request.TargetModel)

// Render the template to get flattened prompt
resp, err := s.chatTemplateRenderer.RenderChatTemplate(ctx, renderReq)
scores, err := s.kvCacheIndexer.GetPodScores(ctx, renderReq, "", request.TargetModel, nil)
if err != nil {
return "", fmt.Errorf("failed to render chat template: %w", err)
return nil, fmt.Errorf("failed to get pod scores for chat/completions: %w", err)
}

if len(resp.RenderedChats) == 0 {
return "", errors.New("no rendered chat returned from template rendering")
}

prompt := resp.RenderedChats[0]
traceLogger.Info("Chat template rendered successfully",
"promptLength", len(prompt))
return prompt, nil
return scores, nil
}

// For regular completions, use the prompt directly
if request.Body != nil && request.Body.Completions != nil {
prompt := request.Body.Completions.Prompt
traceLogger.Info("Using completion prompt directly", "promptLength", len(prompt))
return prompt, nil

scores, err := s.kvCacheIndexer.GetPodScores(ctx, nil, prompt, request.TargetModel, nil)
if err != nil {
return nil, fmt.Errorf("failed to get pod scores for completions: %w", err)
}
return scores, nil
}

return "", errors.New("no valid prompt found in request")
return nil, errors.New("no valid input found in request")
}
16 changes: 10 additions & 6 deletions pkg/plugins/scorer/utils.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package scorer

import "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
import (
"math"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

// podToKey is a function type that converts a Pod to a string key.
// It returns the key and a boolean indicating success.
Expand All @@ -11,7 +15,7 @@ type podToKeyFunc func(pod types.Pod) (string, bool)
// a pod to a key, and a map of scores indexed by those keys. It returns a map
// of pods to their normalized scores.
func indexedScoresToNormalizedScoredPods(pods []types.Pod, podToKey podToKeyFunc,
scores map[string]int) map[types.Pod]float64 {
scores map[string]float64) map[types.Pod]float64 {
scoredPods := make(map[types.Pod]float64)
minScore, maxScore := getMinMax(scores)

Expand All @@ -27,7 +31,7 @@ func indexedScoresToNormalizedScoredPods(pods []types.Pod, podToKey podToKeyFunc
continue
}

scoredPods[pod] = float64(score-minScore) / float64(maxScore-minScore)
scoredPods[pod] = (score - minScore) / (maxScore - minScore)
} else {
scoredPods[pod] = 0.0
}
Expand All @@ -36,9 +40,9 @@ func indexedScoresToNormalizedScoredPods(pods []types.Pod, podToKey podToKeyFunc
return scoredPods
}

func getMinMax(scores map[string]int) (int, int) {
minScore := int(^uint(0) >> 1) // max int
maxScore := -1
func getMinMax(scores map[string]float64) (float64, float64) {
minScore := math.MaxFloat64
maxScore := math.Inf(-1)

for _, score := range scores {
if score < minScore {
Expand Down
10 changes: 8 additions & 2 deletions test/e2e/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,12 @@ func runCompletion(prompt string, theModel openai.CompletionNewParamsModel) (str
Model: theModel,
}

resp, err := openaiclient.Completions.New(testConfig.Context, completionParams, option.WithResponseInto(&httpResp))
ginkgo.By(fmt.Sprintf("Sending Completion Request: (port %s) %#v", port, completionParams))

resp, err := openaiclient.Completions.New(testConfig.Context, completionParams, option.WithResponseInto(&httpResp), option.WithRequestTimeout(readyTimeout))

ginkgo.By(fmt.Sprintf("Verifying Completion Response: %#v", resp))

gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
gomega.Expect(resp.Choices).Should(gomega.HaveLen(1))
gomega.Expect(resp.Choices[0].FinishReason).Should(gomega.Equal(openai.CompletionChoiceFinishReasonStop))
Expand Down Expand Up @@ -445,7 +450,8 @@ plugins:
blockSize: 16 # must match vLLM block size if not default (16)
hashSeed: "42" # must match PYTHONHASHSEED in vLLM pods
tokenizersPoolConfig:
tokenizersCacheDir: "/cache/tokenizers"
hf:
tokenizersCacheDir: "/cache/tokenizers"
kvBlockIndexConfig:
enableMetrics: false # enable kv-block index metrics (prometheus)
metricsLoggingInterval: 6000000000 # log kv-block metrics as well (1m in nanoseconds)
Expand Down
15 changes: 11 additions & 4 deletions test/e2e/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ func scaleDeployment(objects []string, increment int) {

// getModelServerPods Returns the list of Prefill and Decode vLLM pods separately
func getModelServerPods(podLabels, prefillLabels, decodeLabels map[string]string) ([]string, []string) {
ginkgo.By("Getting Model server pods")

pods := getPods(podLabels)

prefillValidator, err := apilabels.ValidatedSelectorFromSet(prefillLabels)
Expand Down Expand Up @@ -98,17 +100,22 @@ func getPods(labels map[string]string) []corev1.Pod {
}

func podsInDeploymentsReady(objects []string) {
var deployment appsv1.Deployment
helper := func(deploymentName string) bool {
isDeploymentReady := func(deploymentName string) bool {
var deployment appsv1.Deployment
err := testConfig.K8sClient.Get(testConfig.Context, types.NamespacedName{Namespace: nsName, Name: deploymentName}, &deployment)
ginkgo.By(fmt.Sprintf("Waiting for deployment %q to be ready (err: %v): replicas=%#v, status=%#v", deploymentName, err, *deployment.Spec.Replicas, deployment.Status))
return err == nil && *deployment.Spec.Replicas == deployment.Status.Replicas &&
deployment.Status.Replicas == deployment.Status.ReadyReplicas
}

for _, kindAndName := range objects {
split := strings.Split(kindAndName, "/")
if strings.ToLower(split[0]) == deploymentKind {
ginkgo.By(fmt.Sprintf("Waiting for pods of %s to be ready", split[1]))
gomega.Eventually(helper, readyTimeout, interval).WithArguments(split[1]).Should(gomega.BeTrue())
gomega.Eventually(isDeploymentReady).
WithArguments(split[1]).
WithPolling(interval).
WithTimeout(readyTimeout).
Should(gomega.BeTrue())
}
}
}
Expand Down