Skip to content

Commit d7f9f3a

Browse files
authored
feat: add support to logitbias and logprobs (#7283)
* feat: add support to logprobs in results Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat: add support to logitbias Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
1 parent cd7d384 commit d7f9f3a

File tree

10 files changed

+385
-12
lines changed

10 files changed

+385
-12
lines changed

backend/backend.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ message PredictOptions {
156156
string CorrelationId = 47;
157157
string Tools = 48; // JSON array of available tools/functions for tool calling
158158
string ToolChoice = 49; // JSON string or object specifying tool choice behavior
159+
int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter)
160+
int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter)
159161
}
160162

161163
// The response message containing the result
@@ -166,6 +168,7 @@ message Reply {
166168
double timing_prompt_processing = 4;
167169
double timing_token_generation = 5;
168170
bytes audio = 6;
171+
bytes logprobs = 7; // JSON-encoded logprobs data matching OpenAI format
169172
}
170173

171174
message GrammarTrigger {

backend/cpp/llama-cpp/grpc-server.cpp

Lines changed: 111 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,34 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
166166
SRV_INF("Extracted tool_choice as string: %s\n", predict->toolchoice().c_str());
167167
}
168168
}
169+
170+
// Extract logprobs and top_logprobs from proto and add to JSON data
171+
// Following server.cpp pattern: logprobs maps to n_probs when provided
172+
if (predict->logprobs() > 0) {
173+
data["logprobs"] = predict->logprobs();
174+
// Map logprobs to n_probs (following server.cpp line 369 pattern)
175+
// n_probs will be set by params_from_json_cmpl if logprobs is provided
176+
data["n_probs"] = predict->logprobs();
177+
SRV_INF("Using logprobs: %d\n", predict->logprobs());
178+
}
179+
if (predict->toplogprobs() > 0) {
180+
data["top_logprobs"] = predict->toplogprobs();
181+
SRV_INF("Using top_logprobs: %d\n", predict->toplogprobs());
182+
}
183+
184+
// Extract logit_bias from proto and add to JSON data
185+
if (!predict->logitbias().empty()) {
186+
try {
187+
// Parse logit_bias JSON string from proto
188+
json logit_bias_json = json::parse(predict->logitbias());
189+
// Add to data - llama.cpp server expects it as an object (map)
190+
data["logit_bias"] = logit_bias_json;
191+
SRV_INF("Using logit_bias: %s\n", predict->logitbias().c_str());
192+
} catch (const json::parse_error& e) {
193+
SRV_ERR("Failed to parse logit_bias JSON from proto: %s\n", e.what());
194+
}
195+
}
196+
169197
data["ignore_eos"] = predict->ignoreeos();
170198
data["embeddings"] = predict->embeddings();
171199

@@ -568,6 +596,28 @@ class BackendServiceImpl final : public backend::Backend::Service {
568596
return Status::OK;
569597
}
570598

599+
// Helper function to extract logprobs from JSON response
600+
static json extract_logprobs_from_json(const json& res_json) {
601+
json logprobs_json = json::object();
602+
603+
// Check for OAI-compatible format: choices[0].logprobs
604+
if (res_json.contains("choices") && res_json["choices"].is_array() &&
605+
res_json["choices"].size() > 0 && res_json["choices"][0].contains("logprobs")) {
606+
logprobs_json = res_json["choices"][0]["logprobs"];
607+
}
608+
// Check for non-OAI format: completion_probabilities
609+
else if (res_json.contains("completion_probabilities")) {
610+
// Convert completion_probabilities to OAI format
611+
logprobs_json["content"] = res_json["completion_probabilities"];
612+
}
613+
// Check for direct logprobs field
614+
else if (res_json.contains("logprobs")) {
615+
logprobs_json = res_json["logprobs"];
616+
}
617+
618+
return logprobs_json;
619+
}
620+
571621
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
572622
json data = parse_options(true, request, ctx_server);
573623

@@ -915,6 +965,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
915965
reply.set_timing_token_generation(timing_token_generation);
916966
}
917967

