Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 9998c48

Browse files
committed
fix: support Mistral v0.3
1 parent c7703f1 commit 9998c48

File tree

3 files changed

+111
-34
lines changed

3 files changed

+111
-34
lines changed

cpp/tensorrt_llm/cortex.tensorrt-llm/src/models/load_model_request.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ struct LoadModelRequest {
88
int ctx_len = 2048;
99
int n_parallel = 1;
1010
std::string model_path;
11-
std::string user_prompt = "<|im_end|>\n<|im_start|>user\n";
12-
std::string ai_prompt = "<|im_end|>\n<|im_start|>user\n";
13-
std::string system_prompt = "<|im_end|>\n<|im_start|>user\n";
11+
std::string user_prompt = "";
12+
std::string ai_prompt = "";
13+
std::string system_prompt = "";
1414
};
1515

1616
inline LoadModelRequest fromJson(std::shared_ptr<Json::Value> json_body) {
@@ -19,9 +19,9 @@ inline LoadModelRequest fromJson(std::shared_ptr<Json::Value> json_body) {
1919
request.ctx_len = json_body->get("ctx_len", 2048).asInt();
2020
request.n_parallel = json_body->get("n_parallel", 1).asInt();
2121
request.model_path = json_body->get("model_path", "").asString();
22-
request.user_prompt = json_body->get("user_prompt", "<|im_end|>\n<|im_start|>user\n").asString();
23-
request.ai_prompt = json_body->get("ai_prompt", "<|im_end|>\n<|im_start|>assistant\n").asString();
24-
request.system_prompt = json_body->get("system_prompt", "<|im_start|>system\n").asString();
22+
request.user_prompt = json_body->get("user_prompt", "").asString();
23+
request.ai_prompt = json_body->get("ai_prompt", "").asString();
24+
request.system_prompt = json_body->get("system_prompt", "").asString();
2525
}
2626
return request;
2727
}

cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc

Lines changed: 87 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,35 +21,82 @@
2121
using json = nlohmann::json;
2222
using namespace tensorrtllm;
2323

24+
namespace {
25+
constexpr const int k200OK = 200;
26+
constexpr const int k400BadRequest = 400;
27+
constexpr const int k409Conflict = 409;
28+
constexpr const int k500InternalServerError = 500;
29+
30+
// https://nvidia.github.io/TensorRT-LLM/_cpp_gen/runtime.html#generationinput-h
31+
// stopWordsList
32+
// 'im', '_' , 'end', '</s>', '<|im_end|>'
33+
const std::vector<int32_t> kOpenhermesStopWords = {321, 28730, 416, 2, 32000, 3, 4, 5, -1, -1};
34+
const std::string kOhUserPrompt = "<|im_end|>\n<|im_start|>user\n";
35+
const std::string kOhAiPrompt = "<|im_end|>\n<|im_start|>assistant\n";
36+
const std::string kOhSystemPrompt = "<|im_start|>system\n";
37+
const std::unordered_map<std::string, int> kOpenhermesTemplate = {{"<|im_end|>", 32000} , {"<|im_start|>", 32001}};
38+
39+
// '[', 'INST', ']', '[INST]', ''[, '/' , 'INST',']', '[/INST]', '</s>'
40+
const std::vector<int32_t> kMistral_V0_3_StopWords
41+
= {29560, 17057, 29561, 3, 29560, 29516, 17057, 29561, 4, 2, 3, 4, 8, 9, 10, -1, -1, -1, -1, -1};
42+
const std::string kMistralUserPrompt = "[INST] ";
43+
const std::string kMistralAiPrompt = "[/INST] ";
44+
const std::string kMistralSystemPrompt = "<s>";
45+
const std::unordered_map<std::string, int> kMistralTemplate = {{"[INST]", 3} , {"[/INST]", 4}};
46+
47+
// TODO(sang) This is fragile, just a temporary solution. Maybe can use a config file or model architect, etc...
48+
bool IsOpenhermes(const std::string& s) {
49+
if (s.find("mistral") != std::string::npos || s.find("Mistral") != std::string::npos) {
50+
return false;
51+
}
52+
return true;
53+
}
54+
55+
std::string GetUserPrompt(bool is_openhermes) {
56+
if(is_openhermes) {
57+
return kOhUserPrompt;
58+
}
59+
return kMistralUserPrompt;
60+
}
2461

25-
constexpr const int k200OK = 200;
26-
constexpr const int k400BadRequest = 400;
27-
constexpr const int k409Conflict = 409;
28-
constexpr const int k500InternalServerError = 500;
62+
std::string GetAiPrompt(bool is_openhermes) {
63+
if(is_openhermes) {
64+
return kOhAiPrompt;
65+
}
66+
return kMistralAiPrompt;
67+
}
2968

69+
std::string GetSystemPrompt(bool is_openhermes) {
70+
if(is_openhermes) {
71+
return kOhSystemPrompt;
72+
}
73+
return kMistralSystemPrompt;
74+
}
75+
}
3076
TensorrtllmEngine::~TensorrtllmEngine() {}
3177

