Skip to content

Commit bf04088

Browse files
committed
feat(trtllm): support guided decoding
TGI already accepts grammar for guided decoding through its HTTP API, however, this feature has been disabled for the trtllm backend. To enable this feature: - Replace the hard-coded disable of the grammar support with the `disable_grammar_support` arg present in the v3 backend. - Pass tokenizer information when constructing the trtllm Executor and enable guided decoding by default. - Pass the validated grammar type and value from requests to the Executor.
1 parent 34307a4 commit bf04088

File tree

6 files changed

+122
-30
lines changed

6 files changed

+122
-30
lines changed

backends/trtllm/csrc/backend.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace huggingface::tgi::backends::trtllm {
2626
}
2727

2828

29-
tle::ExecutorConfig backend_workspace_t::executor_config() const {
29+
tle::ExecutorConfig backend_workspace_t::executor_config(const std::vector<std::string>& encoded_vocab, std::string_view tokenizer_str) const {
3030
// Retrieve the compute capabilities to enable some options at runtime
3131
const auto compute_capabilities = hardware::cuda::compute_capabilities_t();
3232

@@ -40,17 +40,24 @@ namespace huggingface::tgi::backends::trtllm {
4040
executor_config.setKvCacheConfig(tle::KvCacheConfig(true));
4141
executor_config.setEnableChunkedContext(compute_capabilities.is_at_least_ampere());
4242
executor_config.setSchedulerConfig(tle::SchedulerConfig(tle::CapacitySchedulerPolicy::kMAX_UTILIZATION));
43+
executor_config.setGuidedDecodingConfig(tle::GuidedDecodingConfig(
44+
tle::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR,
45+
encoded_vocab,
46+
std::string(tokenizer_str),
47+
generation_config().eos_token_ids
48+
));
4349
return executor_config;
4450
}
4551

46-
backend_t::backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path)
47-
: workspace(engines_folder, executor_worker_path), executor_(executor_factory_initializer(workspace)) {}
52+
backend_t::backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path, const std::vector<std::string> &encoded_vocab, std::string_view tokenizer_str)
53+
: workspace(engines_folder, executor_worker_path),
54+
executor_(executor_factory_initializer(workspace, encoded_vocab, tokenizer_str)) {}
4855

