Skip to content

Commit 0aa6742

Browse files
Whisper pipeline: support long-form audio (#941)
This PR adds: - [x] Long-form audio support with sequential chunking. Common Todos for Whisper support: - [ ] Long-form audio support with [parallel chunking](https://huggingface.co/blog/asr-chunking). - [ ] add perf metrics - [ ] update documentation - [ ] add cpp, python samples tests - [ ] support timestamps streaming - [ ] expose only meaningful parameters in `GenerationConfig` (`task`, `language`, `return_timestamps`, etc) - [ ] Move all whisper pipeline files to dedicated subfolder - [ ] Whisper pipeline doesn't need tokenizer, it uses detokenizer only. Implement detokenizer only initialization for `ov::genai::Tokenizer` - [ ] Check discrete GPU. Integrated GPU works as expected. - [ ] Investigate use of `RemoteTensor` for GPU - [ ] Add batch - [ ] Add sampler, inherit WhisperGenerationConfig from GenerationConfig - [ ] Investigate language autodetection with single decoder (without past) call - [ ] Update python bindings cmake to include whole directory instead of explicit list of files - [ ] Add samples with audio preparation examples - [ ] Add links to audio files so users can download them in samples - [ ] Move supported models list from samples README to common supported models section - [ ] Avoid building GenAI in each tests job as it takes a lot of time - [ ] Double check FP32 support - [ ] Fix tests sporadic fails. Sometimes whisper model cannot be downloaded from HF due to network issues - [ ] Fix stop criteria. Current approach stops on eos_token which is no speech token. But there could be more speech tokens further which are wrongly skipped now. Completed: - [x] support different languages, language autodetection - [x] support translation - [x] support timestamps Current limitations: - No resampling during preprocessing. Input raw speech should have 16k Hz sampling rate - No normalization during preprocessing. Input raw speech should be normalized to near [-1, 1] range Tickets: CVS-147994, CVS-146010, CVS-152542 --------- Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
1 parent 00e532d commit 0aa6742

File tree

13 files changed

+303
-164
lines changed

13 files changed

+303
-164
lines changed

.github/workflows/linux.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ jobs:
347347
run: |
348348
source ${OV_INSTALL_DIR}/setupvars.sh
349349
python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./tests/python_tests/requirements.txt --find-links ${OV_INSTALL_DIR}/wheels --upgrade-strategy eager
350-
python -m pytest ./tests/python_tests/test_whisper_generate_api.py
350+
python -m pytest ./tests/python_tests/test_whisper_generate_api.py -k test_smoke
351351
env:
352352
PYTHONPATH: "./build/:$PYTHONPATH"
353353

.github/workflows/mac.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ jobs:
291291
run: |
292292
source ${OV_INSTALL_DIR}/setupvars.sh
293293
python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./tests/python_tests/requirements.txt --find-links ${OV_INSTALL_DIR}/wheels --upgrade-strategy eager
294-
python -m pytest ./tests/python_tests/test_whisper_generate_api.py
294+
python -m pytest ./tests/python_tests/test_whisper_generate_api.py -k test_smoke
295295
env:
296296
PYTHONPATH: "./build/:$PYTHONPATH"
297297

.github/workflows/windows.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ jobs:
301301
run: |
302302
. "${{ env.OV_INSTALL_DIR }}/setupvars.ps1"
303303
python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./tests/python_tests/requirements.txt --find-links ${env:OV_INSTALL_DIR}/wheels --upgrade-strategy eager
304-
python -m pytest ./tests/python_tests/test_whisper_generate_api.py
304+
python -m pytest ./tests/python_tests/test_whisper_generate_api.py -k test_smoke
305305
env:
306306
PYTHONPATH: "./build/" # cmd evaluates variables in a different way. Setting PYTHONPATH before setupvars.bat instead of doing that after solves that.
307307

@@ -365,7 +365,7 @@ jobs:
365365
- name: Test bindings
366366
run: |
367367
. "${{ env.OV_INSTALL_DIR }}/setupvars.ps1"
368-
python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./tests/python_tests/requirements.txt --find-links ${env:OV_INSTALL_DIR}/tools --upgrade-strategy eager
368+
python -m pip install ./thirdparty/openvino_tokenizers/[transformers] -r ./tests/python_tests/requirements.txt --find-links ${env:OV_INSTALL_DIR}/wheels --upgrade-strategy eager
369369
python -m pytest ./tests/python_tests/test_vlm_api.py
370370
env:
371371
PYTHONPATH: "./build/" # cmd evaluates variables in a different way. Setting PYTHONPATH before setupvars.bat instead of doing that after solves that.

samples/python/whisper_speech_recognition/whisper_speech_recognition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ def streamer(word: str) -> bool:
3636
streamer=streamer,
3737
)
3838

39+
print()
40+
3941
for chunk in result.chunks:
4042
print(f"timestamps: [{chunk.start_ts}, {chunk.end_ts}] text: {chunk.text}")
4143

42-
print()
43-
4444

4545
if "__main__" == __name__:
4646
main()

samples/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ numpy<2.0.0; sys_platform == 'darwin'
44
einops==0.8.0 # For Qwen
55
transformers_stream_generator==0.0.5 # For Qwen
66
diffusers==0.30.3
7+
librosa # For Whisper
78
torchvision # needed for mini-CPM export script. Need to remove when we switch to exporting with optimum-intel.

src/cpp/src/whisper/logit_processor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ void process_whisper_timestamp_logits(ov::Tensor& logits,
4747

4848
if (last_was_timestamp) {
4949
if (penultimate_was_timestamp) {
50-
// has to be timestamp
50+
// has to be non-timestamp
5151
for (size_t i = timestamp_begin; i < vocab_size; i++) {
5252
logits_data[i] = -std::numeric_limits<float>::infinity();
5353
}

src/cpp/src/whisper/timestamps.cpp

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66
namespace ov {
77
namespace genai {
88

9-
std::pair<std::vector<int64_t>, std::vector<ov::genai::Segment>> extract_segments(
10-
const std::vector<int64_t>& tokens,
11-
const ov::genai::WhisperGenerationConfig& config,
12-
const float time_precision) {
13-
std::vector<int64_t> non_timestamp_tokens;
14-
std::vector<ov::genai::Segment> segments;
9+
ov::genai::ExtractedSegments extract_segments(const std::vector<int64_t>& tokens,
10+
const ov::genai::WhisperGenerationConfig& config,
11+
const size_t nb_max_frames,
12+
const float time_precision) {
13+
ov::genai::ExtractedSegments extracted_segments;
1514
std::optional<int64_t> token_start = std::nullopt;
1615
size_t idx_start = 0;
1716

@@ -41,9 +40,14 @@ std::pair<std::vector<int64_t>, std::vector<ov::genai::Segment>> extract_segment
4140
segment.m_tokens = {tokens.begin() + idx_start + 1, tokens.begin() + i};
4241
segment.m_start = (*token_start - config.begin_timestamps_token_id) * time_precision;
4342
segment.m_end = (token - config.begin_timestamps_token_id) * time_precision;
44-
segments.push_back(segment);
43+
extracted_segments.segments.push_back(segment);
4544

46-
non_timestamp_tokens.insert(non_timestamp_tokens.end(), tokens.begin() + idx_start + 1, tokens.begin() + i);
45+
// each next timestamp token represents .02 time diff
46+
extracted_segments.last_offset = (token - config.begin_timestamps_token_id) * 2;
47+
48+
extracted_segments.non_timestamp_tokens.insert(extracted_segments.non_timestamp_tokens.end(),
49+
tokens.begin() + idx_start + 1,
50+
tokens.begin() + i);
4751

4852
token_start = std::nullopt;
4953
}
@@ -53,18 +57,28 @@ std::pair<std::vector<int64_t>, std::vector<ov::genai::Segment>> extract_segment
5357
// add new segment only if it has non timestamps tokens
5458
// do not add new segment if previous segments exists
5559
bool has_tokens_to_add = idx_start < tokens.size() - 1;
56-
bool has_previous_segments = segments.size() > 0;
60+
bool has_previous_segments = extracted_segments.segments.size() > 0;
5761
if (token_start.has_value() && has_tokens_to_add && !has_previous_segments) {
5862
ov::genai::Segment segment;
5963
segment.m_tokens = {tokens.begin() + idx_start + 1, tokens.end()};
6064
segment.m_start = (*token_start - config.begin_timestamps_token_id) * time_precision;
6165
segment.m_end = -1.0f;
62-
segments.push_back(segment);
66+
extracted_segments.segments.push_back(segment);
67+
68+
extracted_segments.last_offset = nb_max_frames;
69+
70+
extracted_segments.non_timestamp_tokens.insert(extracted_segments.non_timestamp_tokens.end(),
71+
tokens.begin() + idx_start + 1,
72+
tokens.end());
73+
}
6374

64-
non_timestamp_tokens.insert(non_timestamp_tokens.end(), tokens.begin() + idx_start + 1, tokens.end());
75+
// last timestamps generated in pairs <ts><ts><eos> -> speech segment continuation to the next chunk -> token_start will have value
76+
// single ending timestamp <ts><eos> -> no more speech till the end of current chunk -> set offset to the end of frame
77+
if (!token_start.has_value()) {
78+
extracted_segments.last_offset = nb_max_frames;
6579
}
6680

67-
return {non_timestamp_tokens, segments};
81+
return extracted_segments;
6882
}
6983

7084
} // namespace genai

src/cpp/src/whisper/timestamps.hpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,16 @@
1010
namespace ov {
1111
namespace genai {
1212

13-
std::pair<std::vector<int64_t>, std::vector<ov::genai::Segment>> extract_segments(
14-
const std::vector<int64_t>& tokens,
15-
const ov::genai::WhisperGenerationConfig& config,
16-
const float time_precision);
13+
struct ExtractedSegments {
14+
std::vector<ov::genai::Segment> segments;
15+
size_t last_offset;
16+
std::vector<int64_t> non_timestamp_tokens;
17+
};
18+
19+
ExtractedSegments extract_segments(const std::vector<int64_t>& tokens,
20+
const ov::genai::WhisperGenerationConfig& config,
21+
const size_t nb_max_frames,
22+
const float time_precision);
1723

1824
} // namespace genai
1925
} // namespace ov

src/cpp/src/whisper/whisper.cpp

Lines changed: 70 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ int64_t decode(ov::Tensor& encoder_hidden_state,
7676
ov::InferRequest& decoder,
7777
std::vector<int64_t>& input_ids,
7878
const ov::genai::WhisperGenerationConfig& config,
79-
bool apply_logit_processors = true) {
79+
const bool apply_logit_processors = true,
80+
const bool return_timestamps = false) {
8081
decoder.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state});
8182

8283
ov::Tensor input_ids_tensor(ov::element::i64, {1, input_ids.size()}, input_ids.data());
@@ -90,7 +91,7 @@ int64_t decode(ov::Tensor& encoder_hidden_state,
9091
ov::genai::do_suppress_tokens(output_tensor, 0, config.begin_suppress_tokens);
9192
ov::genai::do_suppress_tokens(output_tensor, 0, config.suppress_tokens);
9293

93-
if (config.return_timestamps) {
94+
if (return_timestamps) {
9495
ov::genai::process_whisper_timestamp_logits(output_tensor, 0, config, {}, true);
9596
}
9697
}
@@ -105,6 +106,7 @@ int64_t decode_with_past(ov::Tensor& encoder_hidden_state,
105106
int64_t input_id,
106107
const size_t cache_position,
107108
const ov::genai::WhisperGenerationConfig& config,
109+
const bool return_timestamps,
108110
const std::vector<int64_t>& generated_tokens) {
109111
decoder_with_past.set_tensor("encoder_hidden_states", ov::Tensor{encoder_hidden_state});
110112

@@ -122,7 +124,7 @@ int64_t decode_with_past(ov::Tensor& encoder_hidden_state,
122124

123125
ov::genai::do_suppress_tokens(output_tensor, 0, config.suppress_tokens);
124126

125-
if (config.return_timestamps) {
127+
if (return_timestamps) {
126128
ov::genai::process_whisper_timestamp_logits(output_tensor, 0, config, generated_tokens);
127129
}
128130

@@ -135,14 +137,15 @@ int64_t detect_language(ov::Tensor& encoder_hidden_state,
135137
ov::InferRequest decoder,
136138
const ov::genai::WhisperGenerationConfig& config) {
137139
std::vector<int64_t> input_ids{config.decoder_start_token_id};
138-
int64_t output_token = decode(encoder_hidden_state, decoder, input_ids, config, false);
140+
int64_t output_token = decode(encoder_hidden_state, decoder, input_ids, config, false, false);
139141

140142
return output_token;
141143
}
142144

143-
std::vector<int64_t> prepare_input_ids(ov::Tensor& encoder_hidden_state,
144-
ov::InferRequest decoder,
145-
const ov::genai::WhisperGenerationConfig& config) {
145+
std::vector<int64_t> prepare_init_ids(ov::Tensor& encoder_hidden_state,
146+
ov::InferRequest decoder,
147+
const ov::genai::WhisperGenerationConfig& config,
148+
const bool return_timestamps) {
146149
if (!config.is_multilingual) {
147150
return std::vector<int64_t>{config.decoder_start_token_id, config.no_timestamps_token_id};
148151
}
@@ -162,7 +165,7 @@ std::vector<int64_t> prepare_input_ids(ov::Tensor& encoder_hidden_state,
162165
task_token_id = config.translate_token_id;
163166
}
164167

165-
if (config.return_timestamps) {
168+
if (return_timestamps) {
166169
return std::vector<int64_t>{config.decoder_start_token_id, language_token_id, task_token_id};
167170
}
168171

@@ -175,11 +178,11 @@ std::vector<int64_t> prepare_input_ids(ov::Tensor& encoder_hidden_state,
175178
std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_state,
176179
const ov::genai::WhisperGenerationConfig& config,
177180
ov::genai::WhisperInitializedModels& models,
181+
std::vector<int64_t> init_ids,
178182
const size_t max_new_tokens,
183+
const bool return_timestamps,
179184
const std::shared_ptr<ov::genai::StreamerBase> streamer) {
180-
std::vector<int64_t> input_ids = prepare_input_ids(encoder_hidden_state, models.decoder, config);
181-
182-
int64_t output_token = decode(encoder_hidden_state, models.decoder, input_ids, config);
185+
int64_t output_token = decode(encoder_hidden_state, models.decoder, init_ids, config, true, return_timestamps);
183186

184187
std::vector<int64_t> output_tokens{output_token};
185188

@@ -198,8 +201,9 @@ std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_sta
198201
auto output_token = decode_with_past(encoder_hidden_state,
199202
models.decoder_with_past,
200203
output_tokens.back(),
201-
input_ids.size() + output_tokens.size() - 1,
204+
init_ids.size() + output_tokens.size() - 1,
202205
config,
206+
return_timestamps,
203207
output_tokens);
204208

205209
if (i == 0) {
@@ -225,52 +229,75 @@ std::pair<bool, std::vector<int64_t>> full_decode(ov::Tensor& encoder_hidden_sta
225229

226230
namespace ov {
227231
namespace genai {
228-
// hf hash 2 algos for handling long (>30s) audios https://huggingface.co/openai/whisper-large-v3#chunked-long-form
229-
// Sequential: uses a "sliding window" for buffered inference, transcribing 30-second slices one after the other
230-
// Chunked: splits long audio files into shorter ones (with a small overlap between segments), transcribes each segment
231-
// independently, and stitches the resulting transcriptions at the boundaries
232-
233-
// By default, Transformers uses the sequential algorithm. To enable the chunked algorithm, pass the chunk_length_s
234-
// parameter to the pipeline. A chunk length of 30-seconds is optimal. Sequential algo:
235-
// 1. Process whole raw speech into mel spectrogram
236-
// 2. Chunk mel spectrogram into 30s
237-
// 3. Enable timestamps
238-
// 4. Process each chunk sequentially.
239-
// 5. For each chunk stop at first eos token. Start next window from last timestamp found.
240-
// remove eos tokens if not finished yet
241-
// remove pad tokens
242-
// 7. Concatenate output tokens
232+
243233
std::pair<std::vector<int64_t>, std::optional<std::vector<Segment>>> whisper_generate(
244234
const ov::genai::WhisperGenerationConfig& config,
245235
const ov::genai::WhisperConfig& model_config,
246236
const RawSpeechInput& raw_speech,
247237
ov::genai::WhisperInitializedModels& models,
248238
WhisperFeatureExtractor& feature_extractor,
249239
const std::shared_ptr<StreamerBase> streamer) {
240+
auto input_features = feature_extractor.extract(raw_speech);
241+
242+
const bool is_shortform = input_features.n_frames <= feature_extractor.nb_max_frames;
243+
// long-form audio processing requires timestamps to be enabled
244+
const bool return_timestamps = config.return_timestamps || !is_shortform;
245+
246+
std::vector<int64_t> init_ids;
250247
std::vector<int64_t> output_tokens;
251248
size_t max_new_tokens = config.get_max_new_tokens();
252249

253-
for (size_t chunk_offset = 0; chunk_offset < raw_speech.size(); chunk_offset += feature_extractor.n_samples) {
250+
std::vector<Segment> segments;
251+
252+
// 0.02 by default
253+
const float time_precision = static_cast<float>(feature_extractor.chunk_length) / model_config.max_source_positions;
254+
size_t segment_offset = 0;
255+
256+
for (size_t chunk_offset = 0; chunk_offset < input_features.n_frames; chunk_offset += segment_offset) {
254257
if (output_tokens.size() >= max_new_tokens) {
255258
break;
256259
}
257260

258-
// Split audio data into fixed feature_extractor.chunk_size windows.
259-
size_t copy_size = std::min((raw_speech.size() - chunk_offset), size_t(feature_extractor.n_samples));
260-
std::vector<float> input_features_sub_chunk(raw_speech.begin() + chunk_offset,
261-
raw_speech.begin() + chunk_offset + copy_size);
261+
auto input_features_chunk = input_features.get_data_with_offset(chunk_offset, feature_extractor.nb_max_frames);
262262

263-
auto input_features = feature_extractor.extract(input_features_sub_chunk);
263+
ov::Tensor hidden_state_tensor = encode(models.encoder,
264+
input_features_chunk,
265+
feature_extractor.feature_size,
266+
feature_extractor.nb_max_frames);
264267

265-
ov::Tensor hidden_state_tensor =
266-
encode(models.encoder, input_features, feature_extractor.feature_size, feature_extractor.nb_max_frames);
268+
// prepare init_ids just once for whole input
269+
if (init_ids.empty()) {
270+
init_ids = prepare_init_ids(hidden_state_tensor, models.decoder, config, return_timestamps);
271+
}
267272

268-
bool cancelled;
269-
std::vector<int64_t> chunk_output_tokens;
270-
std::tie(cancelled, chunk_output_tokens) =
271-
full_decode(hidden_state_tensor, config, models, max_new_tokens - output_tokens.size(), streamer);
273+
auto [cancelled, chunk_output_tokens] = full_decode(hidden_state_tensor,
274+
config,
275+
models,
276+
init_ids,
277+
max_new_tokens - output_tokens.size(),
278+
return_timestamps,
279+
streamer);
280+
281+
if (return_timestamps) {
282+
auto extracted_segments = ov::genai::extract_segments(chunk_output_tokens,
283+
config,
284+
feature_extractor.nb_max_frames,
285+
time_precision);
286+
287+
segments.insert(segments.end(), extracted_segments.segments.begin(), extracted_segments.segments.end());
288+
289+
output_tokens.insert(output_tokens.end(),
290+
extracted_segments.non_timestamp_tokens.begin(),
291+
extracted_segments.non_timestamp_tokens.end());
292+
293+
segment_offset = extracted_segments.last_offset;
294+
} else {
295+
output_tokens.insert(output_tokens.end(), chunk_output_tokens.begin(), chunk_output_tokens.end());
296+
}
272297

273-
output_tokens.insert(output_tokens.end(), chunk_output_tokens.begin(), chunk_output_tokens.end());
298+
if (is_shortform) {
299+
segment_offset = input_features.n_frames;
300+
}
274301

275302
if (cancelled) {
276303
break;
@@ -281,12 +308,9 @@ std::pair<std::vector<int64_t>, std::optional<std::vector<Segment>>> whisper_gen
281308
streamer->end();
282309
}
283310

284-
std::optional<std::vector<Segment>> segments = std::nullopt;
285-
if (config.return_timestamps) {
286-
// 0.02 by default
287-
const float time_precision =
288-
static_cast<float>(feature_extractor.chunk_length) / model_config.max_source_positions;
289-
std::tie(output_tokens, segments) = ov::genai::extract_segments(output_tokens, config, time_precision);
311+
// if return_timestamps wasn't enabled by user
312+
if (!config.return_timestamps) {
313+
return {output_tokens, std::nullopt};
290314
}
291315

292316
return {output_tokens, segments};

0 commit comments

Comments
 (0)