Skip to content

Commit f914544

Browse files
authored
batched-bench : add "separate text gen" mode (#17103)
1 parent 4b13a68 commit f914544

File tree

3 files changed

+40
-12
lines changed

3 files changed

+40
-12
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2253,6 +2253,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
22532253
params.is_pp_shared = true;
22542254
}
22552255
).set_examples({LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
2256+
add_opt(common_arg(
2257+
{"-tgs"},
2258+
string_format("is the text generation separated across the different sequences (default: %s)", params.is_tg_separate ? "true" : "false"),
2259+
[](common_params & params) {
2260+
params.is_tg_separate = true;
2261+
}
2262+
).set_examples({LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
22562263
add_opt(common_arg(
22572264
{"-npp"}, "n0,n1,...",
22582265
"number of prompt tokens",

common/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,8 @@ struct common_params {
460460
float slot_prompt_similarity = 0.1f;
461461

462462
// batched-bench params
463-
bool is_pp_shared = false;
463+
bool is_pp_shared = false;
464+
bool is_tg_separate = false;
464465

465466
std::vector<int32_t> n_pp;
466467
std::vector<int32_t> n_tg;

tools/batched-bench/batched-bench.cpp

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ int main(int argc, char ** argv) {
2323

2424
common_init();
2525

26-
int is_pp_shared = params.is_pp_shared;
26+
int is_pp_shared = params.is_pp_shared;
27+
int is_tg_separate = params.is_tg_separate;
2728

2829
std::vector<int> n_pp = params.n_pp;
2930
std::vector<int> n_tg = params.n_tg;
@@ -72,8 +73,8 @@ int main(int argc, char ** argv) {
7273

7374
// decode in batches of ctx_params.n_batch tokens
7475
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch, bool synchronize) {
75-
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
76-
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
76+
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
77+
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
7778

7879
llama_batch batch_view = {
7980
n_tokens,
@@ -113,7 +114,7 @@ int main(int argc, char ** argv) {
113114

114115
if (!params.batched_bench_output_jsonl) {
115116
LOG("\n");
116-
LOG("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, int(params.flash_attn_type), params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
117+
LOG("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, is_tg_separate = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, int(params.flash_attn_type), is_pp_shared, is_tg_separate, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
117118
LOG("\n");
118119
LOG("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
119120
LOG("|%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "----", "------", "--------", "--------", "--------", "--------", "--------", "--------");
@@ -172,16 +173,35 @@ int main(int argc, char ** argv) {
172173

173174
const auto t_tg_start = ggml_time_us();
174175

175-
for (int i = 0; i < tg; ++i) {
176-
common_batch_clear(batch);
177-
176+
if (is_tg_separate) {
177+
// decode pattern:
178+
// 0 0 0 ... 1 1 1 ... 2 2 2 ... 3 3 3 ...
178179
for (int j = 0; j < pl; ++j) {
179-
common_batch_add(batch, get_token_rand(), pp + i, { j }, true);
180+
for (int i = 0; i < tg; ++i) {
181+
common_batch_clear(batch);
182+
183+
common_batch_add(batch, get_token_rand(), pp + i, { j }, true);
184+
185+
if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
186+
LOG_ERR("%s: llama_decode() failed\n", __func__);
187+
return 1;
188+
}
189+
}
180190
}
191+
} else {
192+
// decode pattern:
193+
// 0123 0123 0123 ...
194+
for (int i = 0; i < tg; ++i) {
195+
common_batch_clear(batch);
181196

182-
if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
183-
LOG_ERR("%s: llama_decode() failed\n", __func__);
184-
return 1;
197+
for (int j = 0; j < pl; ++j) {
198+
common_batch_add(batch, get_token_rand(), pp + i, { j }, true);
199+
}
200+
201+
if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
202+
LOG_ERR("%s: llama_decode() failed\n", __func__);
203+
return 1;
204+
}
185205
}
186206
}
187207

0 commit comments

Comments
 (0)