968+
// Extract and set logprobs if present
969+
json logprobs_json = extract_logprobs_from_json(res);
970+
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
971+
std::string logprobs_str = logprobs_json.dump();
972+
reply.set_logprobs(logprobs_str);
973+
}
974+
918975
writer->Write(reply);
919976
}
920977
} else {
@@ -934,6 +991,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
934991
reply.set_timing_token_generation(timing_token_generation);
935992
}
936993

994+
// Extract and set logprobs if present
995+
json logprobs_json = extract_logprobs_from_json(first_res_json);
996+
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
997+
std::string logprobs_str = logprobs_json.dump();
998+
reply.set_logprobs(logprobs_str);
999+
}
1000+
9371001
writer->Write(reply);
9381002
}
9391003

@@ -969,6 +1033,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
9691033
reply.set_timing_token_generation(timing_token_generation);
9701034
}
9711035

1036+
// Extract and set logprobs if present
1037+
json logprobs_json = extract_logprobs_from_json(res);
1038+
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
1039+
std::string logprobs_str = logprobs_json.dump();
1040+
reply.set_logprobs(logprobs_str);
1041+
}
1042+
9721043
writer->Write(reply);
9731044
}
9741045
} else {
@@ -988,6 +1059,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
9881059
reply.set_timing_token_generation(timing_token_generation);
9891060
}
9901061

1062+
// Extract and set logprobs if present
1063+
json logprobs_json = extract_logprobs_from_json(res_json);
1064+
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
1065+
std::string logprobs_str = logprobs_json.dump();
1066+
reply.set_logprobs(logprobs_str);
1067+
}
1068+
9911069
writer->Write(reply);
9921070
}
9931071
}
@@ -1335,28 +1413,54 @@ class BackendServiceImpl final : public backend::Backend::Service {
13351413
if (all_results.results.size() == 1) {
13361414
// single result
13371415
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get()) != nullptr);
1338-
reply->set_message(all_results.results[0]->to_json().value("content", ""));
1416+
json result_json = all_results.results[0]->to_json();
1417+
reply->set_message(result_json.value("content", ""));
13391418

1340-
int32_t tokens_predicted = all_results.results[0]->to_json().value("tokens_predicted", 0);
1419+
int32_t tokens_predicted = result_json.value("tokens_predicted", 0);
13411420
reply->set_tokens(tokens_predicted);
1342-
int32_t tokens_evaluated = all_results.results[0]->to_json().value("tokens_evaluated", 0);
1421+
int32_t tokens_evaluated = result_json.value("tokens_evaluated", 0);
13431422
reply->set_prompt_tokens(tokens_evaluated);
13441423

1345-
if (all_results.results[0]->to_json().contains("timings")) {
1346-
double timing_prompt_processing = all_results.results[0]->to_json().at("timings").value("prompt_ms", 0.0);
1424+
if (result_json.contains("timings")) {
1425+
double timing_prompt_processing = result_json.at("timings").value("prompt_ms", 0.0);
13471426
reply->set_timing_prompt_processing(timing_prompt_processing);
1348-
double timing_token_generation = all_results.results[0]->to_json().at("timings").value("predicted_ms", 0.0);
1427+
double timing_token_generation = result_json.at("timings").value("predicted_ms", 0.0);
13491428
reply->set_timing_token_generation(timing_token_generation);
13501429
}
13511430

1431+
// Extract and set logprobs if present
1432+
json logprobs_json = extract_logprobs_from_json(result_json);
1433+
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
1434+
std::string logprobs_str = logprobs_json.dump();
1435+
reply->set_logprobs(logprobs_str);
1436+
}
1437+
13521438
} else {
13531439
// multiple results (multitask)
13541440
json arr = json::array();
1441+
json logprobs_arr = json::array();
1442+
bool has_logprobs = false;
13551443
for (auto & res : all_results.results) {
13561444
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
1357-
arr.push_back(res->to_json().value("content", ""));
1445+
json res_json = res->to_json();
1446+
arr.push_back(res_json.value("content", ""));
1447+
1448+
// Extract logprobs for each result
1449+
json logprobs_json = extract_logprobs_from_json(res_json);
1450+
if (!logprobs_json.empty() && !logprobs_json.is_null()) {
1451+
has_logprobs = true;
1452+
logprobs_arr.push_back(logprobs_json);
1453+
} else {
1454+
logprobs_arr.push_back(json::object());
1455+
}
13581456
}
13591457
reply->set_message(arr);
1458+
1459+
// Set logprobs if any result has them
1460+
if (has_logprobs) {
1461+
std::string logprobs_str = logprobs_arr.dump();
1462+
reply->set_logprobs(logprobs_str);
1463+
}
13601464
}
13611465
}
13621466

