Skip to content

Commit dcb2336

Browse files
authored
[CB] Add validation mode to Sampler (#904)
Tickets: * [153600](https://jira.devtools.intel.com/browse/CVS-153600) * [153601](https://jira.devtools.intel.com/browse/CVS-153601)
1 parent cfb6e02 commit dcb2336

File tree

7 files changed

+457
-40
lines changed

7 files changed

+457
-40
lines changed

src/cpp/src/continuous_batching_impl.hpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc
5050
const std::string& device,
5151
const ov::AnyMap& llm_plugin_config,
5252
const ov::AnyMap& tokenizer_plugin_config)
53-
: ContinuousBatchingImpl{models_path, Tokenizer(models_path, tokenizer_plugin_config), scheduler_config, device, llm_plugin_config} {};
53+
: ContinuousBatchingImpl{ models_path,
54+
Tokenizer(models_path, tokenizer_plugin_config),
55+
scheduler_config,
56+
device,
57+
llm_plugin_config } {};
5458

5559

5660
GenerationHandle add_request(uint64_t request_id,

src/cpp/src/logit_processor.hpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,12 @@ class LogitProcessor {
365365
}
366366
}
367367

368-
void increment_gen_tokens() {
369-
++m_generated_tokens;
368+
void update_generated_len(size_t updated_len) {
369+
m_generated_tokens = updated_len;
370+
}
371+
372+
size_t get_generated_len() {
373+
return m_generated_tokens;
370374
}
371375

372376
void register_new_generated_token(int64_t new_token_id) {
@@ -377,4 +381,10 @@ class LogitProcessor {
377381
it->second++;
378382
}
379383
}
384+
385+
void decrease_generated_token_occurance(int64_t token_id) {
386+
OPENVINO_ASSERT(m_unique_generated_token_ids->count(token_id) > 0);
387+
m_unique_generated_token_ids->at(token_id)--;
388+
}
389+
380390
};

src/cpp/src/sampler.cpp

Lines changed: 145 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -465,12 +465,13 @@ void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, Sa
465465
}
466466
}
467467

