2121using json = nlohmann::json;
2222using namespace tensorrtllm ;
2323
24+ namespace {
25+ constexpr const int k200OK = 200 ;
26+ constexpr const int k400BadRequest = 400 ;
27+ constexpr const int k409Conflict = 409 ;
28+ constexpr const int k500InternalServerError = 500 ;
29+
30+ // https://nvidia.github.io/TensorRT-LLM/_cpp_gen/runtime.html#generationinput-h
31+ // stopWordsList
32+ // 'im', '_' , 'end', '</s>', '<|im_end|>'
33+ const std::vector<int32_t > kOpenhermesStopWords = {321 , 28730 , 416 , 2 , 32000 , 3 , 4 , 5 , -1 , -1 };
34+ const std::string kOhUserPrompt = " <|im_end|>\n <|im_start|>user\n " ;
35+ const std::string kOhAiPrompt = " <|im_end|>\n <|im_start|>assistant\n " ;
36+ const std::string kOhSystemPrompt = " <|im_start|>system\n " ;
37+ const std::unordered_map<std::string, int > kOpenhermesTemplate = {{" <|im_end|>" , 32000 } , {" <|im_start|>" , 32001 }};
38+
39+ // '[', 'INST', ']', '[INST]', ''[, '/' , 'INST',']', '[/INST]', '</s>'
40+ const std::vector<int32_t > kMistral_V0_3_StopWords
41+ = {29560 , 17057 , 29561 , 3 , 29560 , 29516 , 17057 , 29561 , 4 , 2 , 3 , 4 , 8 , 9 , 10 , -1 , -1 , -1 , -1 , -1 };
42+ const std::string kMistralUserPrompt = " [INST] " ;
43+ const std::string kMistralAiPrompt = " [/INST] " ;
44+ const std::string kMistralSystemPrompt = " <s>" ;
45+ const std::unordered_map<std::string, int > kMistralTemplate = {{" [INST]" , 3 } , {" [/INST]" , 4 }};
46+
47+ // TODO(sang) This is fragile, just a temporary solution. Maybe can use a config file or model architect, etc...
48+ bool IsOpenhermes (const std::string& s) {
49+ if (s.find (" mistral" ) != std::string::npos || s.find (" Mistral" ) != std::string::npos) {
50+ return false ;
51+ }
52+ return true ;
53+ }
54+
55+ std::string GetUserPrompt (bool is_openhermes) {
56+ if (is_openhermes) {
57+ return kOhUserPrompt ;
58+ }
59+ return kMistralUserPrompt ;
60+ }
2461
25- constexpr const int k200OK = 200 ;
26- constexpr const int k400BadRequest = 400 ;
27- constexpr const int k409Conflict = 409 ;
28- constexpr const int k500InternalServerError = 500 ;
62+ std::string GetAiPrompt (bool is_openhermes) {
63+ if (is_openhermes) {
64+ return kOhAiPrompt ;
65+ }
66+ return kMistralAiPrompt ;
67+ }
2968
69+ std::string GetSystemPrompt (bool is_openhermes) {
70+ if (is_openhermes) {
71+ return kOhSystemPrompt ;
72+ }
73+ return kMistralSystemPrompt ;
74+ }
75+ }
3076TensorrtllmEngine::~TensorrtllmEngine () {}
3177
3278void RemoveId (std::vector<int >& vec, int id) {
3379 vec.erase (std::remove (vec.begin (), vec.end (), id), vec.end ());
3480}
3581
36- bool HandleMatch (std::string const & rew_text, std::shared_ptr<InferenceState> infer_state) {
37- if (infer_state->IsComplete ()) {
82+ bool HandleMatch (std::string const & rew_text, std::shared_ptr<InferenceState> infer_state, bool is_openhermes ) {
83+ if (infer_state->IsComplete (is_openhermes )) {
3884 return false ;
3985 }
4086 if (infer_state->stop_word_match_len == 0 ) {
41- if (rew_text.find (' <' ) != std::string::npos) { // Found "<" anywhere in the text
87+ if ((is_openhermes && rew_text.find (' <' ) != std::string::npos) ||
88+ (!is_openhermes && rew_text.find (' [' ) != std::string::npos)) {
4289 infer_state->stop_word_match_len ++; // Move to next state
4390 infer_state->prev_text = rew_text;
4491 return true ;
4592 }
4693 }
47- else if (rew_text == infer_state->sequence [ infer_state->stop_word_match_len ] ) {
94+ else if (rew_text == infer_state->GetSequence (is_openhermes, infer_state->stop_word_match_len ) ) {
4895 infer_state->stop_word_match_len ++; // Move to next state
4996 infer_state->prev_text = rew_text;
5097 return true ;
5198 }
52- else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->sequence [ 0 ] ) {
99+ else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence (is_openhermes, 0u ) ) {
53100 infer_state->stop_word_match_len = 1 ; // Restart from first match if sequence breaks but matches start
54101 infer_state->prev_text = rew_text;
55102 return true ;
@@ -67,9 +114,11 @@ GenerationInput::TensorPtr TensorrtllmEngine::GetTensorSingleStopWordList(int st
67114}
68115
69116GenerationInput::TensorPtr TensorrtllmEngine::GetTensorChatMLStopWordList () {
70- std::vector<int32_t > stop_words_tokens
71- = {321 , 28730 , 416 , 2 , 32000 , 3 , 4 , 5 , -1 , -1 }; // Extend with -1 for increased length
72- return gpt_session->getBufferManager ().copyFrom (stop_words_tokens, ITensor::makeShape ({1 , 2 , 5 }), MemoryType::kGPU );
117+ if (is_openhermes_) {
118+ return gpt_session->getBufferManager ().copyFrom (kOpenhermesStopWords , ITensor::makeShape ({1 , 2 , static_cast <int >(kOpenhermesStopWords .size ()/2 )}), MemoryType::kGPU );
119+ } else {
120+ return gpt_session->getBufferManager ().copyFrom (kMistral_V0_3_StopWords , ITensor::makeShape ({1 , 2 , static_cast <int >(kMistral_V0_3_StopWords .size ()/2 )}), MemoryType::kGPU );
121+ }
73122}
74123
75124GenerationInput TensorrtllmEngine::CreateGenerationInput (std::vector<int32_t > input_ids_host) {
@@ -102,27 +151,35 @@ void InferenceThread(
102151 TensorrtllmEngine* self,
103152 SamplingConfig sampling_config,
104153 int input_len,
105- int outputLen) {
154+ int outputLen, bool is_openhermes ) {
106155
107156 // Input preparation
108157 LOG_INFO << " Inference thread started" ;
109158 GenerationInput generation_input = self->CreateGenerationInput (input_ids_host);
110159 GenerationOutput generation_output = self->CreateGenerationOutput ();
111160
112161 // Define the callback to stream each generated token
113- generation_output.onTokenGenerated = [&infer_state, input_len, outputLen, self, &generation_output](
162+ generation_output.onTokenGenerated = [&infer_state, input_len, outputLen, self, &generation_output, is_openhermes ](
114163 GenerationOutput::TensorPtr const & output_ids, SizeType step, bool finished) {
115- LOG_INFO << " Generating tokenizer in thread" ;
164+ // LOG_INFO << "Generating tokenizer in thread";
116165 // Assuming the shape of output_ids tensor is (1, 1, 160), where 160 is the number of tokens
117166 int output_length = output_ids->getShape ().d [2 ]; // Get the length of output IDs based on the tensor shape
118167 // Copy output IDs from GPU to host for printing
119168 std::vector<int32_t > output_idsHost (output_length);
120169 self->gpt_session ->getBufferManager ().copy (*output_ids, output_idsHost.data (), MemoryType::kCPU );
121170 // Find the last non-zero value in the output IDs starting from the end of the input sequence
122171 std::vector<int > output_idsHostDecode (output_idsHost.begin () + input_len, output_idsHost.end ());
172+
123173 RemoveId (output_idsHostDecode, 0 );
124- RemoveId (output_idsHostDecode, 32000 );
125- RemoveId (output_idsHostDecode, 32001 );
174+ if (is_openhermes) {
175+ for (auto const & [_, v]: kOpenhermesTemplate ) {
176+ RemoveId (output_idsHostDecode, v);
177+ }
178+ } else {
179+ for (auto const & [_, v]: kMistralTemplate ) {
180+ RemoveId (output_idsHostDecode, v);
181+ }
182+ }
126183 std::string text = self->cortex_tokenizer ->Decode (output_idsHostDecode);
127184
128185 if (infer_state->prev_pos >= 0 && infer_state->prev_pos < text.size ()) {
@@ -225,6 +282,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
225282 }
226283 }
227284 formatted_input += ai_prompt;
285+ // LOG_INFO << formatted_input;
228286 // Format the input from user
229287
230288 std::shared_ptr<InferenceState> infer_state = std::make_shared<InferenceState>();
@@ -243,23 +301,25 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
243301 sampling_config.repetitionPenalty = std::vector{request.frequency_penalty };
244302 // Input preparation
245303
246- std::thread inference_thread (InferenceThread, infer_state, input_ids_host, callback, this , sampling_config, input_len, outputLen);
304+ std::thread inference_thread (InferenceThread, infer_state, input_ids_host, callback, this , sampling_config, input_len, outputLen, is_openhermes_ );
247305 inference_thread.detach (); // Detach the thread to allow it to run independently
248306
249- q_->runTaskInQueue ([cb = std::move (callback), infer_state]() {
307+ q_->runTaskInQueue ([this , cb = std::move (callback), infer_state]() {
308+ // std::string res_str;
250309 LOG_INFO << " Preparing to run inference task queue..." ;
251310 while (true ) { // Continuously check if the queue is not empty
252311 std::unique_lock<std::mutex> lock (infer_state->queue_mutex ); // Lock the queue for exclusive access
253312 if (!infer_state->texts_to_stream .empty ()) {
254313 std::string rew_text = infer_state->texts_to_stream .front ();
314+ // res_str += rew_text;
255315 infer_state->texts_to_stream .pop ();
256- if (HandleMatch (rew_text, infer_state) && rew_text != " [DONE]" ) {
316+ if (HandleMatch (rew_text, infer_state, is_openhermes_ ) && rew_text != " [DONE]" ) {
257317 continue ;
258318 };
259319
260320 if (rew_text == " [DONE]" ) {
261321 const std::string str
262- = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), " _ " , " " , " stop" )
322+ = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), model_id_ , " " , " stop" )
263323 + " \n\n " + " data: [DONE]" + " \n\n " ;
264324
265325 infer_state->is_finished = true ;
@@ -275,7 +335,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
275335 break ;
276336 }
277337 const std::string text_to_stream
278- = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), " _ " , rew_text) + " \n\n " ;
338+ = " data: " + tensorrtllm_utils::CreateReturnJson (tensorrtllm_utils::GenerateRandomString (20 ), model_id_ , rew_text) + " \n\n " ;
279339
280340 lock.unlock (); // Unlock as soon as possible
281341 infer_state->prev_text = rew_text;
@@ -293,6 +353,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
293353 lock.unlock ();
294354 }
295355 }
356+ // LOG_INFO << res_str;
296357 });
297358
298359 LOG_INFO << " Inference completed" ;
@@ -302,11 +363,12 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
302363void TensorrtllmEngine::LoadModel (std::shared_ptr<Json::Value> json_body, std::function<void (Json::Value&&, Json::Value&&)>&& callback) {
303364 model::LoadModelRequest request = model::fromJson (json_body);
304365 std::filesystem::path model_dir = request.model_path ;
366+ is_openhermes_ = IsOpenhermes (request.model_path );
305367
306368 int ctx_len = request.ctx_len ;
307- this ->user_prompt = request.user_prompt ;
308- this ->ai_prompt = request.ai_prompt ;
309- this ->system_prompt = request.system_prompt ;
369+ this ->user_prompt = request.user_prompt . empty () ? GetUserPrompt (is_openhermes_) : request. user_prompt ;
370+ this ->ai_prompt = request.ai_prompt . empty () ? GetAiPrompt (is_openhermes_) : request. ai_prompt ;
371+ this ->system_prompt = request.system_prompt . empty () ? GetSystemPrompt (is_openhermes_) : request. system_prompt ;
310372 this ->model_id_ = GetModelId (*json_body);
311373
312374 logger = std::make_shared<TllmLogger>();
0 commit comments