@@ -39,12 +39,12 @@ static std::ostringstream * g_output_ss;
3939static std::vector<llama_token> * g_output_tokens;
4040static bool is_interacting = false ;
4141
42- static bool file_exists (const std::string &path) {
42+ static bool file_exists (const std::string & path) {
4343 std::ifstream f (path.c_str ());
4444 return f.good ();
4545}
4646
47- static bool file_is_empty (const std::string &path) {
47+ static bool file_is_empty (const std::string & path) {
4848 std::ifstream f;
4949 f.exceptions (std::ifstream::failbit | std::ifstream::badbit);
5050 f.open (path.c_str (), std::ios::in | std::ios::binary | std::ios::ate);
@@ -117,6 +117,14 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v
117117 LOG_TEE (" %s" , text);
118118}
119119
120+ static std::string chat_add_and_format (struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
121+ llama_chat_msg new_msg{role, content};
122+ auto formatted = llama_chat_format_single (
123+ model, g_params->chat_template , chat_msgs, new_msg, role == " user" );
124+ chat_msgs.push_back ({role, content});
125+ return formatted;
126+ }
127+
120128int main (int argc, char ** argv) {
121129 gpt_params params;
122130 g_params = ¶ms;
@@ -190,6 +198,7 @@ int main(int argc, char ** argv) {
190198 llama_model * model;
191199 llama_context * ctx;
192200 llama_context * ctx_guidance = NULL ;
201+ std::vector<llama_chat_msg> chat_msgs;
193202 g_model = &model;
194203 g_ctx = &ctx;
195204
@@ -215,6 +224,8 @@ int main(int argc, char ** argv) {
215224 __func__, n_ctx_train, n_ctx);
216225 }
217226
227+ LOG_TEE (" %s: chat template example: %s\n " , __func__, llama_chat_format_example (model, params.chat_template ).c_str ());
228+
218229 // print system information
219230 {
220231 LOG_TEE (" \n " );
@@ -249,16 +260,21 @@ int main(int argc, char ** argv) {
249260
250261 std::vector<llama_token> embd_inp;
251262
252- if (params.interactive_first || !params.prompt .empty () || session_tokens.empty ()) {
253- LOG (" tokenize the prompt\n " );
254- embd_inp = ::llama_tokenize (ctx, params.prompt , true , true );
255- } else {
256- LOG (" use session tokens\n " );
257- embd_inp = session_tokens;
258- }
263+ {
264+ auto prompt = params.conversation
265+ ? chat_add_and_format (model, chat_msgs, " system" , params.prompt ) // format the system prompt in conversation mode
266+ : params.prompt ;
267+ if (params.interactive_first || !params.prompt .empty () || session_tokens.empty ()) {
268+ LOG (" tokenize the prompt\n " );
269+ embd_inp = ::llama_tokenize (ctx, prompt, true , true );
270+ } else {
271+ LOG (" use session tokens\n " );
272+ embd_inp = session_tokens;
273+ }
259274
260- LOG (" prompt: \" %s\"\n " , log_tostr (params.prompt ));
261- LOG (" tokens: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, embd_inp).c_str ());
275+ LOG (" prompt: \" %s\"\n " , log_tostr (prompt));
276+ LOG (" tokens: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, embd_inp).c_str ());
277+ }
262278
263279 // Should not run without any tokens
264280 if (embd_inp.empty ()) {
@@ -478,6 +494,7 @@ int main(int argc, char ** argv) {
478494 std::vector<int > input_tokens; g_input_tokens = &input_tokens;
479495 std::vector<int > output_tokens; g_output_tokens = &output_tokens;
480496 std::ostringstream output_ss; g_output_ss = &output_ss;
497+ std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode
481498
482499 // the first thing we will do is to output the prompt, so set color accordingly
483500 console::set_display (console::prompt);
@@ -793,11 +810,18 @@ int main(int argc, char ** argv) {
793810 is_antiprompt = true ;
794811 }
795812
813+ chat_add_and_format (model, chat_msgs, " system" , assistant_ss.str ());
796814 is_interacting = true ;
797815 printf (" \n " );
798816 }
799817 }
800818
819+ // if current token is not EOG, we add it to current assistant message
820+ if (params.conversation ) {
821+ auto id = llama_sampling_last (ctx_sampling);
822+ assistant_ss << llama_token_to_piece (ctx, id, false );
823+ }
824+
801825 if (n_past > 0 && is_interacting) {
802826 LOG (" waiting for user input\n " );
803827
@@ -848,8 +872,12 @@ int main(int argc, char ** argv) {
848872 string_process_escapes (buffer);
849873 }
850874
875+ std::string user_inp = params.conversation
876+ ? chat_add_and_format (model, chat_msgs, " user" , std::move (buffer))
877+ : std::move (buffer);
878+ // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
851879 const auto line_pfx = ::llama_tokenize (ctx, params.input_prefix , false , true );
852- const auto line_inp = ::llama_tokenize (ctx, buffer , false , false );
880+ const auto line_inp = ::llama_tokenize (ctx, user_inp , false , params. conversation );
853881 const auto line_sfx = ::llama_tokenize (ctx, params.input_suffix , false , true );
854882
855883 LOG (" input tokens: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, line_inp).c_str ());
@@ -864,6 +892,9 @@ int main(int argc, char ** argv) {
864892 output_ss << llama_token_to_piece (ctx, token);
865893 }
866894
895+ // reset assistant message
896+ assistant_ss.str (" " );
897+
867898 n_remain -= line_inp.size ();
868899 LOG (" n_remain: %d\n " , n_remain);
869900 } else {
0 commit comments