Skip to content

Commit 79de1c2

Browse files
committed
feat(trtllm): add stop sequence support
Support per request stop sequences.
1 parent 4e0c82f commit 79de1c2

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

backends/trtllm/src/looper.rs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ struct GenerationContext {
3535
tokens: Vec<u32>,
3636
start: Option<Instant>,
3737
queued: Instant,
38+
39+
/// output_buffer stores the output for detecting stop sequences
40+
output_buffer: Option<String>,
3841
}
3942

4043
#[derive(Debug, Copy, Clone)]
@@ -191,11 +194,39 @@ fn executor_status_looper(
191194
fn post_process_decoded_token(
192195
tokenizer: &Tokenizer,
193196
ctx: &mut GenerationContext,
194-
decoded_token: DecodedToken,
197+
mut decoded_token: DecodedToken,
195198
) -> InferResult<InferStreamResponse> {
196199
match tokenizer.decode(&[decoded_token.id], false) {
197200
Ok(text) => {
198201
let is_special = tokenizer.get_added_vocabulary().is_special_token(&text);
202+
203+
if let Some(buf) = ctx.output_buffer.as_mut() {
204+
if buf.len() + text.len() > buf.capacity() {
205+
let mut start = buf.len() + text.len() - buf.capacity();
206+
while start <= buf.len() && !buf.is_char_boundary(start) {
207+
start += 1;
208+
}
209+
buf.drain(..start);
210+
}
211+
buf.push_str(&text);
212+
213+
for stop_seq in &ctx.request.stopping_parameters.stop_sequences {
214+
let start = if 1 + buf.len() > text.len() + stop_seq.len() {
215+
let mut start = 1 + buf.len() - text.len() - stop_seq.len();
216+
while start > 0 && !buf.is_char_boundary(start) {
217+
start -= 1;
218+
}
219+
start
220+
} else {
221+
0
222+
};
223+
if buf[start..].contains(stop_seq) {
224+
decoded_token.is_final = true;
225+
decoded_token.finish_reason = FinishReason::StopWords;
226+
}
227+
}
228+
}
229+
199230
let token = Token {
200231
id: decoded_token.id,
201232
text,
@@ -344,12 +375,20 @@ impl Backend for TensorRtLlmBackendV2 {
344375

345376
// Send the context to the executor for scheduling
346377
let queued = Instant::now();
378+
let output_buffer = request
379+
.stopping_parameters
380+
.stop_sequences
381+
.iter()
382+
.map(|x| x.len())
383+
.max()
384+
.map(|m| String::with_capacity(m + 32)); // TODO: is this number enough?
347385
match self.0.send(GenerationContext {
348386
request,
349387
streamer,
350388
tokens: Vec::with_capacity(256),
351389
start: None,
352390
queued,
391+
output_buffer,
353392
}) {
354393
Ok(_) => Ok(UnboundedReceiverStream::new(receiver)),
355394
Err(_) => Err(GenerationError(

0 commit comments

Comments
 (0)