Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2239,6 +2239,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.is_pp_shared = true;
}
).set_examples({LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
add_opt(common_arg(
{"-tgs"},
string_format("is the text generation separated across the different sequences (default: %s)", params.is_tg_separate ? "true" : "false"),
[](common_params & params) {
params.is_tg_separate = true;
}
).set_examples({LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
add_opt(common_arg(
{"-npp"}, "n0,n1,...",
"number of prompt tokens",
Expand Down
3 changes: 2 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,8 @@ struct common_params {
float slot_prompt_similarity = 0.1f;

// batched-bench params
bool is_pp_shared = false;
bool is_pp_shared = false;
bool is_tg_separate = false;

std::vector<int32_t> n_pp;
std::vector<int32_t> n_tg;
Expand Down
42 changes: 31 additions & 11 deletions tools/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ int main(int argc, char ** argv) {

common_init();

int is_pp_shared = params.is_pp_shared;
int is_pp_shared = params.is_pp_shared;
int is_tg_separate = params.is_tg_separate;

std::vector<int> n_pp = params.n_pp;
std::vector<int> n_tg = params.n_tg;
Expand Down Expand Up @@ -72,8 +73,8 @@ int main(int argc, char ** argv) {

// decode in batches of ctx_params.n_batch tokens
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch, bool synchronize) {
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);

llama_batch batch_view = {
n_tokens,
Expand Down Expand Up @@ -113,7 +114,7 @@ int main(int argc, char ** argv) {

if (!params.batched_bench_output_jsonl) {
LOG("\n");
LOG("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, int(params.flash_attn_type), params.is_pp_shared, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
LOG("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, is_tg_separate = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, params.n_batch, params.n_ubatch, int(params.flash_attn_type), is_pp_shared, is_tg_separate, params.n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
LOG("\n");
LOG("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
LOG("|%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "----", "------", "--------", "--------", "--------", "--------", "--------", "--------");
Expand Down Expand Up @@ -172,16 +173,35 @@ int main(int argc, char ** argv) {

const auto t_tg_start = ggml_time_us();

for (int i = 0; i < tg; ++i) {
common_batch_clear(batch);

if (is_tg_separate) {
// decode pattern:
// 0 0 0 ... 1 1 1 ... 2 2 2 ... 3 3 3 ...
for (int j = 0; j < pl; ++j) {
common_batch_add(batch, get_token_rand(), pp + i, { j }, true);
for (int i = 0; i < tg; ++i) {
common_batch_clear(batch);

common_batch_add(batch, get_token_rand(), pp + i, { j }, true);

if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
}
}
} else {
// decode pattern:
// 0123 0123 0123 ...
for (int i = 0; i < tg; ++i) {
common_batch_clear(batch);

if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
for (int j = 0; j < pl; ++j) {
common_batch_add(batch, get_token_rand(), pp + i, { j }, true);
}

if (!decode_helper(ctx, batch, ctx_params.n_batch, true)) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
}
}

Expand Down
Loading