Skip to content

Commit 0858af2

Browse files
committed
fix(trtllm): fix segfault when canceling request
When a request is cancelled, the `tensorrt_llm::executor::Result` contains `outputTokenIds` with size 1, but `outputTokenIds[0]` has size 0. This causes `as_generation_step` to segfault. Check the size of `outputTokenIds` and `logProbs` before attempting to access the inner vector. The `finishReasons` can be skipped because it has only one dimension and the minimum beam size is 1. Because cxx have not added Option support yet, include two boolean flags to denote whether the value is valid. Change log level when request is cancelled to debug.
1 parent cc4b584 commit 0858af2

File tree

3 files changed

+49
-13
lines changed

3 files changed

+49
-13
lines changed

backends/trtllm/csrc/ffi.hpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,24 @@ namespace huggingface::tgi::backends::trtllm {
5555
const auto reqId = r.getRequestId();
5656
if (!r.hasError()) [[likely]] {
5757
const auto result = r.getResult();
58-
const auto logits = result.logProbs.value()[0];
58+
std::optional<uint32_t> token_id = std::nullopt;
59+
if (!result.outputTokenIds.empty() && !result.outputTokenIds[0].empty()) {
60+
token_id = static_cast<uint32_t>(result.outputTokenIds[0][0]);
61+
}
62+
63+
std::optional<float> log_prob = std::nullopt;
64+
if (result.logProbs && !result.logProbs->empty() && !result.logProbs.value()[0].empty()) {
65+
log_prob = result.logProbs.value()[0].back();
66+
}
67+
5968
return generation_step_t{
6069
reqId,
61-
static_cast<uint32_t>(result.outputTokenIds[0][0]),
62-
logits.back(),
70+
token_id.value_or(0),
71+
log_prob.value_or(0.0),
6372
result.isFinal,
6473
as_finish_reason_t(result.finishReasons[0]),
74+
token_id.has_value(),
75+
log_prob.has_value(),
6576
false,
6677
std::string()
6778
};
@@ -72,6 +83,8 @@ namespace huggingface::tgi::backends::trtllm {
7283
0.0,
7384
true,
7485
finish_reason_t::kNOT_FINISHED,
86+
false,
87+
false,
7588
true,
7689
std::move(r.getErrorMsg())
7790
};

backends/trtllm/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ mod ffi {
4444
log_prob: f32,
4545
is_final: bool,
4646
finish_reason: FinishReason,
47+
token_id_valid: bool,
48+
log_prob_valid: bool,
4749
has_error: bool,
4850
error_msg: String,
4951
}

backends/trtllm/src/looper.rs

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,28 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
4949
type Error = InferError;
5050

5151
fn try_from(step: &'step GenerationStep) -> Result<Self, Self::Error> {
52-
if !step.has_error {
53-
Ok(Self {
54-
id: step.token_id,
55-
log_prob: step.log_prob,
56-
is_final: step.is_final,
57-
finish_reason: step.finish_reason,
58-
})
59-
} else {
60-
Err(GenerationError(step.error_msg.clone()))
52+
if step.has_error {
53+
return Err(GenerationError(step.error_msg.clone()));
6154
}
55+
56+
if !step.token_id_valid {
57+
return Err(GenerationError(
58+
"GenerationStep contains no token_id".to_string(),
59+
));
60+
}
61+
62+
if !step.log_prob_valid {
63+
return Err(GenerationError(
64+
"GenerationStep contains no log_prob".to_string(),
65+
));
66+
}
67+
68+
Ok(Self {
69+
id: step.token_id,
70+
log_prob: step.log_prob,
71+
is_final: step.is_final,
72+
finish_reason: step.finish_reason,
73+
})
6274
}
6375
}
6476

@@ -151,7 +163,16 @@ fn executor_status_looper(
151163
let _ = in_flights.remove(&step.request_id);
152164
}
153165
} else {
154-
warn!("Untracked request {}", step.request_id,);
166+
match step.finish_reason {
167+
FinishReason::Cancelled => {
168+
// The client has canceled the request, so this should not generate a
169+
// warning.
170+
debug!("Cancelled request {}", step.request_id);
171+
}
172+
_ => {
173+
warn!("Untracked request {}", step.request_id);
174+
}
175+
}
155176
}
156177
}
157178
}

0 commit comments

Comments
 (0)