Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ require (
github.com/google/uuid v1.6.0
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/jellydator/ttlcache/v3 v3.4.0
github.com/llm-d/llm-d-kv-cache-manager v0.3.2
github.com/llm-d/llm-d-kv-cache-manager v0.4.0-rc2
github.com/onsi/ginkgo/v2 v2.27.2
github.com/onsi/gomega v1.38.2
github.com/openai/openai-go v1.12.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/llm-d/llm-d-kv-cache-manager v0.3.2 h1:omSTXtuII3ol37CaoI9h+2VxE0m8EoeVOor+CkQh99I=
github.com/llm-d/llm-d-kv-cache-manager v0.3.2/go.mod h1:q6u7LnzMxNcHHb5/LRdHNNeZzzGMSENFSP1NGfsJEmA=
github.com/llm-d/llm-d-kv-cache-manager v0.4.0-rc2 h1:l2Sm8W6SRg4TAme4RsndwZ++5+4aQvDI4vnf8TKrhww=
github.com/llm-d/llm-d-kv-cache-manager v0.4.0-rc2/go.mod h1:ZlK7MCuz5D/weLeHyNKEmVF/eJZDyYn3XyRowTihq9o=
github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/maruel/natural v1.1.1 h1:Hja7XhhmvEFhcByqDoHz9QZbkWey+COd9xWfCfn1ioo=
Expand Down
69 changes: 17 additions & 52 deletions pkg/plugins/scorer/precise_prefix_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,18 @@ var _ framework.Scorer = &PrecisePrefixCacheScorer{}
// a new instance of the PrefixCacheTrackingPlugin.
func PrecisePrefixCachePluginFactory(name string, rawParameters json.RawMessage,
handle plugins.Handle) (plugins.Plugin, error) {
indexerConfig, err := kvcache.NewDefaultConfig()
if err != nil {
return nil, fmt.Errorf("failed to create default indexer config: %w", err)
}
parameters := PrecisePrefixCachePluginConfig{
IndexerConfig: kvcache.NewDefaultConfig(),
IndexerConfig: indexerConfig,
KVEventsConfig: kvevents.DefaultConfig(),
}

// read hugging face token from environment variable if set
if token := os.Getenv("HF_TOKEN"); token != "" {
parameters.IndexerConfig.TokenizersPoolConfig.HuggingFaceToken = token
parameters.IndexerConfig.TokenizersPoolConfig.HFTokenizerConfig.HuggingFaceToken = token
}

if rawParameters != nil {
Expand Down Expand Up @@ -87,15 +91,9 @@ func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePr
pool := kvevents.NewPool(config.KVEventsConfig, kvCacheIndexer.KVBlockIndex())
pool.Start(ctx)

chatTemplateRenderer := preprocessing.NewChatTemplatingProcessor()
if err := chatTemplateRenderer.Initialize(); err != nil {
return nil, fmt.Errorf("failed to initialize chat templating processor: %w", err)
}

return &PrecisePrefixCacheScorer{
typedName: plugins.TypedName{Type: PrecisePrefixCachePluginType},
kvCacheIndexer: kvCacheIndexer,
chatTemplateRenderer: chatTemplateRenderer,
typedName: plugins.TypedName{Type: PrecisePrefixCachePluginType},
kvCacheIndexer: kvCacheIndexer,
}, nil
}

Expand All @@ -105,9 +103,8 @@ func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePr
// state, and the `kvevents.Pool` to subscribe to KV-cache events
// to keep the internal KV-cache index state up-to-date.
type PrecisePrefixCacheScorer struct {
typedName plugins.TypedName
kvCacheIndexer *kvcache.Indexer
chatTemplateRenderer *preprocessing.ChatTemplatingProcessor
typedName plugins.TypedName
kvCacheIndexer *kvcache.Indexer
}

// TypedName returns the typed name of the plugin.
Expand All @@ -132,13 +129,13 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, _ *types.CycleStat
}

// Extract the flattened prompt from the request
prompt, err := s.extractPrompt(ctx, request)
renderReq, prompt, err := s.extractRequest(ctx, request)
if err != nil {
logger.Error(err, "Failed to extract prompt from request")
return nil
}

scores, err := s.kvCacheIndexer.GetPodScores(ctx, prompt, request.TargetModel, nil)
scores, err := s.kvCacheIndexer.GetPodScores(ctx, renderReq, prompt, request.TargetModel, nil)
if err != nil {
logger.Error(err, "Failed to get pod scores")
return nil
Expand All @@ -158,10 +155,10 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, _ *types.CycleStat
return indexedScoresToNormalizedScoredPods(pods, podToKey, scores)
}

// extractPrompt extracts the flattened prompt from the request.
// extractRequest extracts the flattened prompt from the request.
// For chat completions, it renders the messages using the model's chat template.
// For regular completions, it uses the prompt directly.
func (s *PrecisePrefixCacheScorer) extractPrompt(ctx context.Context, request *types.LLMRequest) (string, error) {
func (s *PrecisePrefixCacheScorer) extractRequest(ctx context.Context, request *types.LLMRequest) (*preprocessing.RenderJinjaTemplateRequest, string, error) {
traceLogger := log.FromContext(ctx).V(logutil.TRACE).WithName(s.typedName.String())

// The upstream parser guarantees exactly one body is populated, but we defensively prioritize chat completions.
Expand Down Expand Up @@ -194,47 +191,15 @@ func (s *PrecisePrefixCacheScorer) extractPrompt(ctx context.Context, request *t
})
}

// Fetch the chat template from the model
fetchReq := preprocessing.FetchChatTemplateRequest{
Model: request.TargetModel,
}

chatTemplate, chatTemplateKWArgs, err := s.chatTemplateRenderer.FetchChatTemplate(ctx, fetchReq)
if err != nil {
return "", fmt.Errorf("failed to fetch chat template: %w", err)
}

traceLogger.Info("Chat template fetched",
"model", request.TargetModel,
"templateLength", len(chatTemplate),
"hasKwargs", len(chatTemplateKWArgs) > 0)

// Set the fetched template in the render request
renderReq.ChatTemplate = chatTemplate
renderReq.ChatTemplateKWArgs = chatTemplateKWArgs

// Render the template to get flattened prompt
resp, err := s.chatTemplateRenderer.RenderChatTemplate(ctx, renderReq)
if err != nil {
return "", fmt.Errorf("failed to render chat template: %w", err)
}

if len(resp.RenderedChats) == 0 {
return "", errors.New("no rendered chat returned from template rendering")
}

prompt := resp.RenderedChats[0]
traceLogger.Info("Chat template rendered successfully",
"promptLength", len(prompt))
return prompt, nil
return renderReq, "", nil
}

