From 1b9c420c14a8efb24b0014f89e41371703fe39f9 Mon Sep 17 00:00:00 2001 From: Tzu-Yu Lee Date: Sun, 11 May 2025 03:27:59 +0800 Subject: [PATCH 01/13] feat(trtllm): add new finish reasons Add new finish reasons introduced in TensorRT-LLM v0.16.0. --- backends/trtllm/csrc/ffi.hpp | 4 ++++ backends/trtllm/src/lib.rs | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/backends/trtllm/csrc/ffi.hpp b/backends/trtllm/csrc/ffi.hpp index 840614bbcfe..a877df5a5c0 100644 --- a/backends/trtllm/csrc/ffi.hpp +++ b/backends/trtllm/csrc/ffi.hpp @@ -42,6 +42,10 @@ namespace huggingface::tgi::backends::trtllm { return finish_reason_t::kEND_ID; case tle::FinishReason::kLENGTH: return finish_reason_t::kLENGTH; + case tle::FinishReason::kTIMED_OUT: + return finish_reason_t::kTIMED_OUT; + case tle::FinishReason::kCANCELLED: + return finish_reason_t::kCANCELLED; default: std::unreachable(); } diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index 085072561f1..52e48f913f9 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -24,6 +24,14 @@ mod ffi { /// The request finished because the maximum number of tokens was reached. #[cxx_name = "kLENGTH"] MaxLength = 3u8, + + #[cxx_name = "kTIMED_OUT"] + /// The request finished because it got timed out (via the mAllotedTime parameter) + TimedOut = 4u8, + + #[cxx_name = "kCANCELLED"] + /// The request was cancelled by calling cancelRequest. + Cancelled = 5u8, } /// Struct used as shared type between rust and C++ to represent the result From 592c3c7913c7bdb989b0c89311d73fc4d2635369 Mon Sep 17 00:00:00 2001 From: Tzu-Yu Lee Date: Tue, 13 May 2025 00:05:56 +0800 Subject: [PATCH 02/13] fix: fix prometheus_port CLI short arg conflict The short arg of `prometheus_port` conflicts with `port`. Remove the short arg variant. Fixes https://github.com/huggingface/text-generation-inference/issues/3205 --- backends/llamacpp/src/main.rs | 2 +- backends/trtllm/src/main.rs | 2 +- backends/v2/src/main.rs | 2 +- backends/v3/src/main.rs | 2 +- docs/source/reference/launcher.md | 2 +- launcher/src/main.rs | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/backends/llamacpp/src/main.rs b/backends/llamacpp/src/main.rs index 9ee61ce6e2e..c193c8689b2 100644 --- a/backends/llamacpp/src/main.rs +++ b/backends/llamacpp/src/main.rs @@ -119,7 +119,7 @@ struct Args { #[clap(default_value = "3000", long, short, env)] port: u16, - #[clap(default_value = "9000", long, short, env)] + #[clap(default_value = "9000", long, env)] prometheus_port: u16, /// Enable JSON output format. diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index 543f8e6e352..81fca0e7e57 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -37,7 +37,7 @@ struct Args { hostname: String, #[clap(default_value = "3000", long, short, env)] port: u16, - #[clap(default_value = "9000", long, short, env)] + #[clap(default_value = "9000", long, env)] prometheus_port: u16, #[clap(long, env, required = true)] tokenizer_name: String, diff --git a/backends/v2/src/main.rs b/backends/v2/src/main.rs index 60b5d52bbe2..a0f3558c0da 100644 --- a/backends/v2/src/main.rs +++ b/backends/v2/src/main.rs @@ -36,7 +36,7 @@ struct Args { hostname: String, #[clap(default_value = "3000", long, short, env)] port: u16, - #[clap(default_value = "9000", long, short, env)] + #[clap(default_value = "9000", long, env)] prometheus_port: u16, #[clap(default_value = "/tmp/text-generation-server-0", long, env)] master_shard_uds_path: String, diff --git a/backends/v3/src/main.rs b/backends/v3/src/main.rs index 44e63853e04..75a2069124e 100644 --- a/backends/v3/src/main.rs +++ b/backends/v3/src/main.rs @@ -36,7 +36,7 @@ struct Args { hostname: String, #[clap(default_value = "3000", long, short, env)] port: u16, - #[clap(default_value = "9000", long, short, env)] + #[clap(default_value = "9000", long, env)] prometheus_port: u16, #[clap(default_value = "/tmp/text-generation-server-0", long, env)] master_shard_uds_path: String, diff --git a/docs/source/reference/launcher.md b/docs/source/reference/launcher.md index 5b7321b73a3..f49cbac5a2a 100644 --- a/docs/source/reference/launcher.md +++ b/docs/source/reference/launcher.md @@ -254,7 +254,7 @@ Options: ``` ## PROMETHEUS_PORT ```shell - -p, --prometheus-port + --prometheus-port The Prometheus port to listen on [env: PROMETHEUS_PORT=] diff --git a/launcher/src/main.rs b/launcher/src/main.rs index c727623ce47..f339cbb47e0 100644 --- a/launcher/src/main.rs +++ b/launcher/src/main.rs @@ -774,7 +774,7 @@ struct Args { port: u16, /// The Prometheus port to listen on. - #[clap(default_value = "9000", long, short, env)] + #[clap(default_value = "9000", long, env)] prometheus_port: u16, /// The name of the socket for gRPC communication between the webserver From 4e0c82fef2802cf494b2441ddf488b49edd3c297 Mon Sep 17 00:00:00 2001 From: Tzu-Yu Lee Date: Sun, 18 May 2025 02:22:53 +0800 Subject: [PATCH 03/13] fix(trtllm): fix segfault when canceling request When a request is cancelled, the `tensorrt_llm::executor::Result` contains `outputTokenIds` with size 1, but `outputTokenIds[0]` has size 0. This causes `as_generation_step` to segfault. Check the size of `outputTokenIds` and `logProbs` before attempting to access the inner vector. The `finishReasons` can be skipped because it has only one dimension and the minimum beam size is 1. Because cxx have not added Option support yet, include two boolean flags to denote whether the value is valid. Change log level when request is cancelled to debug. --- backends/trtllm/csrc/ffi.hpp | 19 +++++++++++++--- backends/trtllm/src/lib.rs | 2 ++ backends/trtllm/src/looper.rs | 41 ++++++++++++++++++++++++++--------- 3 files changed, 49 insertions(+), 13 deletions(-) diff --git a/backends/trtllm/csrc/ffi.hpp b/backends/trtllm/csrc/ffi.hpp index a877df5a5c0..90d1b9d17bd 100644 --- a/backends/trtllm/csrc/ffi.hpp +++ b/backends/trtllm/csrc/ffi.hpp @@ -55,13 +55,24 @@ namespace huggingface::tgi::backends::trtllm { const auto reqId = r.getRequestId(); if (!r.hasError()) [[likely]] { const auto result = r.getResult(); - const auto logits = result.logProbs.value()[0]; + std::optional token_id = std::nullopt; + if (!result.outputTokenIds.empty() && !result.outputTokenIds[0].empty()) { + token_id = static_cast(result.outputTokenIds[0][0]); + } + + std::optional log_prob = std::nullopt; + if (result.logProbs && !result.logProbs->empty() && !result.logProbs.value()[0].empty()) { + log_prob = result.logProbs.value()[0].back(); + } + return generation_step_t{ reqId, - static_cast(result.outputTokenIds[0][0]), - logits.back(), + token_id.value_or(0), + log_prob.value_or(0.0), result.isFinal, as_finish_reason_t(result.finishReasons[0]), + token_id.has_value(), + log_prob.has_value(), false, std::string() }; @@ -72,6 +83,8 @@ namespace huggingface::tgi::backends::trtllm { 0.0, true, finish_reason_t::kNOT_FINISHED, + false, + false, true, std::move(r.getErrorMsg()) }; diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index 52e48f913f9..b2a9274dd74 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -44,6 +44,8 @@ mod ffi { log_prob: f32, is_final: bool, finish_reason: FinishReason, + token_id_valid: bool, + log_prob_valid: bool, has_error: bool, error_msg: String, } diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index 5fed954fff7..fd0bc967da2 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -49,16 +49,28 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken { type Error = InferError; fn try_from(step: &'step GenerationStep) -> Result { - if !step.has_error { - Ok(Self { - id: step.token_id, - log_prob: step.log_prob, - is_final: step.is_final, - finish_reason: step.finish_reason, - }) - } else { - Err(GenerationError(step.error_msg.clone())) + if step.has_error { + return Err(GenerationError(step.error_msg.clone())); } + + if !step.token_id_valid { + return Err(GenerationError( + "GenerationStep contains no token_id".to_string(), + )); + } + + if !step.log_prob_valid { + return Err(GenerationError( + "GenerationStep contains no log_prob".to_string(), + )); + } + + Ok(Self { + id: step.token_id, + log_prob: step.log_prob, + is_final: step.is_final, + finish_reason: step.finish_reason, + }) } } @@ -151,7 +163,16 @@ fn executor_status_looper( let _ = in_flights.remove(&step.request_id); } } else { - warn!("Untracked request {}", step.request_id,); + match step.finish_reason { + FinishReason::Cancelled => { + // The client has canceled the request, so this should not generate a + // warning. + debug!("Cancelled request {}", step.request_id); + } + _ => { + warn!("Untracked request {}", step.request_id); + } + } } } } From 79de1c2cbca400b779dd5e59403fa11cc29afd46 Mon Sep 17 00:00:00 2001 From: Tzu-Yu Lee Date: Sun, 18 May 2025 02:37:19 +0800 Subject: [PATCH 04/13] feat(trtllm): add stop sequence support Support per request stop sequences. --- backends/trtllm/src/looper.rs | 41 ++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index fd0bc967da2..6d7f30c31d4 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -35,6 +35,9 @@ struct GenerationContext { tokens: Vec, start: Option, queued: Instant, + + /// output_buffer stores the output for detecting stop sequences + output_buffer: Option, } #[derive(Debug, Copy, Clone)] @@ -191,11 +194,39 @@ fn executor_status_looper( fn post_process_decoded_token( tokenizer: &Tokenizer, ctx: &mut GenerationContext, - decoded_token: DecodedToken, + mut decoded_token: DecodedToken, ) -> InferResult { match tokenizer.decode(&[decoded_token.id], false) { Ok(text) => { let is_special = tokenizer.get_added_vocabulary().is_special_token(&text); + + if let Some(buf) = ctx.output_buffer.as_mut() { + if buf.len() + text.len() > buf.capacity() { + let mut start = buf.len() + text.len() - buf.capacity(); + while start <= buf.len() && !buf.is_char_boundary(start) { + start += 1; + } + buf.drain(..start); + } + buf.push_str(&text); + + for stop_seq in &ctx.request.stopping_parameters.stop_sequences { + let start = if 1 + buf.len() > text.len() + stop_seq.len() { + let mut start = 1 + buf.len() - text.len() - stop_seq.len(); + while start > 0 && !buf.is_char_boundary(start) { + start -= 1; + } + start + } else { + 0 + }; + if buf[start..].contains(stop_seq) { + decoded_token.is_final = true; + decoded_token.finish_reason = FinishReason::StopWords; + } + } + } + let token = Token { id: decoded_token.id, text, @@ -344,12 +375,20 @@ impl Backend for TensorRtLlmBackendV2 { // Send the context to the executor for scheduling let queued = Instant::now(); + let output_buffer = request + .stopping_parameters + .stop_sequences + .iter() + .map(|x| x.len()) + .max() + .map(|m| String::with_capacity(m + 32)); // TODO: is this number enough? match self.0.send(GenerationContext { request, streamer, tokens: Vec::with_capacity(256), start: None, queued, + output_buffer, }) { Ok(_) => Ok(UnboundedReceiverStream::new(receiver)), Err(_) => Err(GenerationError( From b157cd00aa3ef3647b34c09ff838d48ac2290e66 Mon Sep 17 00:00:00 2001 From: Tzu-Yu Lee Date: Sun, 18 May 2025 02:49:35 +0800 Subject: [PATCH 05/13] feat(trtllm): catch broader exception The trycatch only uses the `what()` method, which means we can catch the broader `std::exception` instead. This is beneficial because nlohmann/json also throws exception. --- backends/trtllm/csrc/ffi.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backends/trtllm/csrc/ffi.hpp b/backends/trtllm/csrc/ffi.hpp index 90d1b9d17bd..d31743ee37b 100644 --- a/backends/trtllm/csrc/ffi.hpp +++ b/backends/trtllm/csrc/ffi.hpp @@ -1,6 +1,7 @@ #ifndef TGI_BACKEND_TRTLLM_FFI #define TGI_BACKEND_TRTLLM_FFI +#include #include #include @@ -17,7 +18,7 @@ namespace rust::behavior { template static void trycatch(Try &&func, Fail &&fail) noexcept try { func(); - } catch (tensorrt_llm::common::TllmException &e) { + } catch (const std::exception &e) { fail(e.what()); } } From 8c4a14e3c6630f6b9d9bf169b2ed6a297cf0a344 Mon Sep 17 00:00:00 2001 From: Tzu-Yu Lee Date: Sun, 18 May 2025 03:25:13 +0800 Subject: [PATCH 06/13] feat(trtllm): check existence of config files When the required config files are not present, nlohmann/json throws parsing error, which does not help much for identifying what was wrong. Check the existence of these files early and return specific error messages. --- backends/trtllm/src/errors.rs | 4 ++++ backends/trtllm/src/looper.rs | 22 +++++++++++++++++++++- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/backends/trtllm/src/errors.rs b/backends/trtllm/src/errors.rs index 812fd6e30d8..3e6bd7430b7 100644 --- a/backends/trtllm/src/errors.rs +++ b/backends/trtllm/src/errors.rs @@ -19,4 +19,8 @@ pub enum TensorRtLlmBackendError { WebServer(#[from] server::WebServerError), #[error("Tokio runtime failed to start: {0}")] Tokio(#[from] std::io::Error), + #[error("config.json doesn't exist in engine folder {0}")] + ConfigNotFound(PathBuf), + #[error("generation_config.json doesn't exist in engine folder {0}")] + GenerationConfigNotFound(PathBuf), } diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index 6d7f30c31d4..17030b211d6 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -3,7 +3,7 @@ use cxx::UniquePtr; use hashbrown::HashMap; use std::hint; use std::ops::Deref; -use std::path::Path; +use std::path::{Path, PathBuf}; use tokenizers::Tokenizer; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::TryAcquireError; @@ -283,6 +283,26 @@ fn ensure_paths_exist, PP: AsRef>( return Err(err); } + let mut config_path = PathBuf::from(engine_folder); + config_path.push("config.json"); + + if !config_path.exists() { + let err = TensorRtLlmBackendError::ConfigNotFound(engine_folder.to_path_buf()); + + error!("Path validation failed: {}", err,); + return Err(err); + } + + let mut generation_config_path = PathBuf::from(engine_folder); + generation_config_path.push("generation_config.json"); + + if !generation_config_path.exists() { + let err = TensorRtLlmBackendError::GenerationConfigNotFound(engine_folder.to_path_buf()); + + error!("Path validation failed: {}", err,); + return Err(err); + } + // Ensure executor worker binary exists if !executor_worker_path.exists() { let err = TensorRtLlmBackendError::ExecutorWorkerNotFound(engine_folder.to_path_buf()); From c170c6621fa48e4ecef99808c9cfbd3340fae903 Mon Sep 17 00:00:00 2001 From: Tzu-Yu Lee Date: Sun, 18 May 2025 18:22:02 +0800 Subject: [PATCH 07/13] fix(trtllm): fix do_sample being ignored Currently, the do_sample option is ignored and the executor will always sample. Set top_k to 1 if do_sample is false. --- backends/trtllm/src/looper.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index 17030b211d6..a4b70ea99d7 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -98,12 +98,17 @@ fn executor_status_looper( let generation_params = &request.parameters; let stopping_params = &request.stopping_parameters; let input_ids = request.input_ids.as_deref(); + let top_k = if generation_params.do_sample { + generation_params.top_k + } else { + 1 + }; // Submit to the TensorRT-LLM executor for scheduling match backend.pin_mut().submit( &input_ids.unwrap(), // This is checked beforehand in validate() stopping_params.max_new_tokens, - generation_params.top_k, + top_k, generation_params.top_p, generation_params.temperature, generation_params.repetition_penalty, From ee82a0850791afac1a556f63f06e7d1d876381d2 Mon Sep 17 00:00:00 2001 From: Tzu-Yu Lee Date: Sun, 25 May 2025 17:40:45 +0800 Subject: [PATCH 08/13] feat(trtllm): get more accurate start time Get a more accurate inference start time from the trtllm response. Because `Instant` does not expose absolute value, create reference points on both sides and return duration relative to the reference point instead. --- backends/trtllm/csrc/backend.cpp | 9 +++++++- backends/trtllm/csrc/ffi.hpp | 36 ++++++++++++++++++++++++++------ backends/trtllm/src/lib.rs | 4 ++++ backends/trtllm/src/looper.rs | 35 ++++++++++++++++++++++++------- 4 files changed, 70 insertions(+), 14 deletions(-) diff --git a/backends/trtllm/csrc/backend.cpp b/backends/trtllm/csrc/backend.cpp index 2151466be6e..4a131e31749 100644 --- a/backends/trtllm/csrc/backend.cpp +++ b/backends/trtllm/csrc/backend.cpp @@ -59,7 +59,14 @@ namespace huggingface::tgi::backends::trtllm { static_cast(g_params.max_new_tokens), true, (tle::SamplingConfig) s_params, - tle::OutputConfig{ /* returnLogProbs= */ true}, + tle::OutputConfig{ + /* returnLogProbs= */ true, + false, + false, + false, + false, + /* returnPerfMetrics=*/ true, + }, std::nullopt, std::nullopt, std::nullopt, diff --git a/backends/trtllm/csrc/ffi.hpp b/backends/trtllm/csrc/ffi.hpp index d31743ee37b..624259cfe02 100644 --- a/backends/trtllm/csrc/ffi.hpp +++ b/backends/trtllm/csrc/ffi.hpp @@ -1,6 +1,7 @@ #ifndef TGI_BACKEND_TRTLLM_FFI #define TGI_BACKEND_TRTLLM_FFI +#include #include #include #include @@ -52,7 +53,7 @@ namespace huggingface::tgi::backends::trtllm { } } - static auto as_generation_step = [](const tle::Response &r) { + static auto as_generation_step = [](const tle::Response &r, const std::chrono::time_point created) { const auto reqId = r.getRequestId(); if (!r.hasError()) [[likely]] { const auto result = r.getResult(); @@ -66,14 +67,23 @@ namespace huggingface::tgi::backends::trtllm { log_prob = result.logProbs.value()[0].back(); } + std::optional first_scheduled_time_ns = std::nullopt; + if (result.requestPerfMetrics) { + const auto &t = result.requestPerfMetrics->timingMetrics; + const auto ns = std::chrono::duration_cast(t.firstScheduledTime - created).count(); + first_scheduled_time_ns = static_cast(ns); + } + return generation_step_t{ reqId, token_id.value_or(0), log_prob.value_or(0.0), + first_scheduled_time_ns.value_or(0), result.isFinal, as_finish_reason_t(result.finishReasons[0]), token_id.has_value(), log_prob.has_value(), + first_scheduled_time_ns.has_value(), false, std::string() }; @@ -82,10 +92,12 @@ namespace huggingface::tgi::backends::trtllm { reqId, 0, 0.0, + 0, true, finish_reason_t::kNOT_FINISHED, false, false, + false, true, std::move(r.getErrorMsg()) }; @@ -97,9 +109,16 @@ namespace huggingface::tgi::backends::trtllm { private: backend_t inner_; + // m_created_time is a reference point to convert time from c++ time_point + // to rust Instant. + std::chrono::time_point m_created_time; + + public: - tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path) - : inner_(engine_folder, executor_worker_path) {} + tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path, const std::chrono::time_point& created_time) + : inner_(engine_folder, executor_worker_path), + m_created_time {created_time} + {} size_t num_tokens_ready() const noexcept { return inner_.num_tokens_ready(); } @@ -139,13 +158,16 @@ namespace huggingface::tgi::backends::trtllm { SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size()); + auto f = [this](const tle::Response &r){ + return as_generation_step(r, m_created_time); + }; // Transform tle::Response to generation_step_t #ifdef __cpp_lib_ranges_to_container - auto steps = responses | std::views::transform(as_generation_step) | std::ranges::to(); + auto steps = responses | std::views::transform(f) | std::ranges::to(); #else auto steps = std::vector(); steps.reserve(responses.size()); - std::transform(responses.begin(), responses.end(), std::back_inserter(steps), as_generation_step); + std::transform(responses.begin(), responses.end(), std::back_inserter(steps), f); #endif return std::make_unique>(steps); @@ -197,12 +219,14 @@ namespace huggingface::tgi::backends::trtllm { std::unique_ptr create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) { + const auto created_time = std::chrono::steady_clock::now(); std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend); return std::make_unique( std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()), std::filesystem::path::format::auto_format), std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()), - std::filesystem::path::format::auto_format) + std::filesystem::path::format::auto_format), + created_time ); } } diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index b2a9274dd74..3a245151b99 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -42,10 +42,14 @@ mod ffi { request_id: u64, token_id: u32, log_prob: f32, + + /// The time of first schedule since the creation of the backend + first_scheduled_time_ns: i64, is_final: bool, finish_reason: FinishReason, token_id_valid: bool, log_prob_valid: bool, + first_scheduled_time_ns_valid: bool, has_error: bool, error_msg: String, } diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index a4b70ea99d7..34fcf34ef68 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -8,7 +8,7 @@ use tokenizers::Tokenizer; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::TryAcquireError; use tokio::task::spawn_blocking; -use tokio::time::Instant; +use tokio::time::{Duration, Instant}; use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error, warn}; @@ -82,6 +82,7 @@ fn executor_status_looper( tokenizer: Tokenizer, mut backend: UniquePtr, mut backlog: UnboundedReceiver, + created_time: Instant, ) { // Track the tuple (request_id, stream) for each request let mut in_flights = @@ -144,12 +145,22 @@ fn executor_status_looper( for step in responses.deref() { if let Some(ctx) = in_flights.get_mut(&step.request_id) { // Update the starting timestamp if not set - // This value might not be the actual real starting time of the request - // on the executor side - Need to expose more info from the executor to - // retrieve this value - // TODO : Expose actual real starting time for a request on FFI layer if ctx.start.is_none() { - ctx.start = Some(Instant::now()); + if step.first_scheduled_time_ns_valid { + if step.first_scheduled_time_ns >= 0 { + ctx.start = created_time.checked_add(Duration::from_nanos( + step.first_scheduled_time_ns as u64, + )); + } else { + ctx.start = created_time.checked_sub(Duration::from_nanos( + -step.first_scheduled_time_ns as u64, + )); + } + } + + if ctx.start.is_none() { + ctx.start = Some(Instant::now()); + } } // Try to map the generation step to a DecodedToken @@ -348,13 +359,23 @@ impl TensorRtLlmBackendV2 { // Allocate the IPC layer to communicate with the backend let (executor_sender, executor_receiver) = unbounded_channel(); + // This is a reference point to convert time from c++ time_point + // to rust Instant. + let created_time = Instant::now(); + // Create the FFI backend let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path) .map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?; // Executor looper is responsible for scheduling and pulling requests state at regular interval spawn_blocking(move || { - executor_status_looper(max_inflight_requests, tokenizer, backend, executor_receiver) + executor_status_looper( + max_inflight_requests, + tokenizer, + backend, + executor_receiver, + created_time, + ) }); Ok(TensorRtLlmBackendV2(executor_sender)) From 23b78029fe5e374788c6de4f125085bbd9beec86 Mon Sep 17 00:00:00 2001 From: Tzu-Yu Lee Date: Sun, 25 May 2025 22:07:54 +0800 Subject: [PATCH 09/13] perf(trtllm): reduce futile loop iterations The executor_status_looper runs a spin loop, even if there are no active requests. This makes the service constantly wasting a CPU core. Make the loop block on receiving requests if there are no running ones to reduce CPU usage when idle. --- backends/trtllm/src/looper.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index 34fcf34ef68..267f9fa9612 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -90,7 +90,12 @@ fn executor_status_looper( 'scheduler: loop { // Is there any request pending to be scheduled? - let awaiting_requests = backlog.len(); + let mut awaiting_requests = backlog.len(); + if awaiting_requests == 0 && in_flights.is_empty() { + // Wait for 1 request if we are not waiting for any response, + // so that the loop blocks at receive from backlog. + awaiting_requests += 1; + } for _ in 0..awaiting_requests { // Retrieve all the requests if let Some(ctx) = backlog.blocking_recv() { From 161f62e00a1c2b50b8bc7d3ddc2d122db9417362 Mon Sep 17 00:00:00 2001 From: Tzu-Yu Lee Date: Fri, 30 May 2025 17:20:30 +0800 Subject: [PATCH 10/13] refactor: add interior mutability to tensorrt_llm_backend_t Make `tensorrt_llm_backend_t` interior mutable by marking the `inner_` struct as a `mutable` field, so we can make the methods `const`. This makes the pointer accessible from multiple threads at the Rust side without wrapping a Mutex. The underlying tensorrt_llm::executor::Executor already contains a mutex. --- backends/trtllm/csrc/ffi.hpp | 8 ++++---- backends/trtllm/src/lib.rs | 6 +++--- backends/trtllm/src/looper.rs | 9 ++++----- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/backends/trtllm/csrc/ffi.hpp b/backends/trtllm/csrc/ffi.hpp index 624259cfe02..99513011c89 100644 --- a/backends/trtllm/csrc/ffi.hpp +++ b/backends/trtllm/csrc/ffi.hpp @@ -107,7 +107,7 @@ namespace huggingface::tgi::backends::trtllm { class tensorrt_llm_backend_t { private: - backend_t inner_; + mutable backend_t inner_; // m_created_time is a reference point to convert time from c++ time_point // to rust Instant. @@ -131,7 +131,7 @@ namespace huggingface::tgi::backends::trtllm { float_t repetition_penalty, float_t frequency_penalty, uint64_t seed - ) { + ) const { // This is enabled only if using add_compile_definitions(SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_TRACE) SPDLOG_TRACE(FMT_STRING("[FFI] Submitting {:d} prompt tokens to the executor")); @@ -152,7 +152,7 @@ namespace huggingface::tgi::backends::trtllm { } } - std::unique_ptr> pull_tokens() noexcept { + std::unique_ptr> pull_tokens() const noexcept { if (num_tokens_ready() > 0) [[likely]] { const auto responses = inner_.pull_tokens(); @@ -176,7 +176,7 @@ namespace huggingface::tgi::backends::trtllm { } } - void cancel(request_id_t request_id) noexcept { + void cancel(request_id_t request_id) const noexcept { SPDLOG_DEBUG("[FFI] cancelling request {:d}", request_id); inner_.cancel(request_id); } diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index 3a245151b99..511893a6cc4 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -83,7 +83,7 @@ mod ffi { fn num_tokens_ready(self: &TensorRtLlmBackendImpl) -> usize; fn submit( - self: Pin<&mut TensorRtLlmBackendImpl>, + self: &TensorRtLlmBackendImpl, tokens: &[u32], max_new_tokens: u32, top_k: u32, @@ -95,10 +95,10 @@ mod ffi { ) -> Result; fn pull_tokens( - self: Pin<&mut TensorRtLlmBackendImpl>, + self: &TensorRtLlmBackendImpl, ) -> Result>>; - fn cancel(self: Pin<&mut TensorRtLlmBackendImpl>, request_id: u64); + fn cancel(self: &TensorRtLlmBackendImpl, request_id: u64); } } diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index 267f9fa9612..e2ef0ee8223 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -80,7 +80,7 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken { fn executor_status_looper( max_inflight_requests: usize, tokenizer: Tokenizer, - mut backend: UniquePtr, + backend: UniquePtr, mut backlog: UnboundedReceiver, created_time: Instant, ) { @@ -111,7 +111,7 @@ fn executor_status_looper( }; // Submit to the TensorRT-LLM executor for scheduling - match backend.pin_mut().submit( + match backend.submit( &input_ids.unwrap(), // This is checked beforehand in validate() stopping_params.max_new_tokens, top_k, @@ -143,8 +143,7 @@ fn executor_status_looper( } if backend.num_tokens_ready() > 0 { - let mut backend = backend.pin_mut(); - match backend.as_mut().pull_tokens() { + match backend.pull_tokens() { Ok(responses) => { // Iterate through all the decoded token for step in responses.deref() { @@ -183,7 +182,7 @@ fn executor_status_looper( "Client dropped - removing request {} from tracked requests", step.request_id ); - backend.as_mut().cancel(step.request_id); + backend.cancel(step.request_id); let _ = in_flights.remove(&step.request_id); } } else { From e2b0063c1e9ac18128e780e9664923ea572959b2 Mon Sep 17 00:00:00 2001 From: Tzu-Yu Lee Date: Sun, 1 Jun 2025 04:01:35 +0800 Subject: [PATCH 11/13] feat(trtllm): separate request and response loop The executor_status_looper spend CPU time polling at the number of tokens. Because the function is protected by mutex inside, this also interferes with the Executor. Because now the TensorRtLlmBackendImpl is interior mutable, we can mark it as `Send` and share it in multiple threads. Therefore, the loop can be split into request and response parts, and we can await for tokens instead of constantly polling. --- backends/trtllm/csrc/backend.cpp | 4 - backends/trtllm/csrc/backend.hpp | 7 - backends/trtllm/csrc/ffi.hpp | 29 ++-- backends/trtllm/src/lib.rs | 2 - backends/trtllm/src/looper.rs | 265 +++++++++++++++++-------------- 5 files changed, 159 insertions(+), 148 deletions(-) diff --git a/backends/trtllm/csrc/backend.cpp b/backends/trtllm/csrc/backend.cpp index 4a131e31749..0ff68716d50 100644 --- a/backends/trtllm/csrc/backend.cpp +++ b/backends/trtllm/csrc/backend.cpp @@ -46,10 +46,6 @@ namespace huggingface::tgi::backends::trtllm { backend_t::backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path) : workspace(engines_folder, executor_worker_path), executor_(executor_factory_initializer(workspace)) {} - size_t backend_t::num_tokens_ready() const noexcept { - return executor_.getNumResponsesReady(); - } - std::expected backend_t::submit(std::span token_ids, const generation_params_t g_params, const sampling_params_t s_params) noexcept { diff --git a/backends/trtllm/csrc/backend.hpp b/backends/trtllm/csrc/backend.hpp index 40b44a842b3..9f7067da64e 100644 --- a/backends/trtllm/csrc/backend.hpp +++ b/backends/trtllm/csrc/backend.hpp @@ -175,13 +175,6 @@ namespace huggingface::tgi::backends::trtllm { submit(std::span token_ids, generation_params_t generation_params, sampling_params_t sampling_params) noexcept; - /** - * Query the number of tokens available across all in-flight generations - * @return - */ - [[nodiscard("Pulling out the number of tokens")]] - size_t num_tokens_ready() const noexcept; - /** * Pull out newly generated tokens from the executor * @return diff --git a/backends/trtllm/csrc/ffi.hpp b/backends/trtllm/csrc/ffi.hpp index 99513011c89..c4b27400acf 100644 --- a/backends/trtllm/csrc/ffi.hpp +++ b/backends/trtllm/csrc/ffi.hpp @@ -120,8 +120,6 @@ namespace huggingface::tgi::backends::trtllm { m_created_time {created_time} {} - size_t num_tokens_ready() const noexcept { return inner_.num_tokens_ready(); } - request_id_t submit( rust::Slice tokens, uint32_t max_new_tokens, @@ -153,27 +151,22 @@ namespace huggingface::tgi::backends::trtllm { } std::unique_ptr> pull_tokens() const noexcept { - if (num_tokens_ready() > 0) [[likely]] { - const auto responses = inner_.pull_tokens(); + const auto responses = inner_.pull_tokens(); - SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size()); + SPDLOG_TRACE("[FFI] Successfully pulled out {:d} responses from executor", responses.size()); - auto f = [this](const tle::Response &r){ - return as_generation_step(r, m_created_time); - }; - // Transform tle::Response to generation_step_t + auto f = [this](const tle::Response &r){ + return as_generation_step(r, m_created_time); + }; + auto steps = std::make_unique>(); + // Transform tle::Response to generation_step_t #ifdef __cpp_lib_ranges_to_container - auto steps = responses | std::views::transform(f) | std::ranges::to(); + *steps = responses | std::views::transform(f) | std::ranges::to(); #else - auto steps = std::vector(); - steps.reserve(responses.size()); - std::transform(responses.begin(), responses.end(), std::back_inserter(steps), f); + steps->reserve(responses.size()); + std::transform(responses.begin(), responses.end(), std::back_inserter(steps), f); #endif - return std::make_unique>(steps); - - } else { - return std::make_unique>(); - } + return steps; } void cancel(request_id_t request_id) const noexcept { diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index 511893a6cc4..2127d13e3f4 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -80,8 +80,6 @@ mod ffi { executor_worker: &str, ) -> Result>; - fn num_tokens_ready(self: &TensorRtLlmBackendImpl) -> usize; - fn submit( self: &TensorRtLlmBackendImpl, tokens: &[u32], diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index e2ef0ee8223..43bacec79fc 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -1,10 +1,11 @@ use async_trait::async_trait; use cxx::UniquePtr; use hashbrown::HashMap; -use std::hint; use std::ops::Deref; use std::path::{Path, PathBuf}; +use std::sync::Arc; use tokenizers::Tokenizer; +use tokio::sync::mpsc::error::TryRecvError; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tokio::sync::TryAcquireError; use tokio::task::spawn_blocking; @@ -77,137 +78,158 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken { } } -fn executor_status_looper( +struct InFlightRequest { + request_id: u64, + ctx: GenerationContext, +} + +/// request_looper reads from the backlog, sends the request to backend, +/// and then transfer the request context to the response_looper via in_flights. +fn request_looper( + backend: Arc>, + mut backlog: UnboundedReceiver, + in_flights: UnboundedSender, +) { + loop { + let Some(ctx) = backlog.blocking_recv() else { + break; + }; + // Submit all the request to the executor and move the context to the in-flight tracker + let request = &ctx.request; + let generation_params = &request.parameters; + let stopping_params = &request.stopping_parameters; + let input_ids = request.input_ids.as_deref(); + let top_k = if generation_params.do_sample { + generation_params.top_k + } else { + 1 + }; + + // Submit to the TensorRT-LLM executor for scheduling + match backend.submit( + &input_ids.unwrap(), // This is checked beforehand in validate() + stopping_params.max_new_tokens, + top_k, + generation_params.top_p, + generation_params.temperature, + generation_params.repetition_penalty, + generation_params.frequency_penalty, + generation_params.seed, + ) { + Ok(request_id) => { + // Insert the context linked to the generated request id in the tracker + debug!("[in-flight] Added {}", request_id); + if let Err(err) = in_flights.send(InFlightRequest { request_id, ctx }) { + error!("[in-flight] Send failed {}", err); + return; + } + } + Err(e) => { + // Return to the caller + let what = e.to_string(); + error!(error = what.as_str(), "Failed to schedule request"); + + let err = Err(InferError::Overloaded(TryAcquireError::NoPermits)); + if let Err(_) = ctx.streamer.send(err) { + error!("Failed to send back error to the client"); + } + } + }; + } +} + +/// response_looper awaits requests from in_flights if there are no active ones +/// or awaits for tokens from backend. The tokens are processed and sent back. +fn response_looper( max_inflight_requests: usize, tokenizer: Tokenizer, - backend: UniquePtr, - mut backlog: UnboundedReceiver, created_time: Instant, + backend: Arc>, + mut in_flight_recv: UnboundedReceiver, ) { - // Track the tuple (request_id, stream) for each request + // // Track the tuple (request_id, stream) for each request let mut in_flights = HashMap::::with_capacity(max_inflight_requests * 2); - - 'scheduler: loop { - // Is there any request pending to be scheduled? - let mut awaiting_requests = backlog.len(); - if awaiting_requests == 0 && in_flights.is_empty() { - // Wait for 1 request if we are not waiting for any response, - // so that the loop blocks at receive from backlog. - awaiting_requests += 1; - } - for _ in 0..awaiting_requests { - // Retrieve all the requests - if let Some(ctx) = backlog.blocking_recv() { - // Submit all the request to the executor and move the context to the in-flight tracker - let request = &ctx.request; - let generation_params = &request.parameters; - let stopping_params = &request.stopping_parameters; - let input_ids = request.input_ids.as_deref(); - let top_k = if generation_params.do_sample { - generation_params.top_k - } else { - 1 - }; - - // Submit to the TensorRT-LLM executor for scheduling - match backend.submit( - &input_ids.unwrap(), // This is checked beforehand in validate() - stopping_params.max_new_tokens, - top_k, - generation_params.top_p, - generation_params.temperature, - generation_params.repetition_penalty, - generation_params.frequency_penalty, - generation_params.seed, - ) { - Ok(request_id) => { - // Insert the context linked to the generated request id in the tracker - debug!("[in-flight] Added {}", request_id); - in_flights.insert(request_id, ctx); - } - Err(e) => { - // Return to the caller - let what = e.to_string(); - error!(error = what.as_str(), "Failed to schedule request"); - - let err = Err(InferError::Overloaded(TryAcquireError::NoPermits)); - if let Err(_) = ctx.streamer.send(err) { - error!("Failed to send back error to the client"); - } - } - }; - } else { - break 'scheduler; - } + loop { + if in_flights.is_empty() { + // If there are no active requests, block on Rust channel instead of C++ side. + let Some(req) = in_flight_recv.blocking_recv() else { + return; + }; + in_flights.insert(req.request_id, req.ctx); } + match backend.pull_tokens() { + Ok(responses) => { + // Fetch all pending requests, in case we are receiving tokens from them. + loop { + match in_flight_recv.try_recv() { + Ok(req) => in_flights.insert(req.request_id, req.ctx), + Err(err) => match err { + TryRecvError::Empty => break, + TryRecvError::Disconnected => return, + }, + }; + } - if backend.num_tokens_ready() > 0 { - match backend.pull_tokens() { - Ok(responses) => { - // Iterate through all the decoded token - for step in responses.deref() { - if let Some(ctx) = in_flights.get_mut(&step.request_id) { - // Update the starting timestamp if not set - if ctx.start.is_none() { - if step.first_scheduled_time_ns_valid { - if step.first_scheduled_time_ns >= 0 { - ctx.start = created_time.checked_add(Duration::from_nanos( - step.first_scheduled_time_ns as u64, - )); - } else { - ctx.start = created_time.checked_sub(Duration::from_nanos( - -step.first_scheduled_time_ns as u64, - )); - } + // Iterate through all the decoded token + for step in responses.deref() { + if let Some(ctx) = in_flights.get_mut(&step.request_id) { + // Update the starting timestamp if not set + if ctx.start.is_none() { + if step.first_scheduled_time_ns_valid { + if step.first_scheduled_time_ns >= 0 { + ctx.start = created_time.checked_add(Duration::from_nanos( + step.first_scheduled_time_ns as u64, + )); + } else { + ctx.start = created_time.checked_sub(Duration::from_nanos( + -step.first_scheduled_time_ns as u64, + )); } + } - if ctx.start.is_none() { - ctx.start = Some(Instant::now()); - } + if ctx.start.is_none() { + ctx.start = Some(Instant::now()); } + } - // Try to map the generation step to a DecodedToken - let response = match DecodedToken::try_from(step) { - Ok(decoded_token) => { - post_process_decoded_token(&tokenizer, ctx, decoded_token) - } - Err(err) => Err(err), - }; - - // Attempt to send back the response to the client - if let Err(_) = ctx.streamer.send(response) { - // Client has dropped, remove from tracked requests - debug!( - "Client dropped - removing request {} from tracked requests", - step.request_id - ); - backend.cancel(step.request_id); - let _ = in_flights.remove(&step.request_id); + // Try to map the generation step to a DecodedToken + let response = match DecodedToken::try_from(step) { + Ok(decoded_token) => { + post_process_decoded_token(&tokenizer, ctx, decoded_token) } - } else { - match step.finish_reason { - FinishReason::Cancelled => { - // The client has canceled the request, so this should not generate a - // warning. - debug!("Cancelled request {}", step.request_id); - } - _ => { - warn!("Untracked request {}", step.request_id); - } + Err(err) => Err(err), + }; + + // Attempt to send back the response to the client + if let Err(_) = ctx.streamer.send(response) { + // Client has dropped, remove from tracked requests + debug!( + "Client dropped - removing request {} from tracked requests", + step.request_id + ); + backend.cancel(step.request_id); + let _ = in_flights.remove(&step.request_id); + } + } else { + match step.finish_reason { + FinishReason::Cancelled => { + // The client has canceled the request, so this should not generate a + // warning. + debug!("Cancelled request {}", step.request_id); + } + _ => { + warn!("Untracked request {}", step.request_id); } } } } - Err(ref err) => { - error!("Failed to get responses from the executor: {}.", err.what()); - break 'scheduler; - } + } + Err(ref err) => { + error!("Failed to get responses from the executor: {}.", err.what()); + break; } } - - // Hint the CPU we are spin-locking - hint::spin_loop(); } } @@ -347,6 +369,7 @@ fn ensure_paths_exist, PP: AsRef>( } unsafe impl Send for TensorRtLlmBackendImpl {} +unsafe impl Sync for TensorRtLlmBackendImpl {} pub struct TensorRtLlmBackendV2(UnboundedSender); @@ -363,6 +386,8 @@ impl TensorRtLlmBackendV2 { // Allocate the IPC layer to communicate with the backend let (executor_sender, executor_receiver) = unbounded_channel(); + let (in_flight_sender, in_flight_receiver) = unbounded_channel(); + // This is a reference point to convert time from c++ time_point // to rust Instant. let created_time = Instant::now(); @@ -371,14 +396,20 @@ impl TensorRtLlmBackendV2 { let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path) .map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?; - // Executor looper is responsible for scheduling and pulling requests state at regular interval + let backend = Arc::new(backend); + let backend_response = backend.clone(); + + // Request looper is responsible for scheduling requests + spawn_blocking(move || request_looper(backend, executor_receiver, in_flight_sender)); + + // Response looper is responsible for awaiting tokens and send them back spawn_blocking(move || { - executor_status_looper( + response_looper( max_inflight_requests, tokenizer, - backend, - executor_receiver, created_time, + backend_response, + in_flight_receiver, ) }); From 34307a4282030ea8dc6c1ab5d7305f1e251f9a59 Mon Sep 17 00:00:00 2001 From: Tzu-Yu Lee Date: Sat, 11 Oct 2025 03:58:14 +0800 Subject: [PATCH 12/13] fix(trtllm): handle single eos_token_id in generation_config The type of `eos_token_id` in `transformers.GenerationConfig` is `Union[int, list[int]]` (as of transformers 4.57.0). The original code only parses this field when the value is an array, so the stop_words is not populated for some models. Add code to handle the `int` case as well. --- backends/trtllm/csrc/backend.hpp | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/backends/trtllm/csrc/backend.hpp b/backends/trtllm/csrc/backend.hpp index 9f7067da64e..82443c4d149 100644 --- a/backends/trtllm/csrc/backend.hpp +++ b/backends/trtllm/csrc/backend.hpp @@ -69,14 +69,23 @@ namespace huggingface::tgi::backends::trtllm { constexpr explicit generation_config_t(const json &config) : top_p(config.value("top_p", 1.0f)), temperature(config.value("temperature", 1.0f)), stop_words(0) { - if (config.contains("/eos_token_id"_json_pointer) && config["/eos_token_id"_json_pointer].is_array()) { + if (!config.contains("/eos_token_id"_json_pointer)) { + return; + } + if (config["/eos_token_id"_json_pointer].is_array()) { + SPDLOG_DEBUG("generation config eos_token_id is array"); const auto &eos_token_id = config["/eos_token_id"_json_pointer]; std::for_each(eos_token_id.begin(), eos_token_id.end(), [this](const auto token_id) { stop_words.emplace_back(1, token_id.template get()); }); + } - SPDLOG_DEBUG("Detected {:d} predefined stop_words from generation_config.json", stop_words.size()); + if (config["/eos_token_id"_json_pointer].is_number()) { + SPDLOG_DEBUG("generation config eos_token_id is number"); + stop_words.emplace_back(1, config["/eos_token_id"_json_pointer].get()); } + + SPDLOG_DEBUG("Detected {:d} predefined stop_words from generation_config.json", stop_words.size()); } }; From bf040884b2756535b8c02d3453d1c4dd0b858e33 Mon Sep 17 00:00:00 2001 From: Tzu-Yu Lee Date: Mon, 13 Oct 2025 00:09:00 +0800 Subject: [PATCH 13/13] feat(trtllm): support guided decoding TGI already accepts grammar for guided decoding through its HTTP API, however, this feature has been disabled for the trtllm backend. To enable this feature: - Replace the hard-coded disable of the grammar support with the `disable_grammar_support` arg present in the v3 backend. - Pass tokenizer information when constructing the trtllm Executor and enable guided decoding by default. - Pass the validated grammar type and value from requests to the Executor. --- backends/trtllm/csrc/backend.cpp | 25 +++++++++++++++---- backends/trtllm/csrc/backend.hpp | 25 ++++++++++++------- backends/trtllm/csrc/ffi.hpp | 38 ++++++++++++++++++++++++----- backends/trtllm/src/lib.rs | 17 +++++++++++++ backends/trtllm/src/looper.rs | 42 +++++++++++++++++++++++++------- backends/trtllm/src/main.rs | 5 +++- 6 files changed, 122 insertions(+), 30 deletions(-) diff --git a/backends/trtllm/csrc/backend.cpp b/backends/trtllm/csrc/backend.cpp index 0ff68716d50..de975ad1d6e 100644 --- a/backends/trtllm/csrc/backend.cpp +++ b/backends/trtllm/csrc/backend.cpp @@ -26,7 +26,7 @@ namespace huggingface::tgi::backends::trtllm { } - tle::ExecutorConfig backend_workspace_t::executor_config() const { + tle::ExecutorConfig backend_workspace_t::executor_config(const std::vector& encoded_vocab, std::string_view tokenizer_str) const { // Retrieve the compute capabilities to enable some options at runtime const auto compute_capabilities = hardware::cuda::compute_capabilities_t(); @@ -40,17 +40,24 @@ namespace huggingface::tgi::backends::trtllm { executor_config.setKvCacheConfig(tle::KvCacheConfig(true)); executor_config.setEnableChunkedContext(compute_capabilities.is_at_least_ampere()); executor_config.setSchedulerConfig(tle::SchedulerConfig(tle::CapacitySchedulerPolicy::kMAX_UTILIZATION)); + executor_config.setGuidedDecodingConfig(tle::GuidedDecodingConfig( + tle::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR, + encoded_vocab, + std::string(tokenizer_str), + generation_config().eos_token_ids + )); return executor_config; } - backend_t::backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path) - : workspace(engines_folder, executor_worker_path), executor_(executor_factory_initializer(workspace)) {} + backend_t::backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path, const std::vector &encoded_vocab, std::string_view tokenizer_str) + : workspace(engines_folder, executor_worker_path), + executor_(executor_factory_initializer(workspace, encoded_vocab, tokenizer_str)) {} std::expected backend_t::submit(std::span token_ids, const generation_params_t g_params, const sampling_params_t s_params) noexcept { SPDLOG_DEBUG("Submit {:d} tokens for scheduling ({}, {})", token_ids.size(), g_params, s_params); - return executor_.enqueueRequest(tle::Request{ + tle::Request req { {token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens static_cast(g_params.max_new_tokens), true, @@ -68,7 +75,15 @@ namespace huggingface::tgi::backends::trtllm { std::nullopt, std::nullopt, workspace.generation_config().stop_words - }); + }; + + if (g_params.guide_type.has_value()) { + req.setGuidedDecodingParams(tle::GuidedDecodingParams( + g_params.guide_type.value(), + g_params.guide + )); + } + return executor_.enqueueRequest(req); } std::vector backend_t::pull_tokens() noexcept { diff --git a/backends/trtllm/csrc/backend.hpp b/backends/trtllm/csrc/backend.hpp index 82443c4d149..184bf26f9be 100644 --- a/backends/trtllm/csrc/backend.hpp +++ b/backends/trtllm/csrc/backend.hpp @@ -25,6 +25,8 @@ namespace huggingface::tgi::backends::trtllm { */ struct generation_params_t { uint32_t max_new_tokens; + std::optional guide_type; + std::string guide; }; /** @@ -66,9 +68,10 @@ namespace huggingface::tgi::backends::trtllm { float_t top_p; float_t temperature; std::list> stop_words; + std::vector eos_token_ids; constexpr explicit generation_config_t(const json &config) : - top_p(config.value("top_p", 1.0f)), temperature(config.value("temperature", 1.0f)), stop_words(0) { + top_p(config.value("top_p", 1.0f)), temperature(config.value("temperature", 1.0f)), stop_words(0), eos_token_ids{} { if (!config.contains("/eos_token_id"_json_pointer)) { return; } @@ -76,13 +79,17 @@ namespace huggingface::tgi::backends::trtllm { SPDLOG_DEBUG("generation config eos_token_id is array"); const auto &eos_token_id = config["/eos_token_id"_json_pointer]; std::for_each(eos_token_id.begin(), eos_token_id.end(), [this](const auto token_id) { - stop_words.emplace_back(1, token_id.template get()); + const auto token = token_id.template get(); + stop_words.emplace_back(1, token); + eos_token_ids.emplace_back(token); }); } if (config["/eos_token_id"_json_pointer].is_number()) { SPDLOG_DEBUG("generation config eos_token_id is number"); - stop_words.emplace_back(1, config["/eos_token_id"_json_pointer].get()); + const auto token = config["/eos_token_id"_json_pointer].get(); + stop_words.emplace_back(1, token); + eos_token_ids.emplace_back(token); } SPDLOG_DEBUG("Detected {:d} predefined stop_words from generation_config.json", stop_words.size()); @@ -143,7 +150,7 @@ namespace huggingface::tgi::backends::trtllm { * to initialize `tensorrt_llm::executor::Executor` * @return `tensorrt_llm::executor::ExecutorConfig` instance */ - [[nodiscard]] tle::ExecutorConfig executor_config() const; + [[nodiscard]] tle::ExecutorConfig executor_config(const std::vector& encoded_vocab, std::string_view tokenizer_str) const; }; /** @@ -167,10 +174,10 @@ namespace huggingface::tgi::backends::trtllm { tle::Executor executor_; public: - backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path); + backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path, const std::vector &encoded_vocab, std::string_view tokenizer_str); - backend_t(std::filesystem::path &&engines_folder, std::filesystem::path &&executor_worker_path) - : backend_t(engines_folder, executor_worker_path) {}; + backend_t(std::filesystem::path &&engines_folder, std::filesystem::path &&executor_worker_path, const std::vector &encoded_vocab, std::string_view tokenizer_str) + : backend_t(engines_folder, executor_worker_path, encoded_vocab, tokenizer_str) {}; /** * Submit a new request to the executor @@ -201,9 +208,9 @@ namespace huggingface::tgi::backends::trtllm { /** * Create a TensorRT-LLM executor from a workspace */ - const auto executor_factory_initializer = [](const backend_workspace_t &workspace) -> tle::Executor { + const auto executor_factory_initializer = [](const backend_workspace_t &workspace, const std::vector &encoded_vocab, std::string_view tokenizer_str) -> tle::Executor { return {workspace.engines_folder(), tensorrt_llm::executor::ModelType::kDECODER_ONLY, - workspace.executor_config()}; + workspace.executor_config(encoded_vocab, tokenizer_str)}; }; } diff --git a/backends/trtllm/csrc/ffi.hpp b/backends/trtllm/csrc/ffi.hpp index c4b27400acf..3a16630c017 100644 --- a/backends/trtllm/csrc/ffi.hpp +++ b/backends/trtllm/csrc/ffi.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -115,8 +116,8 @@ namespace huggingface::tgi::backends::trtllm { public: - tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path, const std::chrono::time_point& created_time) - : inner_(engine_folder, executor_worker_path), + tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path, const std::chrono::time_point& created_time, const std::vector& encoded_vocab, std::string_view tokenizer_str) + : inner_(engine_folder, executor_worker_path, encoded_vocab, tokenizer_str), m_created_time {created_time} {} @@ -128,16 +129,31 @@ namespace huggingface::tgi::backends::trtllm { float_t temperature, float_t repetition_penalty, float_t frequency_penalty, - uint64_t seed + uint64_t seed, + grammar_type_t grammar_type, + rust::Str grammar_value ) const { // This is enabled only if using add_compile_definitions(SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_TRACE) SPDLOG_TRACE(FMT_STRING("[FFI] Submitting {:d} prompt tokens to the executor")); // Submit the request to the executor and get back a potential request_id used to track request status const auto signed_tokens = std::vector(tokens.begin(), tokens.end()); + + std::optional guide_type = std::nullopt; + switch (grammar_type) { + case grammar_type_t::kJSON: + guide_type = tle::GuidedDecodingParams::GuideType::kJSON_SCHEMA; + break; + case grammar_type_t::kREGEX: + guide_type = tle::GuidedDecodingParams::GuideType::kREGEX; + break; + default: + break; + } + const auto maybe_request_id = inner_.submit( signed_tokens, - {max_new_tokens}, + {max_new_tokens, guide_type, std::string(grammar_value)}, {top_k, top_p, repetition_penalty, frequency_penalty, temperature, seed} ); @@ -211,15 +227,25 @@ namespace huggingface::tgi::backends::trtllm { } std::unique_ptr - create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) { + create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path, const rust::Str tokenizer_str, const rust::Vec encoded_vocab) { const auto created_time = std::chrono::steady_clock::now(); std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend); + + std::vector encoded_vocab_std{}; + encoded_vocab_std.reserve(encoded_vocab.size()); + + for (const auto& v : encoded_vocab) { + encoded_vocab_std.push_back(std::string(v)); + } + return std::make_unique( std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()), std::filesystem::path::format::auto_format), std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()), std::filesystem::path::format::auto_format), - created_time + created_time, + encoded_vocab_std, + std::string_view(tokenizer_str) ); } } diff --git a/backends/trtllm/src/lib.rs b/backends/trtllm/src/lib.rs index 2127d13e3f4..306f9f486b2 100644 --- a/backends/trtllm/src/lib.rs +++ b/backends/trtllm/src/lib.rs @@ -78,6 +78,8 @@ mod ffi { fn create_backend_from_engine_folder( engine_folder: &str, executor_worker: &str, + tokenizer_str: &str, + encoded_vocab: Vec, ) -> Result>; fn submit( @@ -90,6 +92,8 @@ mod ffi { repetition_penalty: f32, frequency_penalty: f32, seed: u64, + grammar_type: GrammarType, + grammar_value: &str, ) -> Result; fn pull_tokens( @@ -98,6 +102,19 @@ mod ffi { fn cancel(self: &TensorRtLlmBackendImpl, request_id: u64); } + + #[cxx_name = "grammar_type_t"] + #[derive(Debug, Clone, Copy)] + pub enum GrammarType { + #[cxx_name = "kNONE"] + None = 0u8, + + #[cxx_name = "kJSON"] + Json = 1u8, + + #[cxx_name = "kREGEX"] + Regex = 2u8, + } } use ffi::FinishReason; diff --git a/backends/trtllm/src/looper.rs b/backends/trtllm/src/looper.rs index 43bacec79fc..32de99f37a1 100644 --- a/backends/trtllm/src/looper.rs +++ b/backends/trtllm/src/looper.rs @@ -18,12 +18,13 @@ use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStr use text_generation_router::validation::ValidationError::{ EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality, }; -use text_generation_router::validation::{Chunk, ValidGenerateRequest}; +use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidGrammar}; use text_generation_router::Token; use crate::errors::TensorRtLlmBackendError; use crate::ffi::{ - create_backend_from_engine_folder, FinishReason, GenerationStep, TensorRtLlmBackendImpl, + create_backend_from_engine_folder, FinishReason, GenerationStep, GrammarType, + TensorRtLlmBackendImpl, }; use crate::utils::first_line; @@ -105,6 +106,16 @@ fn request_looper( 1 }; + let (grammar_type, grammar_value): (GrammarType, &str) = + if let Some(grammar) = &generation_params.grammar { + match grammar { + ValidGrammar::Json(v) => (GrammarType::Json, v), + ValidGrammar::Regex(v) => (GrammarType::Regex, v), + } + } else { + (GrammarType::None, "") + }; + // Submit to the TensorRT-LLM executor for scheduling match backend.submit( &input_ids.unwrap(), // This is checked beforehand in validate() @@ -115,6 +126,8 @@ fn request_looper( generation_params.repetition_penalty, generation_params.frequency_penalty, generation_params.seed, + grammar_type, + grammar_value, ) { Ok(request_id) => { // Insert the context linked to the generated request id in the tracker @@ -392,9 +405,25 @@ impl TensorRtLlmBackendV2 { // to rust Instant. let created_time = Instant::now(); + let encoded_vocab = { + let vocab = tokenizer.get_vocab(true); + let mut tokens: Vec = vocab.keys().map(|x| x.clone()).collect(); + tokens.sort_by(|a, b| vocab.get(a).cmp(&vocab.get(b))); + tokens + }; + + let tokenizer_str = tokenizer + .to_string(false) + .map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?; + // Create the FFI backend - let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path) - .map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?; + let backend = create_backend_from_engine_folder( + &engine_folder, + &executor_worker_path, + &tokenizer_str, + encoded_vocab, + ) + .map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?; let backend = Arc::new(backend); let backend_response = backend.clone(); @@ -425,11 +454,6 @@ impl TensorRtLlmBackendV2 { return Err(ValidationError(TopNTokensDisabled)); } - // TODO: Is it really needed? How can it be validated before? - if request.parameters.grammar.is_some() { - return Err(ValidationError(Grammar)); - } - match request.inputs.len() { 0 => Err(ValidationError(EmptyInput)), 2.. => Err(GenerationError( diff --git a/backends/trtllm/src/main.rs b/backends/trtllm/src/main.rs index 81fca0e7e57..0e28ad02f24 100644 --- a/backends/trtllm/src/main.rs +++ b/backends/trtllm/src/main.rs @@ -67,6 +67,8 @@ struct Args { usage_stats: UsageStatsLevel, #[clap(default_value = "2000000", long, env)] payload_limit: usize, + #[clap(long, env, default_value_t = false)] + disable_grammar_support: bool, } async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option { @@ -244,6 +246,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { executor_worker, usage_stats, payload_limit, + disable_grammar_support, } = args; // Launch Tokio runtime @@ -321,7 +324,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> { false, None, None, - true, + disable_grammar_support, max_client_batch_size, usage_stats, payload_limit,