core/backend/llm.go

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package backend
22

33
import (
44
"context"
5+
"encoding/json"
56
"regexp"
67
"slices"
78
"strings"
@@ -24,6 +25,7 @@ type LLMResponse struct {
2425
Response string // should this be []byte?
2526
Usage TokenUsage
2627
AudioOutput string
28+
Logprobs *schema.Logprobs // Logprobs from the backend response
2729
}
2830

2931
type TokenUsage struct {
@@ -33,7 +35,7 @@ type TokenUsage struct {
3335
TimingTokenGeneration float64
3436
}
3537

36-
func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string) (func() (LLMResponse, error), error) {
38+
func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (LLMResponse, error), error) {
3739
modelFile := c.Model
3840

3941
// Check if the modelFile exists, if it doesn't try to load it from the gallery
@@ -78,6 +80,19 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
7880
opts.Audios = audios
7981
opts.Tools = tools
8082
opts.ToolChoice = toolChoice
83+
if logprobs != nil {
84+
opts.Logprobs = int32(*logprobs)
85+
}
86+
if topLogprobs != nil {
87+
opts.TopLogprobs = int32(*topLogprobs)
88+
}
89+
if len(logitBias) > 0 {
90+
// Serialize logit_bias map to JSON string for proto
91+
logitBiasJSON, err := json.Marshal(logitBias)
92+
if err == nil {
93+
opts.LogitBias = string(logitBiasJSON)
94+
}
95+
}
8196

8297
tokenUsage := TokenUsage{}
8398

@@ -109,6 +124,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
109124
}
110125

111126
ss := ""
127+
var logprobs *schema.Logprobs
112128

113129
var partialRune []byte
114130
err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) {
@@ -120,6 +136,14 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
120136
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
121137
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
122138

139+
// Parse logprobs from reply if present (collect from last chunk that has them)
140+
if len(reply.Logprobs) > 0 {
141+
var parsedLogprobs schema.Logprobs
142+
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
143+
logprobs = &parsedLogprobs
144+
}
145+
}
146+
123147
// Process complete runes and accumulate them
124148
var completeRunes []byte
125149
for len(partialRune) > 0 {
@@ -145,6 +169,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
145169
return LLMResponse{
146170
Response: ss,
147171
Usage: tokenUsage,
172+
Logprobs: logprobs,
148173
}, err
149174
} else {
150175
// TODO: Is the chicken bit the only way to get here? is that acceptable?
@@ -167,9 +192,19 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
167192
response = c.TemplateConfig.ReplyPrefix + response
168193
}
169194

195+
// Parse logprobs from reply if present
196+
var logprobs *schema.Logprobs
197+
if len(reply.Logprobs) > 0 {
198+
var parsedLogprobs schema.Logprobs
199+
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
200+
logprobs = &parsedLogprobs
201+
}
202+
}
203+
170204
return LLMResponse{
171205
Response: response,
172206
Usage: tokenUsage,
207+
Logprobs: logprobs,
173208
}, err
174209
}
175210
}

core/backend/options.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
212212
}
213213
}
214214

215-
return &pb.PredictOptions{
215+
pbOpts := &pb.PredictOptions{
216216
Temperature: float32(*c.Temperature),
217217
TopP: float32(*c.TopP),
218218
NDraft: c.NDraft,
@@ -249,4 +249,6 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
249249
TailFreeSamplingZ: float32(*c.TFZ),
250250
TypicalP: float32(*c.TypicalP),
251251
}
252+
// Logprobs and TopLogprobs are set by the caller if provided
253+
return pbOpts
252254
}