// For regular completions, use the prompt directly
if request.Body != nil && request.Body.Completions != nil {
prompt := request.Body.Completions.Prompt
traceLogger.Info("Using completion prompt directly", "promptLength", len(prompt))
return prompt, nil
return nil, prompt, nil
}

return "", errors.New("no valid prompt found in request")
return nil, "", errors.New("no valid prompt found in request")
}
14 changes: 9 additions & 5 deletions pkg/plugins/scorer/utils.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package scorer

import "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
import (
"math"

"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

// podToKey is a function type that converts a Pod to a string key.
// It returns the key and a boolean indicating success.
Expand All @@ -11,7 +15,7 @@ type podToKeyFunc func(pod types.Pod) (string, bool)
// a pod to a key, and a map of scores indexed by those keys. It returns a map
// of pods to their normalized scores.
func indexedScoresToNormalizedScoredPods(pods []types.Pod, podToKey podToKeyFunc,
scores map[string]int) map[types.Pod]float64 {
scores map[string]float64) map[types.Pod]float64 {
scoredPods := make(map[types.Pod]float64)
minScore, maxScore := getMinMax(scores)

Expand All @@ -36,9 +40,9 @@ func indexedScoresToNormalizedScoredPods(pods []types.Pod, podToKey podToKeyFunc
return scoredPods
}

func getMinMax(scores map[string]int) (int, int) {
minScore := int(^uint(0) >> 1) // max int
maxScore := -1
func getMinMax(scores map[string]float64) (float64, float64) {
minScore := math.MaxFloat64
maxScore := float64(-1)

for _, score := range scores {
if score < minScore {
Expand Down
Loading