Skip to content

Commit f9134da

Browse files
committed
feat: support logit bias in chat request
1 parent e7329fe commit f9134da

File tree

18 files changed

+363
-10
lines changed

18 files changed

+363
-10
lines changed

backends/client/src/v3/client.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext;
77
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
88
use pb::generate::v3::*;
99
use std::cmp::min;
10+
use std::collections::HashMap;
1011
use std::time::Duration;
1112
use tonic::transport::{Channel, Uri};
1213
use tracing::instrument;
@@ -181,6 +182,7 @@ impl Client {
181182
watermark: true,
182183
grammar: String::new(),
183184
grammar_type: GrammarType::None as i32,
185+
logit_bias: HashMap::new(),
184186
}),
185187
stopping_parameters: Some(StoppingCriteriaParameters {
186188
max_new_tokens,

backends/client/src/v3/sharded_client.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::{ClientError, Result};
55
use crate::v3::{Chunk, InfoResponse, Input};
66
use async_trait::async_trait;
77
use futures::future::join_all;
8+
use std::collections::HashMap;
89
use tonic::transport::Uri;
910
use tracing::instrument;
1011
use v3::client::{DecodeTimings, PrefillTimings};
@@ -244,6 +245,7 @@ impl Health for ShardedClient {
244245
watermark: false,
245246
grammar: String::new(),
246247
grammar_type: GrammarType::None as i32,
248+
logit_bias: HashMap::new(),
247249
}),
248250
stopping_parameters: Some(StoppingCriteriaParameters {
249251
max_new_tokens: 1,

backends/v3/src/client/grpc_client.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use grpc_metadata::InjectTelemetryContext;
77
use pb::generate::v3::text_generation_service_client::TextGenerationServiceClient;
88
use pb::generate::v3::*;
99
use std::cmp::min;
10+
use std::collections::HashMap;
1011
use std::time::Duration;
1112
use tonic::transport::{Channel, Uri};
1213
use tracing::instrument;
@@ -181,6 +182,7 @@ impl Client {
181182
watermark: true,
182183
grammar: String::new(),
183184
grammar_type: GrammarType::None as i32,
185+
logit_bias: HashMap::new(),
184186
}),
185187
stopping_parameters: Some(StoppingCriteriaParameters {
186188
max_new_tokens,

backends/v3/src/client/sharded_client.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use crate::client::{
1010
use crate::client::{Chunk, InfoResponse, Input};
1111
use async_trait::async_trait;
1212
use futures::future::join_all;
13+
use std::collections::HashMap;
1314
use tonic::transport::Uri;
1415
use tracing::instrument;
1516

@@ -232,6 +233,7 @@ impl Health for ShardedClient {
232233
watermark: false,
233234
grammar: String::new(),
234235
grammar_type: GrammarType::None as i32,
236+
logit_bias: HashMap::new(),
235237
}),
236238
stopping_parameters: Some(StoppingCriteriaParameters {
237239
max_new_tokens: 1,

backends/v3/src/queue.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::client::{
55
};
66
use nohash_hasher::{BuildNoHashHasher, IntMap};
77
use std::cmp::max;
8+
use std::collections::HashMap;
89
use std::collections::VecDeque;
910
use text_generation_router::infer::InferError;
1011
use text_generation_router::infer::InferStreamResponse;
@@ -522,6 +523,14 @@ impl From<ValidParameters> for NextTokenChooserParameters {
522523
watermark: value.watermark,
523524
grammar,
524525
grammar_type: grammar_type.into(),
526+
logit_bias: value
527+
.logit_bias
528+
.map(|bias| {
529+
bias.into_iter()
530+
.map(|(token, bias)| (token.to_string(), bias as i32))
531+
.collect::<HashMap<String, i32>>()
532+
})
533+
.unwrap_or_default(),
525534
}
526535
}
527536
}

benchmark/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ pub async fn run(
4747
watermark,
4848
grammar: String::new(),
4949
grammar_type: GrammarType::None as i32,
50+
logit_bias: std::collections::HashMap::new(),
5051
};
5152

5253
// Initialize terminal properties

clients/python/text_generation/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from enum import Enum
22
from pydantic import BaseModel, field_validator, ConfigDict
3-
from typing import Optional, List, Union, Any
3+
from typing import Optional, List, Union, Any, Dict
44

55
from text_generation.errors import ValidationError
66

@@ -137,7 +137,7 @@ class ChatRequest(BaseModel):
137137
# decreasing the model's likelihood to repeat the same line verbatim.
138138
frequency_penalty: Optional[float] = None
139139
# Bias values for token selection
140-
logit_bias: Optional[List[float]] = None
140+
logit_bias: Optional[Dict[str, int]] = None
141141
# Whether to return log probabilities
142142
logprobs: Optional[bool] = None
143143
# Number of most likely tokens to return at each position

docs/openapi.json

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -995,12 +995,12 @@
995995
"nullable": true
996996
},
997997
"logit_bias": {
998-
"type": "array",
999-
"items": {
1000-
"type": "number",
1001-
"format": "float"
998+
"type": "object",
999+
"description": "Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.",
1000+
"additionalProperties": {
1001+
"type": "integer",
1002+
"format": "int32"
10021003
},
1003-
"description": "UNUSED\nModify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens\n(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically,\nthe bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model,\nbut values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should\nresult in a ban or exclusive selection of the relevant token.",
10041004
"nullable": true
10051005
},
10061006
"logprobs": {
@@ -1589,6 +1589,17 @@
15891589
"default": "null",
15901590
"nullable": true
15911591
},
1592+
"logit_bias": {
1593+
"type": "object",
1594+
"description": "Modify the likelihood of specified tokens appearing in the completion.\nAccepts a hash map that maps token strings to an associated bias value.",
1595+
"default": "null",
1596+
"additionalProperties": {
1597+
"type": "integer",
1598+
"format": "int32"
1599+
},
1600+
"example": "{\"1923\": 100, \"1924\": -100}",
1601+
"nullable": true
1602+
},
15921603
"max_new_tokens": {
15931604
"type": "integer",
15941605
"format": "int32",
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
{
2+
"choices": [
3+
{
4+
"finish_reason": "length",
5+
"index": 0,
6+
"logprobs": null,
7+
"message": {
8+
"content": "Hello! How can I help you today?",
9+
"name": null,
10+
"role": "assistant",
11+
"tool_calls": null
12+
},
13+
"usage": null
14+
}
15+
],
16+
"created": 1745337495,
17+
"id": "",
18+
"model": "Qwen/Qwen2-VL-2B-Instruct",
19+
"object": "chat.completion",
20+
"system_fingerprint": "3.2.3-dev0-native",
21+
"usage": {
22+
"completion_tokens": 10,
23+
"prompt_tokens": 21,
24+
"total_tokens": 31
25+
}
26+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
{
2+
"choices": [
3+
{
4+
"finish_reason": "length",
5+
"index": 0,
6+
"logprobs": null,
7+
"message": {
8+
"content": "¡Hola! ¿Cómo puedo ayudarte?",
9+
"name": null,
10+
"role": "assistant",
11+
"tool_calls": null
12+
},
13+
"usage": null
14+
}
15+
],
16+
"created": 1745337456,
17+
"id": "",
18+
"model": "Qwen/Qwen2-VL-2B-Instruct",
19+
"object": "chat.completion",
20+
"system_fingerprint": "3.2.3-dev0-native",
21+
"usage": {
22+
"completion_tokens": 10,
23+
"prompt_tokens": 21,
24+
"total_tokens": 31
25+
}
26+
}

0 commit comments

Comments
 (0)