diff --git a/Makefile b/Makefile index a71b78c..4cab0dc 100644 --- a/Makefile +++ b/Makefile @@ -85,7 +85,7 @@ format: ## Format Go source files test: $(GINKGO) download-tokenizer download-zmq ## Run tests @printf "\033[33;1m==== Running tests ====\033[0m\n" ifdef GINKGO_FOCUS - CGO_ENABLED=1 ginkgo -ldflags="$(GO_LDFLAGS)" -v -r -- -ginkgo.v -ginkgo.focus="$(GINKGO_FOCUS)" + CGO_ENABLED=1 $(GINKGO) -ldflags="$(GO_LDFLAGS)" -v -r -- -ginkgo.v -ginkgo.focus="$(GINKGO_FOCUS)" else CGO_ENABLED=1 $(GINKGO) -ldflags="$(GO_LDFLAGS)" -v -r endif diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index c58f6e9..5a2dd11 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -570,8 +570,9 @@ func (s *VllmSimulator) responseSentCallback(model string, isChatCompletion bool // usageData - usage (tokens statistics) for this response // modelName - display name returned to the client and used in metrics. It is either the first alias // from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request). +// numCompletionOptions - number of choices to return in the response. func (s *VllmSimulator) createCompletionResponse(logprobs *int, isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall, - finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool) openaiserverapi.CompletionResponse { + finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool, numCompletionOptions *int) openaiserverapi.CompletionResponse { baseResp := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+s.random.GenerateUUIDString(), time.Now().Unix(), modelName, usageData) @@ -588,8 +589,6 @@ func (s *VllmSimulator) createCompletionResponse(logprobs *int, isChatCompletion baseResp.KVParams.TPSize = 1 } - baseChoice := openaiserverapi.CreateBaseResponseChoice(0, finishReason) - respText := strings.Join(respTokens, "") if isChatCompletion { baseResp.Object = chatCompletionObject @@ -601,41 +600,51 @@ func (s *VllmSimulator) createCompletionResponse(logprobs *int, isChatCompletion message.Content = openaiserverapi.Content{Raw: respText} } - choice := openaiserverapi.CreateChatRespChoice(baseChoice, message) - - // Generate logprobs if requested - if logprobs != nil && toolCalls == nil { - if logprobsData := common.GenerateChatLogprobs(respTokens, *logprobs); logprobsData != nil && len(logprobsData.Content) > 0 { - choice.Logprobs = logprobsData + // Generate numCompletionOptions choices in the response. + choices := []openaiserverapi.ChatRespChoice{} + for i := range *numCompletionOptions { + baseChoice := openaiserverapi.CreateBaseResponseChoice(i, finishReason) + choice := openaiserverapi.CreateChatRespChoice(baseChoice, message) + // Generate logprobs if requested + if logprobs != nil && toolCalls == nil { + if logprobsData := common.GenerateChatLogprobs(respTokens, *logprobs); logprobsData != nil && len(logprobsData.Content) > 0 { + choice.Logprobs = logprobsData + } else { + // Set to nil if generation failed or content is empty + choice.Logprobs = nil + } } else { - // Set to nil if generation failed or content is empty + // Explicitly ensure logprobs is nil when not requested choice.Logprobs = nil } - } else { - // Explicitly ensure logprobs is nil when not requested - choice.Logprobs = nil + choices = append(choices, choice) } - return openaiserverapi.CreateChatCompletionResponse(baseResp, []openaiserverapi.ChatRespChoice{choice}) + return openaiserverapi.CreateChatCompletionResponse(baseResp, choices) } - choice := openaiserverapi.CreateTextRespChoice(baseChoice, respText) - - // Generate logprobs if requested for text completion - if logprobs != nil && *logprobs > 0 { - if logprobsData := common.GenerateTextLogprobs(respTokens, *logprobs); logprobsData != nil && len(logprobsData.Tokens) > 0 { - choice.Logprobs = logprobsData + // Generate numCompletionOptions choices in the response. + choices := []openaiserverapi.TextRespChoice{} + for i := range *numCompletionOptions { + baseChoice := openaiserverapi.CreateBaseResponseChoice(i, finishReason) + choice := openaiserverapi.CreateTextRespChoice(baseChoice, respText) + // Generate logprobs if requested for text completion + if logprobs != nil && *logprobs > 0 { + if logprobsData := common.GenerateTextLogprobs(respTokens, *logprobs); logprobsData != nil && len(logprobsData.Tokens) > 0 { + choice.Logprobs = logprobsData + } else { + // Set to nil if generation failed or tokens is empty + choice.Logprobs = nil + } } else { - // Set to nil if generation failed or tokens is empty + // Explicitly ensure logprobs is nil when not requested choice.Logprobs = nil } - } else { - // Explicitly ensure logprobs is nil when not requested - choice.Logprobs = nil + choices = append(choices, choice) } baseResp.Object = textCompletionObject - return openaiserverapi.CreateTextCompletionResponse(baseResp, []openaiserverapi.TextRespChoice{choice}) + return openaiserverapi.CreateTextCompletionResponse(baseResp, choices) } // sendResponse sends response for completion API, supports both completions (text and chat) @@ -655,7 +664,7 @@ func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, r } resp := s.createCompletionResponse(logprobs, reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName, - reqCtx.CompletionReq.IsDoRemoteDecode()) + reqCtx.CompletionReq.IsDoRemoteDecode(), reqCtx.CompletionReq.GetNumCompletionOptions()) // calculate how long to wait before returning the response, time is based on number of tokens nCachedPromptTokens := reqCtx.CompletionReq.GetNumberOfCachedPromptTokens() @@ -668,7 +677,11 @@ func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, r common.WriteToChannel(s.metrics.reqPrefillTimeChan, time.Since(startPrefill).Seconds(), s.logger, "metrics.reqPrefillTimeChan") startDecode := time.Now() - for range usageData.CompletionTokens - 1 { + // CompletionTokens accounts for all tokens across all choices in the response. + // Each choice is going to have the same set of tokens from the simulator, therefore + // 'preferred' choice is just the requisite share of tokens from the total CompletionTokens. + actualComplCount := usageData.CompletionTokens / *reqCtx.CompletionReq.GetNumCompletionOptions() + for range actualComplCount - 1 { perTokenLatency := s.getInterTokenLatency() time.Sleep(time.Duration(perTokenLatency) * time.Millisecond) diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index a08d946..72377d2 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -38,12 +38,15 @@ const invalidMaxTokensErrMsg = "Max completion tokens and max tokens should be p var _ = Describe("Simulator", func() { DescribeTable("chat completions streaming", - func(mode string) { + func(mode string, numCompletionOptions *int) { ctx := context.TODO() client, err := startServer(ctx, mode) Expect(err).NotTo(HaveOccurred()) openaiclient, params := getOpenAIClientAndChatParams(client, testModel, testUserMessage, true) + if numCompletionOptions != nil { + params.N = param.NewOpt(int64(*numCompletionOptions)) + } stream := openaiclient.Chat.Completions.NewStreaming(ctx, params) defer func() { err := stream.Close() @@ -53,8 +56,17 @@ var _ = Describe("Simulator", func() { role := "" var chunk openai.ChatCompletionChunk numberOfChunksWithUsage := 0 + expectedNumChoices := 0 + if numCompletionOptions == nil { + expectedNumChoices = 1 + } else { + expectedNumChoices = *numCompletionOptions + } for stream.Next() { chunk = stream.Current() + // len(chunk.Choices) >= 0 && len(chunk.Choices) <= expectedNumChoices + // We check for >= 0 because choices cane be emply if usage is enabled. + Expect(len(chunk.Choices)).To(And(BeNumerically(">=", 0), BeNumerically("<=", expectedNumChoices))) for _, choice := range chunk.Choices { if choice.Delta.Role != "" { role = choice.Delta.Role @@ -83,20 +95,26 @@ var _ = Describe("Simulator", func() { } Expect(role).Should(Equal("assistant")) }, - func(mode string) string { - return "mode: " + mode + func(mode string, numCompletionOptions *int) string { + if numCompletionOptions == nil { + return fmt.Sprintf("mode: %s, NumCompletionOptions (N): %v", mode, numCompletionOptions) + } + return fmt.Sprintf("mode: %s, NumCompletionOptions (N): %d", mode, *numCompletionOptions) }, - Entry(nil, common.ModeRandom), - Entry(nil, common.ModeEcho), + Entry(nil, common.ModeRandom, intPtr(10)), + Entry(nil, common.ModeEcho, nil), ) DescribeTable("text completions streaming", - func(mode string) { + func(mode string, numCompletionOptions *int) { ctx := context.TODO() client, err := startServer(ctx, mode) Expect(err).NotTo(HaveOccurred()) openaiclient, params := getOpenAIClentAndCompletionParams(client, testModel, testUserMessage, true) + if numCompletionOptions != nil { + params.N = param.NewOpt(int64(*numCompletionOptions)) + } stream := openaiclient.Completions.NewStreaming(ctx, params) defer func() { err := stream.Close() @@ -105,8 +123,17 @@ var _ = Describe("Simulator", func() { tokens := []string{} var chunk openai.Completion numberOfChunksWithUsage := 0 + expectedNumChoices := 0 + if numCompletionOptions == nil { + expectedNumChoices = 1 + } else { + expectedNumChoices = *numCompletionOptions + } for stream.Next() { chunk = stream.Current() + // len(chunk.Choices) >= 0 && len(chunk.Choices) <= expectedNumChoices + // We check for >= 0 because choices cane be emply if usage is enabled. + Expect(len(chunk.Choices)).To(And(BeNumerically(">=", 0), BeNumerically("<=", expectedNumChoices))) for _, choice := range chunk.Choices { if choice.FinishReason == "" { tokens = append(tokens, choice.Text) @@ -131,15 +158,18 @@ var _ = Describe("Simulator", func() { Expect(text).Should(Equal(testUserMessage)) } }, - func(mode string) string { - return "mode: " + mode + func(mode string, numCompletionOptions *int) string { + if numCompletionOptions == nil { + return fmt.Sprintf("mode: %s, NumCompletionOptions (N): %v", mode, numCompletionOptions) + } + return fmt.Sprintf("mode: %s, NumCompletionOptions (N): %d", mode, *numCompletionOptions) }, - Entry(nil, common.ModeRandom), - Entry(nil, common.ModeEcho), + Entry(nil, common.ModeRandom, intPtr(10)), + Entry(nil, common.ModeEcho, nil), ) DescribeTable("chat completions", - func(mode string, maxTokens int, maxCompletionTokens int) { + func(mode string, maxTokens int, maxCompletionTokens int, numCompletionOptions *int) { ctx := context.TODO() client, err := startServer(ctx, mode) Expect(err).NotTo(HaveOccurred()) @@ -156,6 +186,9 @@ var _ = Describe("Simulator", func() { params.MaxCompletionTokens = param.NewOpt(int64(maxCompletionTokens)) numTokens = maxCompletionTokens } + if numCompletionOptions != nil { + params.N = param.NewOpt(int64(*numCompletionOptions)) + } resp, err := openaiclient.Chat.Completions.New(ctx, params) if err != nil { var openaiError *openai.Error @@ -174,49 +207,66 @@ var _ = Describe("Simulator", func() { Expect(string(resp.Object)).To(Equal(chatCompletionObject)) Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens)) - Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0)) Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens)) - msg := resp.Choices[0].Message.Content - Expect(msg).ShouldNot(BeEmpty()) - - if mode == common.ModeEcho { - // in case of echo mode check that the text is returned as-is - Expect(msg).Should(Equal(testUserMessage)) + // At this point we've gotten back a response; test the expected + // behaviour that the default value of numCompletionOptions (N) is + // 1 if nothing is specified. + expectedNumChoices := 0 + if numCompletionOptions != nil { + expectedNumChoices = *numCompletionOptions } else { - if numTokens > 0 { - tokens := common.Tokenize(msg) - Expect(int64(len(tokens))).Should(BeNumerically("<=", numTokens)) + expectedNumChoices = 1 + } + Expect(len(resp.Choices)).To(Equal(expectedNumChoices)) + if maxTokens > 0 { + // completionTokens > 0 && completionTokens <= expectedNumChoices*maxTokens + Expect(resp.Usage.CompletionTokens).To(And(BeNumerically(">", 0), BeNumerically("<=", expectedNumChoices*maxTokens))) + } + for _, choice := range resp.Choices { + msg := choice.Message.Content + Expect(msg).ShouldNot(BeEmpty()) + if mode == common.ModeEcho { + // in case of echo mode check that the text is returned as-is + Expect(msg).Should(Equal(testUserMessage)) } else { - // in case of random mode ensure that the returned message could be output of the random text generator - Expect(dataset.IsValidText(msg)).To(BeTrue()) + if numTokens > 0 { + tokens := common.Tokenize(msg) + Expect(int64(len(tokens))).Should(BeNumerically("<=", numTokens)) + } else { + // in case of random mode ensure that the returned message could be output of the random text generator + Expect(dataset.IsValidText(msg)).To(BeTrue()) + } } } }, - func(mode string, maxTokens int, maxCompletionTokens int) string { - return fmt.Sprintf("mode: %s max_tokens: %d max_completion_tokens: %d", mode, maxTokens, maxCompletionTokens) + func(mode string, maxTokens int, maxCompletionTokens int, numCompletionOptions *int) string { + if numCompletionOptions == nil { + return fmt.Sprintf("mode: %s max_tokens: %d max_completion_tokens: %d, N: %v", mode, maxTokens, maxCompletionTokens, numCompletionOptions) + } + return fmt.Sprintf("mode: %s max_tokens: %d max_completion_tokens: %d, N: %d", mode, maxTokens, maxCompletionTokens, *numCompletionOptions) }, - Entry(nil, common.ModeRandom, 2, 0), - Entry(nil, common.ModeEcho, 2, 0), - Entry(nil, common.ModeRandom, 1000, 0), - Entry(nil, common.ModeEcho, 1000, 0), - Entry(nil, common.ModeRandom, 1000, 2), - Entry(nil, common.ModeEcho, 1000, 2), - Entry(nil, common.ModeRandom, 0, 2), - Entry(nil, common.ModeEcho, 0, 2), - Entry(nil, common.ModeRandom, 0, 1000), - Entry(nil, common.ModeEcho, 0, 1000), - Entry(nil, common.ModeRandom, 0, 0), - Entry(nil, common.ModeEcho, 0, 0), - Entry(nil, common.ModeRandom, -1, 0), - Entry(nil, common.ModeEcho, -1, 0), - Entry(nil, common.ModeRandom, 0, -1), - Entry(nil, common.ModeEcho, 0, -1), + Entry(nil, common.ModeRandom, 2, 0, intPtr(5)), + Entry(nil, common.ModeEcho, 2, 0, intPtr(5)), + Entry(nil, common.ModeRandom, 1000, 0, intPtr(5)), + Entry(nil, common.ModeEcho, 1000, 0, intPtr(10)), + Entry(nil, common.ModeRandom, 1000, 2, intPtr(10)), + Entry(nil, common.ModeEcho, 1000, 2, intPtr(1)), + Entry(nil, common.ModeRandom, 0, 2, intPtr(1)), + Entry(nil, common.ModeEcho, 0, 2, nil), + Entry(nil, common.ModeRandom, 0, 1000, nil), + Entry(nil, common.ModeEcho, 0, 1000, nil), + Entry(nil, common.ModeRandom, 0, 0, nil), + Entry(nil, common.ModeEcho, 0, 0, nil), + Entry(nil, common.ModeRandom, -1, 0, nil), + Entry(nil, common.ModeEcho, -1, 0, nil), + Entry(nil, common.ModeRandom, 0, -1, nil), + Entry(nil, common.ModeEcho, 0, -1, nil), ) DescribeTable("text completions", // use a function so that httpClient is captured when running - func(mode string, maxTokens int) { + func(mode string, maxTokens int, numCompletionOptions *int) { ctx := context.TODO() client, err := startServer(ctx, mode) Expect(err).NotTo(HaveOccurred()) @@ -227,6 +277,9 @@ var _ = Describe("Simulator", func() { params.MaxTokens = param.NewOpt(int64(maxTokens)) numTokens = maxTokens } + if numCompletionOptions != nil { + params.N = param.NewOpt(int64(*numCompletionOptions)) + } resp, err := openaiclient.Completions.New(ctx, params) if err != nil { var openaiError *openai.Error @@ -245,36 +298,54 @@ var _ = Describe("Simulator", func() { Expect(string(resp.Object)).To(Equal(textCompletionObject)) Expect(resp.Usage.PromptTokens).To(Equal(userMsgTokens)) - Expect(resp.Usage.CompletionTokens).To(BeNumerically(">", 0)) Expect(resp.Usage.TotalTokens).To(Equal(resp.Usage.PromptTokens + resp.Usage.CompletionTokens)) - text := resp.Choices[0].Text - Expect(text).ShouldNot(BeEmpty()) - - if mode == common.ModeEcho { - // in case of echo mode check that the text is returned as-is - Expect(text).Should(Equal(testUserMessage)) + // At this point we've gotten back a response; test the expected + // behaviour that the default value of numCompletionOptions (N) is + // 1 if nothing is specified. + expectedNumChoices := 0 + if numCompletionOptions != nil { + expectedNumChoices = *numCompletionOptions } else { - if numTokens != 0 { - tokens := common.Tokenize(text) - Expect(int64(len(tokens))).Should(BeNumerically("<=", numTokens)) + expectedNumChoices = 1 + } + Expect(len(resp.Choices)).To(Equal(expectedNumChoices)) + if maxTokens > 0 { + // completionTokens > 0 && completionTokens <= expectedNumChoices*maxTokens + Expect(resp.Usage.CompletionTokens).To(And(BeNumerically(">", 0), BeNumerically("<=", expectedNumChoices*maxTokens))) + } + + for _, choice := range resp.Choices { + msg := choice.Text + Expect(msg).ShouldNot(BeEmpty()) + if mode == common.ModeEcho { + // in case of echo mode check that the text is returned as-is + Expect(msg).Should(Equal(testUserMessage)) } else { - // in case of random mode ensure that the returned message could be output of the random text generator - Expect(dataset.IsValidText(text)).To(BeTrue()) + if numTokens > 0 { + tokens := common.Tokenize(msg) + Expect(int64(len(tokens))).Should(BeNumerically("<=", numTokens)) + } else { + // in case of random mode ensure that the returned message could be output of the random text generator + Expect(dataset.IsValidText(msg)).To(BeTrue()) + } } } }, - func(mode string, maxTokens int) string { - return fmt.Sprintf("mode: %s max_tokens: %d", mode, maxTokens) + func(mode string, maxTokens int, numCompletionOptions *int) string { + if numCompletionOptions == nil { + return fmt.Sprintf("mode: %s max_tokens: %d, numCompletionOptions (N): %v", mode, maxTokens, numCompletionOptions) + } + return fmt.Sprintf("mode: %s max_tokens: %d, numCompletionOptions (N): %d", mode, maxTokens, *numCompletionOptions) }, - Entry(nil, common.ModeRandom, 2), - Entry(nil, common.ModeEcho, 2), - Entry(nil, common.ModeRandom, 1000), - Entry(nil, common.ModeEcho, 1000), - Entry(nil, common.ModeRandom, 0), - Entry(nil, common.ModeEcho, 0), - Entry(nil, common.ModeRandom, -1), - Entry(nil, common.ModeEcho, -1), + Entry(nil, common.ModeRandom, 2, intPtr(5)), + Entry(nil, common.ModeEcho, 2, intPtr(5)), + Entry(nil, common.ModeRandom, 1000, intPtr(1)), + Entry(nil, common.ModeEcho, 1000, intPtr(1)), + Entry(nil, common.ModeRandom, 0, nil), + Entry(nil, common.ModeEcho, 0, nil), + Entry(nil, common.ModeRandom, -1, nil), + Entry(nil, common.ModeEcho, -1, nil), ) Context("namespace and pod headers", func() { @@ -690,3 +761,5 @@ var _ = Describe("Simulator", func() { }) }) }) + +func intPtr(a int) *int { return &a } diff --git a/pkg/llm-d-inference-sim/worker.go b/pkg/llm-d-inference-sim/worker.go index 481d185..a441a45 100644 --- a/pkg/llm-d-inference-sim/worker.go +++ b/pkg/llm-d-inference-sim/worker.go @@ -103,6 +103,7 @@ func (s *VllmSimulator) processRequestAsync(reqCtx *openaiserverapi.CompletionRe var err error var toolCalls []openaiserverapi.ToolCall var completionTokens int + var numTokensGenerated int if reqCtx.IsChatCompletion && !common.IsToolChoiceNone(req.GetToolChoice()) && req.GetTools() != nil { @@ -114,7 +115,18 @@ func (s *VllmSimulator) processRequestAsync(reqCtx *openaiserverapi.CompletionRe // Either no tool calls were defined, or we randomly chose not to create tool calls, // so we generate a response text. responseTokens, finishReason, err = s.dataset.GetTokens(req, s.config.Mode) - completionTokens += len(responseTokens) + // To account for max_tokens and max_completion_tokens, the dataset will return a + // LengthFinishReason if either of these are exceeded. Check for this reason first + // and then based on that, infer the number of tokens actually generated. + if finishReason == common.LengthFinishReason { + numTokensGenerated = int(*req.GetMaxCompletionTokens()) + } else { + numTokensGenerated = len(responseTokens) + } + // The content of each of the response choices returned will be + // the same set of response tokens from the simulator. Take that + // into account for completionTokens calculation. + completionTokens += *reqCtx.CompletionReq.GetNumCompletionOptions() * numTokensGenerated } if err != nil { prefix := "" @@ -160,12 +172,17 @@ func (s *VllmSimulator) processRequestAsync(reqCtx *openaiserverapi.CompletionRe wg.Done() } + tokensPerChoice := []int{} + for range *reqCtx.CompletionReq.GetNumCompletionOptions() { + // Currently all responses in the choices block have the same + // content. Simply duplicate those counts for the response. + tokensPerChoice = append(tokensPerChoice, numTokensGenerated) + } common.WriteToChannel(s.metrics.requestSuccessChan, requestSuccessEvent{ - promptTokens: usageData.PromptTokens, - generationTokens: usageData.CompletionTokens, - // currently only responses with a single choice are supported - genTokensPerChoice: []int{usageData.CompletionTokens}, + promptTokens: usageData.PromptTokens, + generationTokens: usageData.CompletionTokens, + genTokensPerChoice: tokensPerChoice, maxTokens: reqCtx.CompletionReq.GetMaxCompletionTokens(), finishReason: finishReason}, s.logger, "metrics.requestSuccessChan") diff --git a/pkg/openai-server-api/request.go b/pkg/openai-server-api/request.go index 80bc4cf..ede564d 100644 --- a/pkg/openai-server-api/request.go +++ b/pkg/openai-server-api/request.go @@ -79,6 +79,8 @@ type CompletionRequest interface { ExtractMaxTokens() *int64 // GetLogprobs returns nil if no logprobs needed, or pointer to number of logprob options to include GetLogprobs() *int + // GetNumCompletionOptions returns the number of chat completion options requested. + GetNumCompletionOptions() *int } // baseCompletionRequest contains base completion request related information @@ -212,6 +214,11 @@ type ChatCompletionRequest struct { // TopLogprobs controls how many alternative tokens to include in the logprobs TopLogprobs *int `json:"top_logprobs,omitempty"` + + // NumCompletionOptions is the number of chat completion choices to generate. + // + // Optional and defaults to 1. + NumCompletionOptions *int `json:"n"` } var _ CompletionRequest = (*ChatCompletionRequest)(nil) @@ -308,6 +315,16 @@ func (c *ChatCompletionRequest) GetLogprobs() *int { return &defaultVal } +func (c *ChatCompletionRequest) GetNumCompletionOptions() *int { + // If not specified, default to 1. + if c.NumCompletionOptions == nil { + n := new(int) + *n = 1 + return n + } + return c.NumCompletionOptions +} + // v1/completion // TextCompletionRequest defines structure of /completion request type TextCompletionRequest struct { @@ -327,6 +344,11 @@ type TextCompletionRequest struct { // a list of the 5 most likely tokens. The API will always return the logprob // of the sampled token, so there may be up to logprobs+1 elements in the response. Logprobs *int `json:"logprobs,omitempty"` + + // NumCompletionOptions is the number of chat completion choices to generate. + // + // Optional and defaults to 1. + NumCompletionOptions *int `json:"n"` } var _ CompletionRequest = (*TextCompletionRequest)(nil) @@ -366,3 +388,13 @@ func (req *TextCompletionRequest) ExtractMaxTokens() *int64 { func (t *TextCompletionRequest) GetLogprobs() *int { return t.Logprobs } + +func (t *TextCompletionRequest) GetNumCompletionOptions() *int { + // If not specified, default to 1. + if t.NumCompletionOptions == nil { + n := new(int) + *n = 1 + return n + } + return t.NumCompletionOptions +}