Skip to content

Commit 8143634

Browse files
authored
Whisper pipeline: return timestamps (#910)
This PR adds return_timestamps support for Whisper pipeline. Common Todos for Whisper support: - [ ] Longer audio inputs (>30s) chunking border poor quality results. Long audio inputs splitted by 30s chunks. This leads to a loss of context on a chunking border. This could be partially solved by [chunking with stride](https://huggingface.co/blog/asr-chunking). - [ ] add perf metrics - [ ] update documentation - [ ] add cpp, python samples tests - [x] support different languages, language autodetection - [x] support translation - [x] support timestamps - [ ] support timestamps streaming - [ ] expose only meaningful parameters in `GenerationConfig` (`task`, `language`, `return_timestamps`, etc) - [ ] Move all whisper pipeline files to dedicated subfolder - [ ] Whisper pipeline doesn't need tokenizer, it uses detokenizer only. Implement detokenizer only initialization for `ov::genai::Tokenizer` - [ ] Check discrete GPU. Integrated GPU works as expected. - [ ] Investigate use of `RemoteTensor` for GPU - [ ] Add batch - [ ] Add sampler, inherit WhisperGenerationConfig from GenerationConfig - [ ] Investigate language autodetection with single decoder (without past) call - [ ] Update python bindings cmake to include whole directory instead of explicit list of files - [ ] Add samples with audio preparation examples - [ ] Add links to audio files so users can download them in samples - [ ] Move supported models list from samples README to common supported models section - [ ] Avoid building GenAI in each tests job as it takes a lot of time - [ ] Double check FP32 support - [ ] Fix tests sporadic fails. Sometimes whisper model cannot be downloaded from HF due to network issues - [ ] Fix stop criteria. Current approach stops on eos_token which is no speech token. But there could be more speech tokens further which are wrongly skipped now. Current limitations: - No resampling during preprocessing. Input raw speech should have 16k Hz sampling rate - No normalization during preprocessing. Input raw speech should be normalized to near [-1, 1] range Tickets: CVS-147994, CVS-146010, CVS-152543
1 parent dcb2336 commit 8143634

File tree

17 files changed

+594
-65
lines changed

17 files changed

+594
-65
lines changed

samples/cpp/whisper_speech_recognition/whisper_speech_recognition.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,20 @@ int main(int argc, char* argv[]) try {
2121
// 'task' and 'language' parameters are supported for multilingual models only
2222
config.language = "<|en|>";
2323
config.task = "transcribe";
24+
config.return_timestamps = true;
2425

2526
auto streamer = [](std::string word) {
2627
std::cout << word;
2728
return false;
2829
};
2930

30-
pipeline.generate(raw_speech, config, streamer);
31+
auto result = pipeline.generate(raw_speech, config, streamer);
3132

32-
std::cout << std::endl;
33+
std::cout << "\n";
34+
35+
for (auto& chunk : *result.chunks) {
36+
std::cout << "timestamps: [" << chunk.start_ts << ", " << chunk.end_ts << "] text: " << chunk.text << "\n";
37+
}
3338
} catch (const std::exception& error) {
3439
try {
3540
std::cerr << error.what() << '\n';

samples/python/whisper_speech_recognition/whisper_speech_recognition.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,19 @@ def streamer(word: str) -> bool:
2626
print(word, end="")
2727
return False
2828

29-
pipe.generate(
29+
result = pipe.generate(
3030
raw_speech,
3131
max_new_tokens=100,
3232
# 'task' and 'language' parameters are supported for multilingual models only
3333
language="<|en|>",
3434
task="transcribe",
35+
return_timestamps=True,
3536
streamer=streamer,
3637
)
3738

39+
for chunk in result.chunks:
40+
print(f"timestamps: [{chunk.start_ts}, {chunk.end_ts}] text: {chunk.text}")
41+
3842
print()
3943

4044

src/cpp/include/openvino/genai/whisper_generation_config.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig {
5151
// Begin timestamps token id.
5252
int64_t begin_timestamps_token_id = 50364;
5353

54+
size_t max_initial_timestamp_index = 50;
55+
5456
bool is_multilingual = true;
5557

5658
// Language token to use for generation in the form of <|en|>.
@@ -65,6 +67,16 @@ class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig {
6567
// Can be set for multilingual models only.
6668
std::optional<std::string> task = std::nullopt;
6769

70+
// If `true` the pipeline will return timestamps along the text for *segments* of words in the text.
71+
// For instance, if you get
72+
// WhisperDecodedResultChunk
73+
// start_ts = 0.5
74+
// end_ts = 1.5
75+
// text = " Hi there!"
76+
// then it means the model predicts that the segment "Hi there!" was spoken after `0.5` and before `1.5` seconds.
77+
// Note that a segment of text refers to a sequence of one or more words, rather than individual words.
78+
bool return_timestamps = false;
79+
6880
// A list containing tokens that will be supressed at the beginning of the sampling process.
6981
std::vector<int64_t> begin_suppress_tokens;
7082

@@ -105,6 +117,7 @@ static constexpr ov::Property<int64_t> no_timestamps_token_id{"no_timestamps_tok
105117
static constexpr ov::Property<int64_t> begin_timestamps_token_id{"begin_timestamps_token_id"};
106118
static constexpr ov::Property<std::string> language{"language"};
107119
static constexpr ov::Property<std::string> task{"task"};
120+
static constexpr ov::Property<bool> return_timestamps{"return_timestamps"};
108121
static constexpr ov::Property<std::map<std::string, int64_t>> lang_to_id{"lang_to_id"};
109122

110123
} // namespace genai

src/cpp/include/openvino/genai/whisper_pipeline.hpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,21 @@ using OptionalWhisperGenerationConfig = std::optional<WhisperGenerationConfig>;
1717

1818
using RawSpeechInput = std::vector<float>;
1919

20+
struct WhisperDecodedResultChunk {
21+
// start of chunk in seconds
22+
float start_ts;
23+
24+
// end of chunk in seconds
25+
// -1.0f if chunk started but model did not predict an ending timestamp
26+
// can happen if audio is cut off in the middle of a word
27+
float end_ts = -1.0f;
28+
std::string text;
29+
};
30+
31+
struct WhisperDecodedResults : public DecodedResults {
32+
std::optional<std::vector<WhisperDecodedResultChunk>> chunks = std::nullopt;
33+
};
34+
2035
class OPENVINO_GENAI_EXPORTS WhisperPipeline {
2136
class Impl;
2237
std::unique_ptr<Impl> m_impl;
@@ -57,11 +72,11 @@ class OPENVINO_GENAI_EXPORTS WhisperPipeline {
5772
* sampling rate.
5873
* @param generation_config optional GenerationConfig
5974
* @param streamer optional streamer
60-
* @return DecodedResults decoded resulting text transcription
75+
* @return WhisperDecodedResults decoded resulting text transcription
6176
*/
62-
DecodedResults generate(const RawSpeechInput& raw_speech_input,
63-
OptionalWhisperGenerationConfig generation_config = std::nullopt,
64-
StreamerVariant streamer = std::monostate());
77+
WhisperDecodedResults generate(const RawSpeechInput& raw_speech_input,
78+
OptionalWhisperGenerationConfig generation_config = std::nullopt,
79+
StreamerVariant streamer = std::monostate());
6580

6681
/**
6782
* @brief High level generate that receives raw speech as a vector of floats and returns decoded output.
@@ -70,14 +85,14 @@ class OPENVINO_GENAI_EXPORTS WhisperPipeline {
7085
*
7186
* @param raw_speech_input raw speech input
7287
* @param properties properties
73-
* @return DecodedResults decoded resulting text transcription
88+
* @return WhisperDecodedResults decoded resulting text transcription
7489
*/
7590
template <typename... Properties>
76-
util::EnableIfAllStringAny<DecodedResults, Properties...> generate(const RawSpeechInput& raw_speech_input,
77-
Properties&&... properties) {
91+
util::EnableIfAllStringAny<WhisperDecodedResults, Properties...> generate(const RawSpeechInput& raw_speech_input,
92+
Properties&&... properties) {
7893
return generate(raw_speech_input, AnyMap{std::forward<Properties>(properties)...});
7994
}
80-
DecodedResults generate(const RawSpeechInput& raw_speech_input, const ov::AnyMap& config_map);
95+
WhisperDecodedResults generate(const RawSpeechInput& raw_speech_input, const ov::AnyMap& config_map);
8196

8297
ov::genai::Tokenizer get_tokenizer();
8398
WhisperGenerationConfig get_generation_config() const;

src/cpp/src/sampler.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ inline bool is_stop_token_id_hit(int64_t generated_token, const std::set<int64_t
3030
return false;
3131
}
3232

33+
std::vector<Token> log_softmax(const ov::Tensor& logits, size_t batch_idx);
34+
3335
struct SamplerOutput {
3436
// IDs of sequences that need to be dropped
3537
std::vector<uint64_t> m_dropped_sequences;
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Copyright (C) 2023-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#include <openvino/openvino.hpp>
5+
6+
#include "openvino/genai/whisper_generation_config.hpp"
7+
#include "sampler.hpp"
8+
9+
namespace ov {
10+
namespace genai {
11+
12+
void do_suppress_tokens(ov::Tensor& logits, const size_t batch_idx, const std::vector<int64_t>& suppress_tokens) {
13+
OPENVINO_ASSERT(logits.get_shape()[0] >= batch_idx, "logits batch size doesn't match the batch number");
14+
15+
size_t vocab_size = logits.get_shape().back();
16+
size_t batch_offset = batch_idx * logits.get_shape()[1] * vocab_size;
17+
size_t sequence_offset = (logits.get_shape()[1] - 1) * vocab_size;
18+
float* logits_data = logits.data<float>() + batch_offset + sequence_offset;
19+
20+
for (auto supress_token : suppress_tokens) {
21+
logits_data[supress_token] = -std::numeric_limits<float>::infinity();
22+
}
23+
}
24+
25+
void process_whisper_timestamp_logits(ov::Tensor& logits,
26+
const size_t batch_idx,
27+
const ov::genai::WhisperGenerationConfig& config,
28+
const std::vector<int64_t>& generated_tokens,
29+
bool initial_step = false) {
30+
const size_t batch_size = logits.get_shape().at(0);
31+
OPENVINO_ASSERT(batch_size == 1, "Batch != 1 is not supported");
32+
33+
size_t vocab_size = logits.get_shape().back();
34+
size_t batch_offset = batch_idx * logits.get_shape()[1] * vocab_size;
35+
size_t sequence_offset = (logits.get_shape()[1] - 1) * vocab_size;
36+
float* logits_data = logits.data<float>() + batch_offset + sequence_offset;
37+
38+
// supress<|notimestamps|>
39+
logits_data[config.no_timestamps_token_id] = -std::numeric_limits<float>::infinity();
40+
41+
size_t timestamp_begin = config.no_timestamps_token_id + 1;
42+
43+
// timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
44+
size_t generated_length = generated_tokens.size();
45+
bool last_was_timestamp = generated_length >= 1 && generated_tokens[generated_length - 1] >= timestamp_begin;
46+
bool penultimate_was_timestamp = generated_length < 2 || generated_tokens[generated_length - 2] >= timestamp_begin;
47+
48+
if (last_was_timestamp) {
49+
if (penultimate_was_timestamp) {
50+
// has to be timestamp
51+
for (size_t i = timestamp_begin; i < vocab_size; i++) {
52+
logits_data[i] = -std::numeric_limits<float>::infinity();
53+
}
54+
} else {
55+
// cannot be normal text token
56+
for (size_t i = 0; i < config.eos_token_id; i++) {
57+
logits_data[i] = -std::numeric_limits<float>::infinity();
58+
}
59+
}
60+
}
61+
62+
// filter generated timestaps
63+
std::vector<int64_t> timestamps;
64+
for (const auto token : generated_tokens) {
65+
if (token >= timestamp_begin) {
66+
timestamps.push_back(token);
67+
}
68+
}
69+
70+
if (timestamps.size() > 0) {
71+
size_t timestamp_last;
72+
// `timestamps` shouldn't decrease; forbid timestamp tokens smaller than the last
73+
// The following lines of code are copied from: https://github.com/openai/whisper/pull/914/files#r1137085090
74+
if (last_was_timestamp && !penultimate_was_timestamp) {
75+
timestamp_last = timestamps.back();
76+
} else {
77+
// Avoid to emit <|0.00|> again
78+
timestamp_last = timestamps.back() + 1;
79+
}
80+
81+
for (size_t i = timestamp_begin; i < timestamp_last; i++) {
82+
logits_data[i] = -std::numeric_limits<float>::infinity();
83+
}
84+
}
85+
86+
// apply the `max_initial_timestamp` option
87+
if (initial_step) {
88+
for (size_t i = 0; i < timestamp_begin; i++) {
89+
logits_data[i] = -std::numeric_limits<float>::infinity();
90+
}
91+
92+
size_t last_allowed = timestamp_begin + config.max_initial_timestamp_index;
93+
for (size_t i = last_allowed + 1; i < vocab_size; i++) {
94+
logits_data[i] = -std::numeric_limits<float>::infinity();
95+
}
96+
}
97+
98+
auto tokens = ov::genai::log_softmax(logits, 0);
99+
float timestamp_exp_prov_sum = 0;
100+
101+
for (size_t i = timestamp_begin; i < vocab_size; i++) {
102+
timestamp_exp_prov_sum += std::exp(tokens[i].m_log_prob);
103+
}
104+
float timestamp_logprob = std::log(timestamp_exp_prov_sum);
105+
106+
auto max_logprob_token = std::max_element(tokens.begin(), tokens.end(), [](const Token& left, const Token& right) {
107+
return left.m_log_prob < right.m_log_prob;
108+
});
109+
110+
if (timestamp_logprob > max_logprob_token->m_log_prob) {
111+
for (size_t i = 0; i < timestamp_begin; i++) {
112+
logits_data[i] = -std::numeric_limits<float>::infinity();
113+
}
114+
}
115+
}
116+
117+
} // namespace genai
118+
} // namespace ov
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright (C) 2023-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#pragma once
5+
6+
#include <openvino/openvino.hpp>
7+
8+
#include "openvino/genai/whisper_generation_config.hpp"
9+
10+
namespace ov {
11+
namespace genai {
12+
13+
void do_suppress_tokens(ov::Tensor& logits, const size_t batch_idx, const std::vector<int64_t>& suppress_tokens);
14+
15+
void process_whisper_timestamp_logits(ov::Tensor& logits,
16+
const size_t batch_idx,
17+
const ov::genai::WhisperGenerationConfig& config,
18+
const std::vector<int64_t>& generated_tokens,
19+
bool initial_step = false);
20+
21+
} // namespace genai
22+
} // namespace ov

src/cpp/src/whisper/timestamps.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Copyright (C) 2023-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#include "timestamps.hpp"
5+
6+
namespace ov {
7+
namespace genai {
8+
9+
std::pair<std::vector<int64_t>, std::vector<ov::genai::Segment>> extract_segments(
10+
const std::vector<int64_t>& tokens,
11+
const ov::genai::WhisperGenerationConfig& config,
12+
const float time_precision) {
13+
std::vector<int64_t> non_timestamp_tokens;
14+
std::vector<ov::genai::Segment> segments;
15+
std::optional<int64_t> token_start = std::nullopt;
16+
size_t idx_start = 0;
17+
18+
for (size_t i = 0; i < tokens.size(); i++) {
19+
int64_t token = tokens[i];
20+
21+
bool is_timestamp = token >= config.begin_timestamps_token_id;
22+
23+
if (!is_timestamp) {
24+
continue;
25+
}
26+
27+
if (!token_start.has_value()) {
28+
token_start = token;
29+
idx_start = i;
30+
} else {
31+
if (token_start == token) {
32+
// from HF:
33+
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/tokenization_whisper.py#L1020
34+
// This is a bug in timestamp token output where we're taking the duplicate token as a stop where it
35+
// should be a start. This is an issue in the underlying model output. Let's just skip it so it becomes
36+
// de-factor a start again.
37+
continue;
38+
}
39+
40+
ov::genai::Segment segment;
41+
segment.m_tokens = {tokens.begin() + idx_start + 1, tokens.begin() + i};
42+
segment.m_start = (*token_start - config.begin_timestamps_token_id) * time_precision;
43+
segment.m_end = (token - config.begin_timestamps_token_id) * time_precision;
44+
segments.push_back(segment);
45+
46+
non_timestamp_tokens.insert(non_timestamp_tokens.end(), tokens.begin() + idx_start + 1, tokens.begin() + i);
47+
48+
token_start = std::nullopt;
49+
}
50+
}
51+
52+
// segment started but has no closing timestamp
53+
// add new segment only if it has non timestamps tokens
54+
// do not add new segment if previous segments exists
55+
bool has_tokens_to_add = idx_start < tokens.size() - 1;
56+
bool has_previous_segments = segments.size() > 0;
57+
if (token_start.has_value() && has_tokens_to_add && !has_previous_segments) {
58+
ov::genai::Segment segment;
59+
segment.m_tokens = {tokens.begin() + idx_start + 1, tokens.end()};
60+
segment.m_start = (*token_start - config.begin_timestamps_token_id) * time_precision;
61+
segment.m_end = -1.0f;
62+
segments.push_back(segment);
63+
64+
non_timestamp_tokens.insert(non_timestamp_tokens.end(), tokens.begin() + idx_start + 1, tokens.end());
65+
}
66+
67+
return {non_timestamp_tokens, segments};
68+
}
69+
70+
} // namespace genai
71+
} // namespace ov

src/cpp/src/whisper/timestamps.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Copyright (C) 2023-2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
#pragma once
5+
6+
#include <openvino/openvino.hpp>
7+
8+
#include "whisper.hpp"
9+
10+
namespace ov {
11+
namespace genai {
12+
13+
std::pair<std::vector<int64_t>, std::vector<ov::genai::Segment>> extract_segments(
14+
const std::vector<int64_t>& tokens,
15+
const ov::genai::WhisperGenerationConfig& config,
16+
const float time_precision);
17+
18+
} // namespace genai
19+
} // namespace ov

0 commit comments

Comments
 (0)