diff --git a/include/whisper.h b/include/whisper.h index f4cc6bf7abd..002e7b9f3ca 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -625,6 +625,14 @@ extern "C" { int n_samples, int n_processors); + WHISPER_API int whisper_full_batch_parallel( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * const * batches, + const int * size_per_batch, + int n_batches, + int n_processors); + // Number of generated text segments // A segment can be a few words, a sentence, or even a paragraph. WHISPER_API int whisper_full_n_segments (struct whisper_context * ctx); diff --git a/src/whisper.cpp b/src/whisper.cpp index f6793cb237b..987aa4875f7 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -7893,6 +7893,129 @@ int whisper_full_parallel( return ret; } + +int whisper_full_batch_parallel( + struct whisper_context *ctx, + struct whisper_full_params params, + const float *const *batches, + const int *size_per_batch, + int n_batches, + int n_processors) +{ + int ret = 0; + n_processors = std::min(n_processors, n_batches); + if (n_batches > n_processors) + { + throw std::runtime_error("batch size must be equal to number of processors"); + } + // prepare separate states for each thread + std::vector states; + std::vector> batches_vector; + batches_vector.reserve(n_batches); + for (int i = 0; i < n_batches; ++i) + { + int batch_size = size_per_batch[i]; + batches_vector.emplace_back(batches[i], batches[i] + batch_size); + } + + // the calling thread will process the first chunk + // while the other threads will process the remaining chunks + const int n_parallel_processes = n_processors - 1; + std::vector workers(n_parallel_processes); + for (int i = 0; i < n_parallel_processes; ++i) + { + if (i + 1 > n_batches - 1) + { + // break when batch not exist for parallel process + break; + } + const float *samples = batches_vector[i + 1].data(); + const int n_samples = batches_vector[i + 1].size(); + // create a new state for each thread + states.push_back(whisper_init_state(ctx)); + + auto params_cur = params; + + params_cur.offset_ms = 0; + params_cur.print_progress = false; + params_cur.print_realtime = false; + + params_cur.new_segment_callback = nullptr; + params_cur.new_segment_callback_user_data = nullptr; + + params_cur.progress_callback = nullptr; + params_cur.progress_callback_user_data = nullptr; + + workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples, n_samples); + } + + { + auto params_cur = params; + + // We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk. + params_cur.print_realtime = false; + + const float *samples = batches_vector[0].data(); + const int n_samples = batches_vector[0].size(); + + // Run the first transformation using default state but only for the first chunk. + ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, n_samples); + } + + for (int i = 0; i < n_parallel_processes; ++i) + { + workers[i].join(); + } + + // combine results into result_state->result_all from all other states + for (int i = 0; i < n_processors - 1; ++i) + { + auto &results_i = states[i]->result_all; + + for (auto &result : results_i) + { + + // make sure that segments are not overlapping + if (!ctx->state->result_all.empty()) + { + result.t0 = std::max(result.t0, ctx->state->result_all.back().t1); + } + + ctx->state->result_all.push_back(std::move(result)); + + // call the new_segment_callback for each segment + if (params.new_segment_callback) + { + params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data); + } + } + + ctx->state->t_mel_us += states[i]->t_mel_us; + + ctx->state->t_sample_us += states[i]->t_sample_us; + ctx->state->t_encode_us += states[i]->t_encode_us; + ctx->state->t_decode_us += states[i]->t_decode_us; + ctx->state->t_batchd_us += states[i]->t_batchd_us; + ctx->state->t_prompt_us += states[i]->t_prompt_us; + + ctx->state->n_sample += states[i]->n_sample; + ctx->state->n_encode += states[i]->n_encode; + ctx->state->n_decode += states[i]->n_decode; + ctx->state->n_batchd += states[i]->n_batchd; + ctx->state->n_prompt += states[i]->n_prompt; + + whisper_free_state(states[i]); + } + + // average the timings + ctx->state->t_mel_us /= n_processors; + ctx->state->t_sample_us /= n_processors; + ctx->state->t_encode_us /= n_processors; + ctx->state->t_decode_us /= n_processors; + + return ret; +} + int whisper_full_n_segments_from_state(struct whisper_state * state) { return state->result_all.size(); }