@@ -23,7 +23,8 @@ int main(int argc, char ** argv) {
2323
2424 common_init ();
2525
26- int is_pp_shared = params.is_pp_shared ;
26+ int is_pp_shared = params.is_pp_shared ;
27+ int is_tg_separate = params.is_tg_separate ;
2728
2829 std::vector<int > n_pp = params.n_pp ;
2930 std::vector<int > n_tg = params.n_tg ;
@@ -72,8 +73,8 @@ int main(int argc, char ** argv) {
7273
7374 // decode in batches of ctx_params.n_batch tokens
7475 auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch, bool synchronize) {
75- for (int32_t i = 0 ; i < ( int32_t ) batch.n_tokens ; i += n_batch) {
76- const int32_t n_tokens = std::min (n_batch, ( int32_t ) ( batch.n_tokens - i) );
76+ for (int32_t i = 0 ; i < batch.n_tokens ; i += n_batch) {
77+ const int32_t n_tokens = std::min (n_batch, batch.n_tokens - i);
7778
7879 llama_batch batch_view = {
7980 n_tokens,
@@ -113,7 +114,7 @@ int main(int argc, char ** argv) {
113114
114115 if (!params.batched_bench_output_jsonl ) {
115116 LOG (" \n " );
116- LOG (" %s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n " , __func__, n_kv_max, params.n_batch , params.n_ubatch , int (params.flash_attn_type ), params. is_pp_shared , params.n_gpu_layers , ctx_params.n_threads , ctx_params.n_threads_batch );
117+ LOG (" %s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, is_tg_separate = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n " , __func__, n_kv_max, params.n_batch , params.n_ubatch , int (params.flash_attn_type ), is_pp_shared, is_tg_separate , params.n_gpu_layers , ctx_params.n_threads , ctx_params.n_threads_batch );
117118 LOG (" \n " );
118119 LOG (" |%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n " , " PP" , " TG" , " B" , " N_KV" , " T_PP s" , " S_PP t/s" , " T_TG s" , " S_TG t/s" , " T s" , " S t/s" );
119120 LOG (" |%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n " , " ------" , " ------" , " ----" , " ------" , " --------" , " --------" , " --------" , " --------" , " --------" , " --------" );
@@ -172,16 +173,35 @@ int main(int argc, char ** argv) {
172173
173174 const auto t_tg_start = ggml_time_us ();
174175
175- for ( int i = 0 ; i < tg; ++i ) {
176- common_batch_clear (batch);
177-
176+ if (is_tg_separate ) {
177+ // decode pattern:
178+ // 0 0 0 ... 1 1 1 ... 2 2 2 ... 3 3 3 ...
178179 for (int j = 0 ; j < pl; ++j) {
179- common_batch_add (batch, get_token_rand (), pp + i, { j }, true );
180+ for (int i = 0 ; i < tg; ++i) {
181+ common_batch_clear (batch);
182+
183+ common_batch_add (batch, get_token_rand (), pp + i, { j }, true );
184+
185+ if (!decode_helper (ctx, batch, ctx_params.n_batch , true )) {
186+ LOG_ERR (" %s: llama_decode() failed\n " , __func__);
187+ return 1 ;
188+ }
189+ }
180190 }
191+ } else {
192+ // decode pattern:
193+ // 0123 0123 0123 ...
194+ for (int i = 0 ; i < tg; ++i) {
195+ common_batch_clear (batch);
181196
182- if (!decode_helper (ctx, batch, ctx_params.n_batch , true )) {
183- LOG_ERR (" %s: llama_decode() failed\n " , __func__);
184- return 1 ;
197+ for (int j = 0 ; j < pl; ++j) {
198+ common_batch_add (batch, get_token_rand (), pp + i, { j }, true );
199+ }
200+
201+ if (!decode_helper (ctx, batch, ctx_params.n_batch , true )) {
202+ LOG_ERR (" %s: llama_decode() failed\n " , __func__);
203+ return 1 ;
204+ }
185205 }
186206 }
187207
0 commit comments