4956
std::expected<request_id_t, backend_error_t>
5057
backend_t::submit(std::span<const token_id_t> token_ids, const generation_params_t g_params,
5158
const sampling_params_t s_params) noexcept {
5259
SPDLOG_DEBUG("Submit {:d} tokens for scheduling ({}, {})", token_ids.size(), g_params, s_params);
53-
return executor_.enqueueRequest(tle::Request{
60+
tle::Request req {
5461
{token_ids.begin(), token_ids.end()}, // Making actual copy of the tokens
5562
static_cast<tle::SizeType32>(g_params.max_new_tokens),
5663
true,
@@ -68,7 +75,15 @@ namespace huggingface::tgi::backends::trtllm {
6875
std::nullopt,
6976
std::nullopt,
7077
workspace.generation_config().stop_words
71-
});
78+
};
79+
80+
if (g_params.guide_type.has_value()) {
81+
req.setGuidedDecodingParams(tle::GuidedDecodingParams(
82+
g_params.guide_type.value(),
83+
g_params.guide
84+
));
85+
}
86+
return executor_.enqueueRequest(req);
7287
}
7388

7489
std::vector<tle::Response> backend_t::pull_tokens() noexcept {

backends/trtllm/csrc/backend.hpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ namespace huggingface::tgi::backends::trtllm {
2525
*/
2626
struct generation_params_t {
2727
uint32_t max_new_tokens;
28+
std::optional<tle::GuidedDecodingParams::GuideType> guide_type;
29+
std::string guide;
2830
};
2931

3032
/**
@@ -66,23 +68,28 @@ namespace huggingface::tgi::backends::trtllm {
6668
float_t top_p;
6769
float_t temperature;
6870
std::list<std::vector<int32_t>> stop_words;
71+
std::vector<int32_t> eos_token_ids;
6972

7073
constexpr explicit generation_config_t(const json &config) :
71-
top_p(config.value("top_p", 1.0f)), temperature(config.value("temperature", 1.0f)), stop_words(0) {
74+
top_p(config.value("top_p", 1.0f)), temperature(config.value("temperature", 1.0f)), stop_words(0), eos_token_ids{} {
7275
if (!config.contains("/eos_token_id"_json_pointer)) {
7376
return;
7477
}
7578
if (config["/eos_token_id"_json_pointer].is_array()) {
7679
SPDLOG_DEBUG("generation config eos_token_id is array");
7780
const auto &eos_token_id = config["/eos_token_id"_json_pointer];
7881
std::for_each(eos_token_id.begin(), eos_token_id.end(), [this](const auto token_id) {
79-
stop_words.emplace_back(1, token_id.template get<int32_t>());
82+
const auto token = token_id.template get<int32_t>();
83+
stop_words.emplace_back(1, token);
84+
eos_token_ids.emplace_back(token);
8085
});
8186
}
8287

8388
if (config["/eos_token_id"_json_pointer].is_number()) {
8489
SPDLOG_DEBUG("generation config eos_token_id is number");
85-
stop_words.emplace_back(1, config["/eos_token_id"_json_pointer].get<int32_t>());
90+
const auto token = config["/eos_token_id"_json_pointer].get<int32_t>();
91+
stop_words.emplace_back(1, token);
92+
eos_token_ids.emplace_back(token);
8693
}
8794

8895
SPDLOG_DEBUG("Detected {:d} predefined stop_words from generation_config.json", stop_words.size());
@@ -143,7 +150,7 @@ namespace huggingface::tgi::backends::trtllm {
143150
* to initialize `tensorrt_llm::executor::Executor`
144151
* @return `tensorrt_llm::executor::ExecutorConfig` instance
145152
*/
146-
[[nodiscard]] tle::ExecutorConfig executor_config() const;
153+
[[nodiscard]] tle::ExecutorConfig executor_config(const std::vector<std::string>& encoded_vocab, std::string_view tokenizer_str) const;
147154
};
148155

149156
/**
@@ -167,10 +174,10 @@ namespace huggingface::tgi::backends::trtllm {
167174
tle::Executor executor_;
168175

169176
public:
170-
backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path);
177+
backend_t(std::filesystem::path &engines_folder, std::filesystem::path &executor_worker_path, const std::vector<std::string> &encoded_vocab, std::string_view tokenizer_str);
171178

172-
backend_t(std::filesystem::path &&engines_folder, std::filesystem::path &&executor_worker_path)
173-
: backend_t(engines_folder, executor_worker_path) {};
179+
backend_t(std::filesystem::path &&engines_folder, std::filesystem::path &&executor_worker_path, const std::vector<std::string> &encoded_vocab, std::string_view tokenizer_str)
180+
: backend_t(engines_folder, executor_worker_path, encoded_vocab, tokenizer_str) {};
174181

175182
/**
176183
* Submit a new request to the executor
@@ -201,9 +208,9 @@ namespace huggingface::tgi::backends::trtllm {
201208
/**
202209
* Create a TensorRT-LLM executor from a workspace
203210
*/
204-
const auto executor_factory_initializer = [](const backend_workspace_t &workspace) -> tle::Executor {
211+
const auto executor_factory_initializer = [](const backend_workspace_t &workspace, const std::vector<std::string> &encoded_vocab, std::string_view tokenizer_str) -> tle::Executor {
205212
return {workspace.engines_folder(), tensorrt_llm::executor::ModelType::kDECODER_ONLY,
206-
workspace.executor_config()};
213+
workspace.executor_config(encoded_vocab, tokenizer_str)};
207214
};
208215
}
209216

backends/trtllm/csrc/ffi.hpp

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <chrono>
55
#include <exception>
66
#include <memory>
7+
#include <optional>
78
#include <thread>
89

910
#include <nvml.h>
@@ -115,8 +116,8 @@ namespace huggingface::tgi::backends::trtllm {
115116

116117

117118
public:
118-
tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path, const std::chrono::time_point<std::chrono::steady_clock>& created_time)
119-
: inner_(engine_folder, executor_worker_path),
119+
tensorrt_llm_backend_t(std::filesystem::path &&engine_folder, std::filesystem::path &&executor_worker_path, const std::chrono::time_point<std::chrono::steady_clock>& created_time, const std::vector<std::string>& encoded_vocab, std::string_view tokenizer_str)
120+
: inner_(engine_folder, executor_worker_path, encoded_vocab, tokenizer_str),
120121
m_created_time {created_time}
121122
{}
122123

@@ -128,16 +129,31 @@ namespace huggingface::tgi::backends::trtllm {
128129
float_t temperature,
129130
float_t repetition_penalty,
130131
float_t frequency_penalty,
131-
uint64_t seed
132+
uint64_t seed,
133+
grammar_type_t grammar_type,
134+
rust::Str grammar_value
132135
) const {
133136
// This is enabled only if using add_compile_definitions(SPDLOG_ACTIVE_LEVEL=SPDLOG_LEVEL_TRACE)
134137
SPDLOG_TRACE(FMT_STRING("[FFI] Submitting {:d} prompt tokens to the executor"));
135138

136139
// Submit the request to the executor and get back a potential request_id used to track request status
137140
const auto signed_tokens = std::vector<int32_t>(tokens.begin(), tokens.end());
141+
142+
std::optional<tle::GuidedDecodingParams::GuideType> guide_type = std::nullopt;
143+
switch (grammar_type) {
144+
case grammar_type_t::kJSON:
145+
guide_type = tle::GuidedDecodingParams::GuideType::kJSON_SCHEMA;
146+
break;
147+
case grammar_type_t::kREGEX:
148+
guide_type = tle::GuidedDecodingParams::GuideType::kREGEX;
149+
break;
150+
default:
151+
break;
152+
}
153+
138154
const auto maybe_request_id = inner_.submit(
139155
signed_tokens,
140-
{max_new_tokens},
156+
{max_new_tokens, guide_type, std::string(grammar_value)},
141157
{top_k, top_p, repetition_penalty, frequency_penalty, temperature, seed}
142158
);
143159

@@ -211,15 +227,25 @@ namespace huggingface::tgi::backends::trtllm {
211227
}
212228

213229
std::unique_ptr<tensorrt_llm_backend_t>
214-
create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path) {
230+
create_backend_from_engine_folder(const rust::Str engines_folder, const rust::Str executor_worker_path, const rust::Str tokenizer_str, const rust::Vec<rust::String> encoded_vocab) {
215231
const auto created_time = std::chrono::steady_clock::now();
216232
std::call_once(backend_initialized_flag, initialize_tensorrt_llm_backend);
233+
234+
std::vector<std::string> encoded_vocab_std{};
235+
encoded_vocab_std.reserve(encoded_vocab.size());
236+
237+
for (const auto& v : encoded_vocab) {
238+
encoded_vocab_std.push_back(std::string(v));
239+
}
240+
217241
return std::make_unique<tensorrt_llm_backend_t>(
218242
std::filesystem::path(std::string_view(engines_folder.begin(), engines_folder.end()),
219243
std::filesystem::path::format::auto_format),
220244
std::filesystem::path(std::string_view(executor_worker_path.begin(), executor_worker_path.end()),
221245
std::filesystem::path::format::auto_format),
222-
created_time
246+
created_time,
247+
encoded_vocab_std,
248+
std::string_view(tokenizer_str)
223249
);
224250
}
225251
}

backends/trtllm/src/lib.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ mod ffi {
7878
fn create_backend_from_engine_folder(
7979
engine_folder: &str,
8080
executor_worker: &str,
81+
tokenizer_str: &str,
82+
encoded_vocab: Vec<String>,
8183
) -> Result<UniquePtr<TensorRtLlmBackendImpl>>;
8284

8385
fn submit(
@@ -90,6 +92,8 @@ mod ffi {
9092
repetition_penalty: f32,
9193
frequency_penalty: f32,
9294
seed: u64,
95+
grammar_type: GrammarType,
96+
grammar_value: &str,
9397
) -> Result<u64>;
9498

9599
fn pull_tokens(
@@ -98,6 +102,19 @@ mod ffi {
98102

99103
fn cancel(self: &TensorRtLlmBackendImpl, request_id: u64);
100104
}
105+
106+
#[cxx_name = "grammar_type_t"]
107+
#[derive(Debug, Clone, Copy)]
108+
pub enum GrammarType {
109+
#[cxx_name = "kNONE"]
110+
None = 0u8,
111+
112+
#[cxx_name = "kJSON"]
113+
Json = 1u8,
114+
115+
#[cxx_name = "kREGEX"]
116+
Regex = 2u8,
117+
}
101118
}
102119

103120
use ffi::FinishReason;

backends/trtllm/src/looper.rs

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@ use text_generation_router::infer::{Backend, GeneratedText, InferError, InferStr
1818
use text_generation_router::validation::ValidationError::{
1919
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
2020
};
21-
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
21+
use text_generation_router::validation::{Chunk, ValidGenerateRequest, ValidGrammar};
2222
use text_generation_router::Token;
2323

2424
use crate::errors::TensorRtLlmBackendError;
2525
use crate::ffi::{
26-
create_backend_from_engine_folder, FinishReason, GenerationStep, TensorRtLlmBackendImpl,
26+
create_backend_from_engine_folder, FinishReason, GenerationStep, GrammarType,
27+
TensorRtLlmBackendImpl,
2728
};
2829
use crate::utils::first_line;
2930

@@ -105,6 +106,16 @@ fn request_looper(
105106
1
106107
};
107108

109+
let (grammar_type, grammar_value): (GrammarType, &str) =
110+
if let Some(grammar) = &generation_params.grammar {
111+
match grammar {
112+
ValidGrammar::Json(v) => (GrammarType::Json, v),
113+
ValidGrammar::Regex(v) => (GrammarType::Regex, v),
114+
}
115+
} else {
116+
(GrammarType::None, "")
117+
};
118+
108119
// Submit to the TensorRT-LLM executor for scheduling
109120
match backend.submit(
110121
&input_ids.unwrap(), // This is checked beforehand in validate()
@@ -115,6 +126,8 @@ fn request_looper(
115126
generation_params.repetition_penalty,
116127
generation_params.frequency_penalty,
117128
generation_params.seed,
129+
grammar_type,
130+
grammar_value,
118131
) {
119132
Ok(request_id) => {
120133
// Insert the context linked to the generated request id in the tracker
@@ -392,9 +405,25 @@ impl TensorRtLlmBackendV2 {
392405
// to rust Instant.
393406
let created_time = Instant::now();
394407

408+
let encoded_vocab = {
409+
let vocab = tokenizer.get_vocab(true);
410+
let mut tokens: Vec<String> = vocab.keys().map(|x| x.clone()).collect();
411+
tokens.sort_by(|a, b| vocab.get(a).cmp(&vocab.get(b)));
412+
tokens
413+
};
414+
415+
let tokenizer_str = tokenizer
416+
.to_string(false)
417+
.map_err(|e| TensorRtLlmBackendError::Tokenizer(e.to_string()))?;
418+
395419
// Create the FFI backend
396-
let backend = create_backend_from_engine_folder(&engine_folder, &executor_worker_path)
397-
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
420+
let backend = create_backend_from_engine_folder(
421+
&engine_folder,
422+
&executor_worker_path,
423+
&tokenizer_str,
424+
encoded_vocab,
425+
)
426+
.map_err(|e| TensorRtLlmBackendError::Runtime(first_line(e.what(), "Unknown error")))?;
398427

399428
let backend = Arc::new(backend);
400429
let backend_response = backend.clone();
@@ -425,11 +454,6 @@ impl TensorRtLlmBackendV2 {
425454
return Err(ValidationError(TopNTokensDisabled));
426455
}
427456

428-
// TODO: Is it really needed? How can it be validated before?
429-
if request.parameters.grammar.is_some() {
430-
return Err(ValidationError(Grammar));
431-
}
432-
433457
match request.inputs.len() {
434458
0 => Err(ValidationError(EmptyInput)),
435459
2.. => Err(GenerationError(

backends/trtllm/src/main.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ struct Args {
6767
usage_stats: UsageStatsLevel,
6868
#[clap(default_value = "2000000", long, env)]
6969
payload_limit: usize,
70+
#[clap(long, env, default_value_t = false)]
71+
disable_grammar_support: bool,
7072
}
7173

7274
async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option<Tokenizer> {
@@ -244,6 +246,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
244246
executor_worker,
245247
usage_stats,
246248
payload_limit,
249+
disable_grammar_support,
247250
} = args;
248251

249252
// Launch Tokio runtime
@@ -321,7 +324,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
321324
false,
322325
None,
323326
None,
324-
true,
327+
disable_grammar_support,
325328
max_client_batch_size,
326329
usage_stats,
327330
payload_limit,

0 commit comments

Comments
 (0)