Skip to content

Commit f38a700

Browse files
committed
fix(trtllm): handle single eos_token_id in generation_config
The type of `eos_token_id` in `transformers.GenerationConfig` is `Union[int, list[int]]` (as of transformers 4.57.0). The original code only parses this field when the value is an array, so the stop_words is not populated for some models. Add code to handle the `int` case as well.
1 parent 0012ef2 commit f38a700

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

backends/trtllm/csrc/backend.hpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,23 @@ namespace huggingface::tgi::backends::trtllm {
6969

7070
constexpr explicit generation_config_t(const json &config) :
7171
top_p(config.value("top_p", 1.0f)), temperature(config.value("temperature", 1.0f)), stop_words(0) {
72-
if (config.contains("/eos_token_id"_json_pointer) && config["/eos_token_id"_json_pointer].is_array()) {
72+
if (!config.contains("/eos_token_id"_json_pointer)) {
73+
return;
74+
}
75+
if (config["/eos_token_id"_json_pointer].is_array()) {
76+
SPDLOG_DEBUG("generation config eos_token_id is array");
7377
const auto &eos_token_id = config["/eos_token_id"_json_pointer];
7478
std::for_each(eos_token_id.begin(), eos_token_id.end(), [this](const auto token_id) {
7579
stop_words.emplace_back(1, token_id.template get<int32_t>());
7680
});
81+
}
7782

78-
SPDLOG_DEBUG("Detected {:d} predefined stop_words from generation_config.json", stop_words.size());
83+
if (config["/eos_token_id"_json_pointer].is_number()) {
84+
SPDLOG_DEBUG("generation config eos_token_id is number");
85+
stop_words.emplace_back(1, config["/eos_token_id"_json_pointer].get<int32_t>());
7986
}
87+
88+
SPDLOG_DEBUG("Detected {:d} predefined stop_words from generation_config.json", stop_words.size());
8089
}
8190
};
8291

0 commit comments

Comments
 (0)