core/http/app_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,83 @@ var _ = Describe("API test", func() {
816816
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
817817
})
818818

819+
It("returns logprobs in chat completions when requested", func() {
820+
topLogprobsVal := 3
821+
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
822+
Model: "testmodel.ggml",
823+
LogProbs: true,
824+
TopLogProbs: topLogprobsVal,
825+
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
826+
Expect(err).ToNot(HaveOccurred())
827+
828+
Expect(len(response.Choices)).To(Equal(1))
829+
Expect(response.Choices[0].Message).ToNot(BeNil())
830+
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())
831+
832+
// Verify logprobs are present and have correct structure
833+
Expect(response.Choices[0].LogProbs).ToNot(BeNil())
834+
Expect(response.Choices[0].LogProbs.Content).ToNot(BeEmpty())
835+
836+
Expect(len(response.Choices[0].LogProbs.Content)).To(BeNumerically(">", 1))
837+
838+
foundatLeastToken := ""
839+
foundAtLeastBytes := []byte{}
840+
foundAtLeastTopLogprobBytes := []byte{}
841+
foundatLeastTopLogprob := ""
842+
// Verify logprobs content structure matches OpenAI format
843+
for _, logprobContent := range response.Choices[0].LogProbs.Content {
844+
// Bytes can be empty for certain tokens (special tokens, etc.), so we don't require it
845+
if len(logprobContent.Bytes) > 0 {
846+
foundAtLeastBytes = logprobContent.Bytes
847+
}
848+
if len(logprobContent.Token) > 0 {
849+
foundatLeastToken = logprobContent.Token
850+
}
851+
Expect(logprobContent.LogProb).To(BeNumerically("<=", 0)) // Logprobs are always <= 0
852+
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically(">", 1))
853+
854+
// If top_logprobs is requested, verify top_logprobs array respects the limit
855+
if len(logprobContent.TopLogProbs) > 0 {
856+
// Should respect top_logprobs limit (3 in this test)
857+
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically("<=", topLogprobsVal))
858+
for _, topLogprob := range logprobContent.TopLogProbs {
859+
if len(topLogprob.Bytes) > 0 {
860+
foundAtLeastTopLogprobBytes = topLogprob.Bytes
861+
}
862+
if len(topLogprob.Token) > 0 {
863+
foundatLeastTopLogprob = topLogprob.Token
864+
}
865+
Expect(topLogprob.LogProb).To(BeNumerically("<=", 0))
866+
}
867+
}
868+
}
869+
870+
Expect(foundAtLeastBytes).ToNot(BeEmpty())
871+
Expect(foundAtLeastTopLogprobBytes).ToNot(BeEmpty())
872+
Expect(foundatLeastToken).ToNot(BeEmpty())
873+
Expect(foundatLeastTopLogprob).ToNot(BeEmpty())
874+
})
875+
876+
It("applies logit_bias to chat completions when requested", func() {
877+
// logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
878+
// According to OpenAI API: modifies the likelihood of specified tokens appearing in the completion
879+
logitBias := map[string]int{
880+
"15043": 1, // Bias token ID 15043 (example token ID) with bias value 1
881+
}
882+
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
883+
Model: "testmodel.ggml",
884+
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}},
885+
LogitBias: logitBias,
886+
})
887+
Expect(err).ToNot(HaveOccurred())
888+
Expect(len(response.Choices)).To(Equal(1))
889+
Expect(response.Choices[0].Message).ToNot(BeNil())
890+
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())
891+
// If logit_bias is applied, the response should be generated successfully
892+
// We can't easily verify the bias effect without knowing the actual token IDs for the model,
893+
// but the fact that the request succeeds confirms the API accepts and processes logit_bias
894+
})
895+
819896
It("returns errors", func() {
820897
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: testPrompt})
821898
Expect(err).To(HaveOccurred())

0 commit comments

Comments
 (0)