468-
Logits Sampler::_get_logit_vector(ov::Tensor logits, size_t batch_idx) {
468+
Logits Sampler::_get_logit_vector(ov::Tensor logits, size_t batch_idx, size_t token_idx) {
469469
ov::Shape logits_shape = logits.get_shape();
470470
size_t batch_size = logits_shape[0], seq_len = logits_shape[1], vocab_size = logits_shape[2];
471471
OPENVINO_ASSERT(batch_idx <= batch_size);
472+
OPENVINO_ASSERT(token_idx < seq_len);
472473
size_t batch_offset = batch_idx * seq_len * vocab_size;
473-
size_t sequence_offset = (seq_len - 1) * vocab_size;
474+
size_t sequence_offset = (seq_len - token_idx - 1) * vocab_size;
474475
float* logits_data = logits.data<float>() + batch_offset + sequence_offset;
475476

476477
return Logits{logits_data, vocab_size};
@@ -560,15 +561,89 @@ std::vector<int64_t> Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen
560561
return dropped_seq_ids;
561562
}
562563

564+
void register_new_token(const Token& sampled_token_id,
565+
Sequence::Ptr running_sequence,
566+
LogitProcessor& logit_processor,
567+
bool is_extend_sequence,
568+
bool is_update_len_logit_processor) {
569+
logit_processor.register_new_generated_token(sampled_token_id.m_index);
570+
size_t generated_len = logit_processor.get_generated_len();
571+
if (is_extend_sequence) {
572+
running_sequence->append_token(sampled_token_id.m_index, sampled_token_id.m_log_prob);
573+
} else {
574+
// just update the token log prob in case of successfully validated token
575+
OPENVINO_ASSERT(generated_len < running_sequence->get_generated_len());
576+
running_sequence->update_generated_log_prob(generated_len, sampled_token_id.m_log_prob);
577+
}
578+
// increment seq len only for one sequence in sequence group to sync them
579+
if (is_update_len_logit_processor) {
580+
logit_processor.update_generated_len(++generated_len);
581+
}
582+
};
583+
584+
std::list<uint64_t>
585+
create_n_forked_sequences(SequenceGroup::Ptr sequence_group,
586+
LogitProcessor& logit_processor,
587+
const std::vector<Token>& sampled_tokens) {
588+
const auto& running_sequences = sequence_group->get_running_sequences();
589+
OPENVINO_ASSERT(running_sequences.size() == 1);
590+
Sequence::Ptr sequence_to_fork = running_sequences.front();
591+
std::list<uint64_t> forked_seq_ids;
592+
for (size_t i = 1; i < sampled_tokens.size(); ++i) {
593+
const auto forked_sequence = sequence_group->fork_sequence(sequence_to_fork);
594+
const auto forked_seq_id = forked_sequence->get_id();
595+
forked_seq_ids.push_back(forked_seq_id);
596+
register_new_token(sampled_tokens[i], forked_sequence, logit_processor, true, false);
597+
}
598+
return forked_seq_ids;
599+
}
600+
601+
bool
602+
is_continue_to_sample_tokens(Sequence::Ptr running_sequence,
603+
size_t token_idx,
604+
size_t max_gen_len,
605+
size_t& decrease_context_len_per_seq_group) {
606+
if (max_gen_len == 0) {
607+
running_sequence->remove_last_tokens(token_idx);
608+
decrease_context_len_per_seq_group = std::max(decrease_context_len_per_seq_group, token_idx);
609+
return false;
610+
}
611+
return true;
612+
}
613+
614+
bool
615+
validate_candidate(Sequence::Ptr running_sequence,
616+
size_t& token_idx,
617+
Token& sampled_token,
618+
bool& is_extend_sequence,
619+
size_t& decrease_context_len_per_seq_group) {
620+
if (token_idx > 0) {
621+
const auto& generated_tokens = running_sequence->get_generated_ids();
622+
auto it = generated_tokens.rbegin();
623+
std::advance(it, token_idx - 1);
624+
// to validate candidates from assisting model and remove incorrect ones from generated sequence
625+
if (*it != sampled_token.m_index) {
626+
running_sequence->remove_last_tokens(token_idx);
627+
decrease_context_len_per_seq_group = std::max(decrease_context_len_per_seq_group, token_idx);
628+
is_extend_sequence = true;
629+
return false;
630+
} else {
631+
sampled_token.m_index = *it;
632+
}
633+
}
634+
return true;
563635

564-
SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups, ov::Tensor logits) {
636+
}
637+
638+
SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
639+
ov::Tensor logits,
640+
bool is_validation_mode_enabled) {
565641
const float * logits_data = logits.data<float>();
566642
ov::Shape logits_shape = logits.get_shape();
567643
OPENVINO_ASSERT(logits_shape.size() == 3);
568644
size_t batch_seq_len = logits_shape[1], vocab_size = logits_shape[2];
569645

570646
SamplerOutput sampler_output;
571-
572647
for (size_t sequence_group_id = 0, currently_processed_tokens = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) {
573648
SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id];
574649
if (!sequence_group->is_scheduled())
@@ -587,45 +662,67 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
587662

588663
const void * sequence_group_logits_data = logits_data + vocab_size * currently_processed_tokens;
589664
ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, actual_seq_len, vocab_size}, (void *)sequence_group_logits_data);
590-
665+
size_t decrease_context_len_per_seq_group = 0;
591666
if (sequence_group->requires_sampling()) {
667+
// get number of token to be validated
668+
auto num_tokens_to_process = sequence_group->get_num_tokens_to_validate();
592669
if (sampling_params.is_greedy_decoding() || sampling_params.is_multinomial()) {
593670
std::vector<Sequence::Ptr> running_sequences = sequence_group->get_running_sequences();
594671
if (sampling_params.is_greedy_decoding()) {
595672
OPENVINO_ASSERT(num_running_sequences == 1);
596673
}
597-
auto register_new_token = [&](const Token& sampled_token_id, Sequence::Ptr running_sequence) {
598-
logit_processor.register_new_generated_token(sampled_token_id.m_index);
599-
running_sequence->append_token(sampled_token_id.m_index, sampled_token_id.m_log_prob);
600-
};
601674
for (size_t running_sequence_id = 0; running_sequence_id < num_running_sequences; ++running_sequence_id) {
602-
auto logit_vector = _get_logit_vector(sequence_group_logits, running_sequence_id);
603-
logit_processor.apply(logit_vector);
604-
Token sampled_token_id;
605-
if (sampling_params.is_greedy_decoding()) {
606-
sampled_token_id = _greedy_sample(logit_vector);
607-
} else {
608-
// is_multinomial()
609-
const bool is_generate_n_tokens = sequence_group->num_total_seqs() == 1;
610-
const size_t num_tokens_per_sequence = is_generate_n_tokens ? sampling_params.num_return_sequences : 1;
611-
auto sampled_token_ids = _multinomial_sample(logit_vector, num_tokens_per_sequence);
612-
sampled_token_id = sampled_token_ids[0];
613-
614-
if (is_generate_n_tokens) {
615-
auto sequence_to_fork = running_sequences[0];
616-
std::list<uint64_t> forked_seq_ids;
617-
for (size_t i = num_running_sequences; i < num_tokens_per_sequence; ++i) {
618-
const auto forked_sequence = sequence_group->fork_sequence(sequence_to_fork);
619-
forked_seq_ids.push_back(forked_sequence->get_id());
620-
register_new_token(sampled_token_ids[i], forked_sequence);
675+
auto& running_sequence = running_sequences[running_sequence_id];
676+
// make `num_tokens_to_process` iteration to validate a candidate generated by `draft_model` + 1 iteration to generate one more token by `main_model`
677+
for (size_t i = 0; i <= num_tokens_to_process; ++i) {
678+
// calculate token offset from the end of logit
679+
size_t token_offset = num_tokens_to_process - i;
680+
// max counter of needed to be sampled tokens
681+
size_t max_num_sampled_token = sampling_params.max_new_tokens + token_offset - running_sequence->get_generated_len();
682+
if (!is_continue_to_sample_tokens(running_sequence, token_offset, max_num_sampled_token, decrease_context_len_per_seq_group)) {
683+
break;
684+
}
685+
686+
// do sampling only for token validation/generation.
687+
// continue in case of extending draft model sequences by main model generated tokens which
688+
// should be taken to KV cache without validation
689+
if (!is_validation_mode_enabled && token_offset > 0) {
690+
continue;
691+
}
692+
693+
auto logit_vector = _get_logit_vector(sequence_group_logits, running_sequence_id, token_offset);
694+
logit_processor.apply(logit_vector);
695+
696+
Token sampled_token_id;
697+
if (sampling_params.is_greedy_decoding()) {
698+
sampled_token_id = _greedy_sample(logit_vector);
699+
} else {
700+
// is_multinomial()
701+
const bool is_generate_n_tokens = sequence_group->num_total_seqs() == 1;
702+
const size_t num_tokens_per_sequence = is_generate_n_tokens ? sampling_params.num_return_sequences : 1;
703+
auto sampled_token_ids = _multinomial_sample(logit_vector, num_tokens_per_sequence);
704+
OPENVINO_ASSERT(sampled_token_ids.size(), num_tokens_per_sequence);
705+
if (is_generate_n_tokens) {
706+
const auto forked_seq_ids = create_n_forked_sequences(sequence_group, logit_processor, sampled_token_ids);
707+
sampler_output.m_forked_sequences.insert({running_sequences[0]->get_id(), forked_seq_ids});
621708
}
622-
sampler_output.m_forked_sequences.insert({running_sequences[0]->get_id(), forked_seq_ids});
709+
sampled_token_id = sampled_token_ids.front();
710+
}
711+
// flag to add sampled token to generated sequence or extend logit processors only
712+
bool is_extend_sequence = token_offset == 0,
713+
// flag to update generated length of sequence group in logit processor
714+
is_update_len_logit_processor = running_sequence_id == num_running_sequences - 1,
715+
is_validation_passed = true;
716+
if (is_validation_mode_enabled) {
717+
is_validation_passed = validate_candidate(running_sequences[running_sequence_id], token_offset, sampled_token_id, is_extend_sequence, decrease_context_len_per_seq_group);
718+
}
719+
register_new_token(sampled_token_id, running_sequences[running_sequence_id], logit_processor, is_extend_sequence, is_update_len_logit_processor);
720+
// to exit from sampling in case of failed token validation
721+
if (!is_validation_passed) {
722+
break;
623723
}
624724
}
625-
626-
register_new_token(sampled_token_id, running_sequences[running_sequence_id]);
627725
}
628-
logit_processor.increment_gen_tokens();
629726
for (const auto& dropped_seq_id : _try_finish_generation(sequence_group)) {
630727
sampler_output.m_dropped_sequences.push_back(dropped_seq_id);
631728
}
@@ -658,14 +755,29 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
658755
// NOTE: it should be before 'get_num_scheduled_tokens' is used
659756
// update internal state of sequence group to reset scheduler tokens and update currently processed ones
660757
sequence_group->finish_iteration();
758+
// decrease sequence_group context in case of candidates generated by draft_model were not accepted by main_model
759+
if (decrease_context_len_per_seq_group) {
760+
const auto num_processed_tokens = sequence_group->get_num_processed_tokens();
761+
OPENVINO_ASSERT(num_processed_tokens >= decrease_context_len_per_seq_group);
762+
OPENVINO_ASSERT(sequence_group->get_context_len() >= decrease_context_len_per_seq_group);
763+
sequence_group->update_processed_tokens_num(num_processed_tokens - decrease_context_len_per_seq_group);
764+
}
661765

662766
// accumulate a number of processed tokens
663-
currently_processed_tokens += padded_amount_of_processed_tokens * num_running_sequences;
767+
currently_processed_tokens += (padded_amount_of_processed_tokens - decrease_context_len_per_seq_group) * num_running_sequences;
664768
}
665769

666770
return sampler_output;
667771
}
668772

773+
void Sampler::update_logit_processor(uint64_t request_id, uint64_t token_id) {
774+
OPENVINO_ASSERT(m_logit_processors.count(request_id));
775+
auto& logit_processor = m_logit_processors.at(request_id);
776+
logit_processor.decrease_generated_token_occurance(token_id);
777+
auto gen_size = logit_processor.get_generated_len();
778+
logit_processor.update_generated_len(gen_size - 1);
779+
}
780+
669781
void Sampler::clear_beam_search_info(uint64_t request_id) {
670782
m_beam_search_info.erase(request_id);
671783
}

src/cpp/src/sampler.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,11 @@ struct SamplerOutput {
4141
class Sampler {
4242
class GroupBeamSearcher;
4343

44-
Logits _get_logit_vector(ov::Tensor logits, size_t batch_idx = 1);
44+
Logits _get_logit_vector(ov::Tensor logits, size_t batch_idx, size_t token_idx);
4545
Token _greedy_sample(const Logits& logits) const;
4646
std::vector<Token> _multinomial_sample(const Logits& logits, size_t num_tokens_per_sequence);
4747
std::vector<int64_t> _try_finish_generation(SequenceGroup::Ptr & sequence_group);
48+
void update_logit_processor(uint64_t request_id, uint64_t token_id);
4849

4950
// request ID => beam search tracking information
5051
std::map<uint64_t, GroupBeamSearcher> m_beam_search_info;
@@ -56,9 +57,10 @@ class Sampler {
5657
Tokenizer m_tokenizer;
5758

5859
public:
60+
Sampler() = default;
5961
Sampler(Tokenizer & tokenizer) : m_tokenizer(tokenizer) {};
6062

61-
SamplerOutput sample(std::vector<SequenceGroup::Ptr> & sequence_groups, ov::Tensor logits);
63+
SamplerOutput sample(std::vector<SequenceGroup::Ptr> & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false);
6264
void set_seed(size_t seed) { rng_engine.seed(seed); }
6365
void clear_beam_search_info(uint64_t request_id);
6466
};

src/cpp/src/sequence_group.hpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ class Sequence {
152152
return m_cumulative_log_prob;
153153
}
154154

155+
void update_generated_log_prob(size_t idx, float log_prob) {
156+
OPENVINO_ASSERT(idx < m_generated_log_probs.size());
157+
m_generated_log_probs[idx] = log_prob;
158+
}
159+
155160
float get_beam_search_score(const ov::genai::GenerationConfig& sampling_params) const {
156161
float cumulative_log_prob = get_cumulative_log_probs(), current_length = get_generated_len();
157162
float score = cumulative_log_prob / std::pow(current_length, sampling_params.length_penalty);
@@ -199,6 +204,8 @@ class SequenceGroup {
199204
size_t m_num_scheduled_tokens = 0;
200205
// context length of longest sequence within a group
201206
size_t m_max_content_len = 0;
207+
// max validation length within a group to check generated tokens
208+
size_t m_num_validated_tokens = 0;
202209

203210

204211
SequenceGroup(uint64_t request_id, const ov::genai::GenerationConfig& sampling_params, std::size_t block_size, bool enable_prefix_caching)
@@ -413,18 +420,29 @@ class SequenceGroup {
413420

414421
void clear_scheduled_tokens() {
415422
m_num_scheduled_tokens = 0;
423+
m_num_validated_tokens = 0;
416424
}
417425

418426
bool is_scheduled() const {
419427
return m_num_scheduled_tokens > 0;
420428
}
421429

430+
void set_num_validated_tokens(size_t k) {
431+
// in case of non-prompt we need to take prev tokens + token to validate
432+
// m_num_validated_tokens = get_num_processed_tokens() ? k + 1 : k;
433+
m_num_validated_tokens = k;
434+
}
435+
436+
size_t get_num_tokens_to_validate() {
437+
return m_num_validated_tokens;
438+
}
439+
422440
size_t get_num_available_tokens_for_batching() const {
423441
OPENVINO_ASSERT(!has_finished(), "Internal error: this function cannot be called on finished sequence group");
424442
OPENVINO_ASSERT(get_num_scheduled_tokens() == 0, "Internal error: this function cannot be called when we are already in scheduling phase");
425443
// if sequence group has not finished, it has at least one token to process
426444
size_t num_available_tokens = std::max(get_prompt_len(), m_max_content_len);
427-
return std::max<size_t>(num_available_tokens - m_num_processed_tokens, 1u);
445+
return std::max<size_t>(num_available_tokens - m_num_processed_tokens, 1u) + m_num_validated_tokens;
428446
}
429447

430448
// mark current schedule phase as finished and updates internal counters

tests/cpp/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ endif()
1616
set(TEST_TARGET_NAME "tests_continuous_batching")
1717
file(GLOB tests_src "*.cpp")
1818
file(GLOB src_files "${OpenVINOGenAI_SOURCE_DIR}/src/cpp/src/sequence_group.cpp"
19-
"${OpenVINOGenAI_SOURCE_DIR}/src/cpp/src/cache_eviction.cpp")
19+
"${OpenVINOGenAI_SOURCE_DIR}/src/cpp/src/cache_eviction.cpp"
20+
"${OpenVINOGenAI_SOURCE_DIR}/src/cpp/src/sampler.cpp")
2021

2122
add_executable(${TEST_TARGET_NAME} ${tests_src}
2223
block_allocator.cpp)

0 commit comments

Comments
 (0)