Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ format: ## Format Go source files
test: $(GINKGO) download-tokenizer download-zmq ## Run tests
@printf "\033[33;1m==== Running tests ====\033[0m\n"
ifdef GINKGO_FOCUS
CGO_ENABLED=1 ginkgo -ldflags="$(GO_LDFLAGS)" -v -r -- -ginkgo.v -ginkgo.focus="$(GINKGO_FOCUS)"
CGO_ENABLED=1 $(GINKGO) -ldflags="$(GO_LDFLAGS)" -v -r -- -ginkgo.v -ginkgo.focus="$(GINKGO_FOCUS)"
else
CGO_ENABLED=1 $(GINKGO) -ldflags="$(GO_LDFLAGS)" -v -r
endif
Expand Down
67 changes: 40 additions & 27 deletions pkg/llm-d-inference-sim/simulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,9 @@ func (s *VllmSimulator) responseSentCallback(model string, isChatCompletion bool
// usageData - usage (tokens statistics) for this response
// modelName - display name returned to the client and used in metrics. It is either the first alias
// from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request).
// numCompletionOptions - number of choices to return in the response.
func (s *VllmSimulator) createCompletionResponse(logprobs *int, isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall,
finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool) openaiserverapi.CompletionResponse {
finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool, numCompletionOptions *int) openaiserverapi.CompletionResponse {
baseResp := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(),
time.Now().Unix(), modelName, usageData)

Expand All @@ -588,8 +589,6 @@ func (s *VllmSimulator) createCompletionResponse(logprobs *int, isChatCompletion
baseResp.KVParams.TPSize = 1
}

baseChoice := openaiserverapi.CreateBaseResponseChoice(0, finishReason)

respText := strings.Join(respTokens, "")
if isChatCompletion {
baseResp.Object = chatCompletionObject
Expand All @@ -601,41 +600,51 @@ func (s *VllmSimulator) createCompletionResponse(logprobs *int, isChatCompletion
message.Content = openaiserverapi.Content{Raw: respText}
}

choice := openaiserverapi.CreateChatRespChoice(baseChoice, message)

// Generate logprobs if requested
if logprobs != nil && toolCalls == nil {
if logprobsData := common.GenerateChatLogprobs(respTokens, *logprobs); logprobsData != nil && len(logprobsData.Content) > 0 {
choice.Logprobs = logprobsData
// Generate numCompletionOptions choices in the response.
choices := []openaiserverapi.ChatRespChoice{}
for i := range *numCompletionOptions {
baseChoice := openaiserverapi.CreateBaseResponseChoice(i, finishReason)
choice := openaiserverapi.CreateChatRespChoice(baseChoice, message)
// Generate logprobs if requested
if logprobs != nil && toolCalls == nil {
if logprobsData := common.GenerateChatLogprobs(respTokens, *logprobs); logprobsData != nil && len(logprobsData.Content) > 0 {
choice.Logprobs = logprobsData
} else {
// Set to nil if generation failed or content is empty
choice.Logprobs = nil
}
} else {
// Set to nil if generation failed or content is empty
// Explicitly ensure logprobs is nil when not requested
choice.Logprobs = nil
}
} else {
// Explicitly ensure logprobs is nil when not requested
choice.Logprobs = nil
choices = append(choices, choice)
}

return openaiserverapi.CreateChatCompletionResponse(baseResp, []openaiserverapi.ChatRespChoice{choice})
return openaiserverapi.CreateChatCompletionResponse(baseResp, choices)
}

choice := openaiserverapi.CreateTextRespChoice(baseChoice, respText)

// Generate logprobs if requested for text completion
if logprobs != nil && *logprobs > 0 {
if logprobsData := common.GenerateTextLogprobs(respTokens, *logprobs); logprobsData != nil && len(logprobsData.Tokens) > 0 {
choice.Logprobs = logprobsData
// Generate numCompletionOptions choices in the response.
choices := []openaiserverapi.TextRespChoice{}
for i := range *numCompletionOptions {
baseChoice := openaiserverapi.CreateBaseResponseChoice(i, finishReason)
choice := openaiserverapi.CreateTextRespChoice(baseChoice, respText)
// Generate logprobs if requested for text completion
if logprobs != nil && *logprobs > 0 {
if logprobsData := common.GenerateTextLogprobs(respTokens, *logprobs); logprobsData != nil && len(logprobsData.Tokens) > 0 {
choice.Logprobs = logprobsData
} else {
// Set to nil if generation failed or tokens is empty
choice.Logprobs = nil
}
} else {
// Set to nil if generation failed or tokens is empty
// Explicitly ensure logprobs is nil when not requested
choice.Logprobs = nil
}
} else {
// Explicitly ensure logprobs is nil when not requested
choice.Logprobs = nil
choices = append(choices, choice)
}

baseResp.Object = textCompletionObject
return openaiserverapi.CreateTextCompletionResponse(baseResp, []openaiserverapi.TextRespChoice{choice})
return openaiserverapi.CreateTextCompletionResponse(baseResp, choices)
}

// sendResponse sends response for completion API, supports both completions (text and chat)
Expand All @@ -655,7 +664,7 @@ func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, r
}

resp := s.createCompletionResponse(logprobs, reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName,
reqCtx.CompletionReq.IsDoRemoteDecode())
reqCtx.CompletionReq.IsDoRemoteDecode(), reqCtx.CompletionReq.GetNumCompletionOptions())

// calculate how long to wait before returning the response, time is based on number of tokens
nCachedPromptTokens := reqCtx.CompletionReq.GetNumberOfCachedPromptTokens()
Expand All @@ -668,7 +677,11 @@ func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, r
common.WriteToChannel(s.metrics.reqPrefillTimeChan, time.Since(startPrefill).Seconds(), s.logger, "metrics.reqPrefillTimeChan")

startDecode := time.Now()
for range usageData.CompletionTokens - 1 {
// CompletionTokens accounts for all tokens across all choices in the response.
// Each choice is going to have the same set of tokens from the simulator, therefore
// 'preferred' choice is just the requisite share of tokens from the total CompletionTokens.
actualComplCount := usageData.CompletionTokens / *reqCtx.CompletionReq.GetNumCompletionOptions()
for range actualComplCount - 1 {
perTokenLatency := s.getInterTokenLatency()
time.Sleep(time.Duration(perTokenLatency) * time.Millisecond)

Expand Down
Loading
Loading