From cf19f430e9971183ff56ed65a017a01260ca1090 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 8 Nov 2025 14:09:22 +0200 Subject: [PATCH] batched-bench : add "separate text gen" mode --- common/arg.cpp | 7 +++++ common/common.h | 3 +- tools/batched-bench/batched-bench.cpp | 42 ++++++++++++++++++++------- 3 files changed, 40 insertions(+), 12 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 5597de121c132..ba67f26605506 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2239,6 +2239,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.is_pp_shared = true; } ).set_examples({LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL})); + add_opt(common_arg( + {"-tgs"}, + string_format("is the text generation separated across the different sequences (default: %s)", params.is_tg_separate ? "true" : "false"), + [](common_params & params) { + params.is_tg_separate = true; + } + ).set_examples({LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL})); add_opt(common_arg( {"-npp"}, "n0,n1,...", "number of prompt tokens", diff --git a/common/common.h b/common/common.h index 54b7849b17448..9e18d0f07c39f 100644 --- a/common/common.h +++ b/common/common.h @@ -460,7 +460,8 @@ struct common_params { float slot_prompt_similarity = 0.1f; // batched-bench params - bool is_pp_shared = false; + bool is_pp_shared = false; + bool is_tg_separate = false; std::vector n_pp; std::vector n_tg; diff --git a/tools/batched-bench/batched-bench.cpp b/tools/batched-bench/batched-bench.cpp index f1ab27cd54d0a..2032a386bb4d2 100644 --- a/tools/batched-bench/batched-bench.cpp +++ b/tools/batched-bench/batched-bench.cpp @@ -23,7 +23,8 @@ int main(int argc, char ** argv) { common_init(); - int is_pp_shared = params.is_pp_shared; + int is_pp_shared = params.is_pp_shared; + int is_tg_separate = params.is_tg_separate; std::vector n_pp = params.n_pp; std::vector n_tg = params.n_tg; @@ -72,8 +73,8 @@ int main(int argc, char ** argv) { // decode in batches of ctx_params.n_batch tokens auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch, bool synchronize) { - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); llama_batch batch_view = { n_tokens, @@ -113,7 +114,7 @@ int main(int argc, char ** argv) { if (!params.batched_bench_output_jsonl) { LOG("\n"); - LOG("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, int(params.flash_attn_type), params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch); + LOG("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, is_tg_separate = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, int(params.flash_attn_type), is_pp_shared, is_tg_separate, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch); LOG("\n"); LOG("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s"); LOG("|%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "----", "------", "--------", "--------", "--------", "--------", "--------", "--------"); @@ -172,16 +173,35 @@ int main(int argc, char ** argv) { const auto t_tg_start = ggml_time_us(); - for (int i = 0; i < tg; ++i) { - common_batch_clear(batch); - + if (is_tg_separate) { + // decode pattern: + // 0 0 0 ... 1 1 1 ... 2 2 2 ... 3 3 3 ... for (int j = 0; j < pl; ++j) { - common_batch_add(batch, get_token_rand(), pp + i, { j }, true); + for (int i = 0; i < tg; ++i) { + common_batch_clear(batch); + + common_batch_add(batch, get_token_rand(), pp + i, { j }, true); + + if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) { + LOG_ERR("%s: llama_decode() failed\n", __func__); + return 1; + } + } } + } else { + // decode pattern: + // 0123 0123 0123 ... + for (int i = 0; i < tg; ++i) { + common_batch_clear(batch); - if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) { - LOG_ERR("%s: llama_decode() failed\n", __func__); - return 1; + for (int j = 0; j < pl; ++j) { + common_batch_add(batch, get_token_rand(), pp + i, { j }, true); + } + + if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) { + LOG_ERR("%s: llama_decode() failed\n", __func__); + return 1; + } } }