@@ -166,6 +166,34 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, const
166166 SRV_INF (" Extracted tool_choice as string: %s\n " , predict->toolchoice ().c_str ());
167167 }
168168 }
169+
170+ // Extract logprobs and top_logprobs from proto and add to JSON data
171+ // Following server.cpp pattern: logprobs maps to n_probs when provided
172+ if (predict->logprobs () > 0 ) {
173+ data[" logprobs" ] = predict->logprobs ();
174+ // Map logprobs to n_probs (following server.cpp line 369 pattern)
175+ // n_probs will be set by params_from_json_cmpl if logprobs is provided
176+ data[" n_probs" ] = predict->logprobs ();
177+ SRV_INF (" Using logprobs: %d\n " , predict->logprobs ());
178+ }
179+ if (predict->toplogprobs () > 0 ) {
180+ data[" top_logprobs" ] = predict->toplogprobs ();
181+ SRV_INF (" Using top_logprobs: %d\n " , predict->toplogprobs ());
182+ }
183+
184+ // Extract logit_bias from proto and add to JSON data
185+ if (!predict->logitbias ().empty ()) {
186+ try {
187+ // Parse logit_bias JSON string from proto
188+ json logit_bias_json = json::parse (predict->logitbias ());
189+ // Add to data - llama.cpp server expects it as an object (map)
190+ data[" logit_bias" ] = logit_bias_json;
191+ SRV_INF (" Using logit_bias: %s\n " , predict->logitbias ().c_str ());
192+ } catch (const json::parse_error& e) {
193+ SRV_ERR (" Failed to parse logit_bias JSON from proto: %s\n " , e.what ());
194+ }
195+ }
196+
169197 data[" ignore_eos" ] = predict->ignoreeos ();
170198 data[" embeddings" ] = predict->embeddings ();
171199
@@ -568,6 +596,28 @@ class BackendServiceImpl final : public backend::Backend::Service {
568596 return Status::OK;
569597 }
570598
599+ // Helper function to extract logprobs from JSON response
600+ static json extract_logprobs_from_json (const json& res_json) {
601+ json logprobs_json = json::object ();
602+
603+ // Check for OAI-compatible format: choices[0].logprobs
604+ if (res_json.contains (" choices" ) && res_json[" choices" ].is_array () &&
605+ res_json[" choices" ].size () > 0 && res_json[" choices" ][0 ].contains (" logprobs" )) {
606+ logprobs_json = res_json[" choices" ][0 ][" logprobs" ];
607+ }
608+ // Check for non-OAI format: completion_probabilities
609+ else if (res_json.contains (" completion_probabilities" )) {
610+ // Convert completion_probabilities to OAI format
611+ logprobs_json[" content" ] = res_json[" completion_probabilities" ];
612+ }
613+ // Check for direct logprobs field
614+ else if (res_json.contains (" logprobs" )) {
615+ logprobs_json = res_json[" logprobs" ];
616+ }
617+
618+ return logprobs_json;
619+ }
620+
571621 grpc::Status PredictStream (grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
572622 json data = parse_options (true , request, ctx_server);
573623
@@ -915,6 +965,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
915965 reply.set_timing_token_generation (timing_token_generation);
916966 }
917967
968+ // Extract and set logprobs if present
969+ json logprobs_json = extract_logprobs_from_json (res);
970+ if (!logprobs_json.empty () && !logprobs_json.is_null ()) {
971+ std::string logprobs_str = logprobs_json.dump ();
972+ reply.set_logprobs (logprobs_str);
973+ }
974+
918975 writer->Write (reply);
919976 }
920977 } else {
@@ -934,6 +991,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
934991 reply.set_timing_token_generation (timing_token_generation);
935992 }
936993
994+ // Extract and set logprobs if present
995+ json logprobs_json = extract_logprobs_from_json (first_res_json);
996+ if (!logprobs_json.empty () && !logprobs_json.is_null ()) {
997+ std::string logprobs_str = logprobs_json.dump ();
998+ reply.set_logprobs (logprobs_str);
999+ }
1000+
9371001 writer->Write (reply);
9381002 }
9391003
@@ -969,6 +1033,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
9691033 reply.set_timing_token_generation (timing_token_generation);
9701034 }
9711035
1036+ // Extract and set logprobs if present
1037+ json logprobs_json = extract_logprobs_from_json (res);
1038+ if (!logprobs_json.empty () && !logprobs_json.is_null ()) {
1039+ std::string logprobs_str = logprobs_json.dump ();
1040+ reply.set_logprobs (logprobs_str);
1041+ }
1042+
9721043 writer->Write (reply);
9731044 }
9741045 } else {
@@ -988,6 +1059,13 @@ class BackendServiceImpl final : public backend::Backend::Service {
9881059 reply.set_timing_token_generation (timing_token_generation);
9891060 }
9901061
1062+ // Extract and set logprobs if present
1063+ json logprobs_json = extract_logprobs_from_json (res_json);
1064+ if (!logprobs_json.empty () && !logprobs_json.is_null ()) {
1065+ std::string logprobs_str = logprobs_json.dump ();
1066+ reply.set_logprobs (logprobs_str);
1067+ }
1068+
9911069 writer->Write (reply);
9921070 }
9931071 }
@@ -1335,28 +1413,54 @@ class BackendServiceImpl final : public backend::Backend::Service {
13351413 if (all_results.results .size () == 1 ) {
13361414 // single result
13371415 GGML_ASSERT (dynamic_cast <server_task_result_cmpl_final*>(all_results.results [0 ].get ()) != nullptr );
1338- reply->set_message (all_results.results [0 ]->to_json ().value (" content" , " " ));
1416+ json result_json = all_results.results [0 ]->to_json ();
1417+ reply->set_message (result_json.value (" content" , " " ));
13391418
1340- int32_t tokens_predicted = all_results. results [ 0 ]-> to_json () .value (" tokens_predicted" , 0 );
1419+ int32_t tokens_predicted = result_json .value (" tokens_predicted" , 0 );
13411420 reply->set_tokens (tokens_predicted);
1342- int32_t tokens_evaluated = all_results. results [ 0 ]-> to_json () .value (" tokens_evaluated" , 0 );
1421+ int32_t tokens_evaluated = result_json .value (" tokens_evaluated" , 0 );
13431422 reply->set_prompt_tokens (tokens_evaluated);
13441423
1345- if (all_results. results [ 0 ]-> to_json () .contains (" timings" )) {
1346- double timing_prompt_processing = all_results. results [ 0 ]-> to_json () .at (" timings" ).value (" prompt_ms" , 0.0 );
1424+ if (result_json .contains (" timings" )) {
1425+ double timing_prompt_processing = result_json .at (" timings" ).value (" prompt_ms" , 0.0 );
13471426 reply->set_timing_prompt_processing (timing_prompt_processing);
1348- double timing_token_generation = all_results. results [ 0 ]-> to_json () .at (" timings" ).value (" predicted_ms" , 0.0 );
1427+ double timing_token_generation = result_json .at (" timings" ).value (" predicted_ms" , 0.0 );
13491428 reply->set_timing_token_generation (timing_token_generation);
13501429 }
13511430
1431+ // Extract and set logprobs if present
1432+ json logprobs_json = extract_logprobs_from_json (result_json);
1433+ if (!logprobs_json.empty () && !logprobs_json.is_null ()) {
1434+ std::string logprobs_str = logprobs_json.dump ();
1435+ reply->set_logprobs (logprobs_str);
1436+ }
1437+
13521438 } else {
13531439 // multiple results (multitask)
13541440 json arr = json::array ();
1441+ json logprobs_arr = json::array ();
1442+ bool has_logprobs = false ;
13551443 for (auto & res : all_results.results ) {
13561444 GGML_ASSERT (dynamic_cast <server_task_result_cmpl_final*>(res.get ()) != nullptr );
1357- arr.push_back (res->to_json ().value (" content" , " " ));
1445+ json res_json = res->to_json ();
1446+ arr.push_back (res_json.value (" content" , " " ));
1447+
1448+ // Extract logprobs for each result
1449+ json logprobs_json = extract_logprobs_from_json (res_json);
1450+ if (!logprobs_json.empty () && !logprobs_json.is_null ()) {
1451+ has_logprobs = true ;
1452+ logprobs_arr.push_back (logprobs_json);
1453+ } else {
1454+ logprobs_arr.push_back (json::object ());
1455+ }
13581456 }
13591457 reply->set_message (arr);
1458+
1459+ // Set logprobs if any result has them
1460+ if (has_logprobs) {
1461+ std::string logprobs_str = logprobs_arr.dump ();
1462+ reply->set_logprobs (logprobs_str);
1463+ }
13601464 }
13611465 }
13621466
0 commit comments