3278
void RemoveId(std::vector<int>& vec, int id) {
3379
vec.erase(std::remove(vec.begin(), vec.end(), id), vec.end());
3480
}
3581

36-
bool HandleMatch(std::string const& rew_text, std::shared_ptr<InferenceState> infer_state) {
37-
if (infer_state->IsComplete()) {
82+
bool HandleMatch(std::string const& rew_text, std::shared_ptr<InferenceState> infer_state, bool is_openhermes) {
83+
if (infer_state->IsComplete(is_openhermes)) {
3884
return false;
3985
}
4086
if (infer_state->stop_word_match_len == 0) {
41-
if (rew_text.find('<') != std::string::npos) { // Found "<" anywhere in the text
87+
if ((is_openhermes && rew_text.find('<') != std::string::npos) ||
88+
(!is_openhermes && rew_text.find('[') != std::string::npos)) {
4289
infer_state->stop_word_match_len++; // Move to next state
4390
infer_state->prev_text = rew_text;
4491
return true;
4592
}
4693
}
47-
else if (rew_text == infer_state->sequence[infer_state->stop_word_match_len]) {
94+
else if (rew_text == infer_state->GetSequence(is_openhermes, infer_state->stop_word_match_len)) {
4895
infer_state->stop_word_match_len++; // Move to next state
4996
infer_state->prev_text = rew_text;
5097
return true;
5198
}
52-
else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->sequence[0]) {
99+
else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence(is_openhermes, 0u)) {
53100
infer_state->stop_word_match_len = 1; // Restart from first match if sequence breaks but matches start
54101
infer_state->prev_text = rew_text;
55102
return true;
@@ -67,9 +114,11 @@ GenerationInput::TensorPtr TensorrtllmEngine::GetTensorSingleStopWordList(int st
67114
}
68115

69116
GenerationInput::TensorPtr TensorrtllmEngine::GetTensorChatMLStopWordList() {
70-
std::vector<int32_t> stop_words_tokens
71-
= {321, 28730, 416, 2, 32000, 3, 4, 5, -1, -1}; // Extend with -1 for increased length
72-
return gpt_session->getBufferManager().copyFrom(stop_words_tokens, ITensor::makeShape({1, 2, 5}), MemoryType::kGPU);
117+
if(is_openhermes_) {
118+
return gpt_session->getBufferManager().copyFrom(kOpenhermesStopWords, ITensor::makeShape({1, 2, static_cast<int>(kOpenhermesStopWords.size()/2)}), MemoryType::kGPU);
119+
} else {
120+
return gpt_session->getBufferManager().copyFrom(kMistral_V0_3_StopWords, ITensor::makeShape({1, 2, static_cast<int>(kMistral_V0_3_StopWords.size()/2)}), MemoryType::kGPU);
121+
}
73122
}
74123

75124
GenerationInput TensorrtllmEngine::CreateGenerationInput(std::vector<int32_t> input_ids_host) {
@@ -102,27 +151,35 @@ void InferenceThread(
102151
TensorrtllmEngine* self,
103152
SamplingConfig sampling_config,
104153
int input_len,
105-
int outputLen) {
154+
int outputLen, bool is_openhermes) {
106155

107156
// Input preparation
108157
LOG_INFO << "Inference thread started";
109158
GenerationInput generation_input = self->CreateGenerationInput(input_ids_host);
110159
GenerationOutput generation_output = self->CreateGenerationOutput();
111160

112161
// Define the callback to stream each generated token
113-
generation_output.onTokenGenerated = [&infer_state, input_len, outputLen, self, &generation_output](
162+
generation_output.onTokenGenerated = [&infer_state, input_len, outputLen, self, &generation_output, is_openhermes](
114163
GenerationOutput::TensorPtr const& output_ids, SizeType step, bool finished) {
115-
LOG_INFO << "Generating tokenizer in thread";
164+
// LOG_INFO << "Generating tokenizer in thread";
116165
// Assuming the shape of output_ids tensor is (1, 1, 160), where 160 is the number of tokens
117166
int output_length = output_ids->getShape().d[2]; // Get the length of output IDs based on the tensor shape
118167
// Copy output IDs from GPU to host for printing
119168
std::vector<int32_t> output_idsHost(output_length);
120169
self->gpt_session->getBufferManager().copy(*output_ids, output_idsHost.data(), MemoryType::kCPU);
121170
// Find the last non-zero value in the output IDs starting from the end of the input sequence
122171
std::vector<int> output_idsHostDecode(output_idsHost.begin() + input_len, output_idsHost.end());
172+
123173
RemoveId(output_idsHostDecode, 0);
124-
RemoveId(output_idsHostDecode, 32000);
125-
RemoveId(output_idsHostDecode, 32001);
174+
if(is_openhermes) {
175+
for(auto const& [_, v]: kOpenhermesTemplate) {
176+
RemoveId(output_idsHostDecode, v);
177+
}
178+
} else {
179+
for(auto const& [_, v]: kMistralTemplate) {
180+
RemoveId(output_idsHostDecode, v);
181+
}
182+
}
126183
std::string text = self->cortex_tokenizer->Decode(output_idsHostDecode);
127184

128185
if (infer_state->prev_pos >= 0 && infer_state->prev_pos < text.size()) {
@@ -225,6 +282,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
225282
}
226283
}
227284
formatted_input += ai_prompt;
285+
// LOG_INFO << formatted_input;
228286
// Format the input from user
229287

230288
std::shared_ptr<InferenceState> infer_state = std::make_shared<InferenceState>();
@@ -243,23 +301,25 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
243301
sampling_config.repetitionPenalty = std::vector{request.frequency_penalty};
244302
// Input preparation
245303

246-
std::thread inference_thread(InferenceThread, infer_state, input_ids_host, callback, this, sampling_config, input_len, outputLen);
304+
std::thread inference_thread(InferenceThread, infer_state, input_ids_host, callback, this, sampling_config, input_len, outputLen, is_openhermes_);
247305
inference_thread.detach(); // Detach the thread to allow it to run independently
248306

249-
q_->runTaskInQueue([cb = std::move(callback), infer_state]() {
307+
q_->runTaskInQueue([this, cb = std::move(callback), infer_state]() {
308+
// std::string res_str;
250309
LOG_INFO << "Preparing to run inference task queue...";
251310
while (true) { // Continuously check if the queue is not empty
252311
std::unique_lock<std::mutex> lock(infer_state->queue_mutex); // Lock the queue for exclusive access
253312
if (!infer_state->texts_to_stream.empty()) {
254313
std::string rew_text = infer_state->texts_to_stream.front();
314+
// res_str += rew_text;
255315
infer_state->texts_to_stream.pop();
256-
if (HandleMatch(rew_text, infer_state) && rew_text != "[DONE]") {
316+
if (HandleMatch(rew_text, infer_state, is_openhermes_) && rew_text != "[DONE]") {
257317
continue;
258318
};
259319

260320
if (rew_text == "[DONE]") {
261321
const std::string str
262-
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", "", "stop")
322+
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), model_id_, "", "stop")
263323
+ "\n\n" + "data: [DONE]" + "\n\n";
264324

265325
infer_state->is_finished = true;
@@ -275,7 +335,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
275335
break;
276336
}
277337
const std::string text_to_stream
278-
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", rew_text) + "\n\n";
338+
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), model_id_, rew_text) + "\n\n";
279339

280340
lock.unlock(); // Unlock as soon as possible
281341
infer_state->prev_text = rew_text;
@@ -293,6 +353,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
293353
lock.unlock();
294354
}
295355
}
356+
// LOG_INFO << res_str;
296357
});
297358

298359
LOG_INFO << "Inference completed";
@@ -302,11 +363,12 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
302363
void TensorrtllmEngine::LoadModel(std::shared_ptr<Json::Value> json_body, std::function<void(Json::Value&&, Json::Value&&)>&& callback) {
303364
model::LoadModelRequest request = model::fromJson(json_body);
304365
std::filesystem::path model_dir = request.model_path;
366+
is_openhermes_ = IsOpenhermes(request.model_path);
305367

306368
int ctx_len = request.ctx_len;
307-
this->user_prompt = request.user_prompt;
308-
this->ai_prompt = request.ai_prompt;
309-
this->system_prompt = request.system_prompt;
369+
this->user_prompt = request.user_prompt.empty() ? GetUserPrompt(is_openhermes_) : request.user_prompt;
370+
this->ai_prompt = request.ai_prompt.empty() ? GetAiPrompt(is_openhermes_) : request.ai_prompt;
371+
this->system_prompt = request.system_prompt.empty() ? GetSystemPrompt(is_openhermes_) : request.system_prompt;
310372
this->model_id_ = GetModelId(*json_body);
311373

312374
logger = std::make_shared<TllmLogger>();

cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,30 @@ struct InferenceState {
7373
std::queue<std::string> texts_to_stream;
7474
std::mutex queue_mutex; // Mutex to protect access to textsToStream
7575
size_t stop_word_match_len = 0;
76-
std::vector<std::string> sequence{"<", "|", "im", "_", "end", "|", ">"};
76+
std::vector<std::string> sequence_openhermes = {"<", "|", "im", "_", "end", "|", ">"};
77+
std::vector<std::string> sequence_mistral = {"[", "INST", "]"};
7778
int token_gen_count = 0;
7879

7980
void Reset() {
8081
stop_word_match_len = 0;
8182
prev_text = "";
8283
}
8384

84-
bool IsComplete() const {
85-
return stop_word_match_len >= sequence.size();
85+
bool IsComplete(bool is_openhermes) const {
86+
if(is_openhermes) {
87+
return stop_word_match_len >= sequence_openhermes.size();
88+
} else {
89+
return stop_word_match_len >= sequence_mistral.size();
90+
}
91+
}
92+
93+
const std::string& GetSequence(bool is_openhermes, size_t index) {
94+
if(is_openhermes) {
95+
return sequence_openhermes[index];
96+
} else {
97+
return sequence_mistral[index];
98+
}
99+
86100
}
87101
};
88102

@@ -138,6 +152,7 @@ class TensorrtllmEngine : public EngineI {
138152
uint64_t start_time_;
139153
std::atomic<bool> model_loaded_;
140154
std::unique_ptr<trantor::ConcurrentTaskQueue> q_;
155+
bool is_openhermes_ = true;
141156
};
142157

143158
} // namespace inferences

0 commit comments

Comments
 (0)