@@ -465,12 +465,13 @@ void Sampler::GroupBeamSearcher::select_next_tokens(const ov::Tensor& logits, Sa
465465 }
466466}
467467
468- Logits Sampler::_get_logit_vector (ov::Tensor logits, size_t batch_idx) {
468+ Logits Sampler::_get_logit_vector (ov::Tensor logits, size_t batch_idx, size_t token_idx ) {
469469 ov::Shape logits_shape = logits.get_shape ();
470470 size_t batch_size = logits_shape[0 ], seq_len = logits_shape[1 ], vocab_size = logits_shape[2 ];
471471 OPENVINO_ASSERT (batch_idx <= batch_size);
472+ OPENVINO_ASSERT (token_idx < seq_len);
472473 size_t batch_offset = batch_idx * seq_len * vocab_size;
473- size_t sequence_offset = (seq_len - 1 ) * vocab_size;
474+ size_t sequence_offset = (seq_len - token_idx - 1 ) * vocab_size;
474475 float * logits_data = logits.data <float >() + batch_offset + sequence_offset;
475476
476477 return Logits{logits_data, vocab_size};
@@ -560,15 +561,89 @@ std::vector<int64_t> Sampler::_try_finish_generation(SequenceGroup::Ptr & sequen
560561 return dropped_seq_ids;
561562}
562563
564+ void register_new_token (const Token& sampled_token_id,
565+ Sequence::Ptr running_sequence,
566+ LogitProcessor& logit_processor,
567+ bool is_extend_sequence,
568+ bool is_update_len_logit_processor) {
569+ logit_processor.register_new_generated_token (sampled_token_id.m_index );
570+ size_t generated_len = logit_processor.get_generated_len ();
571+ if (is_extend_sequence) {
572+ running_sequence->append_token (sampled_token_id.m_index , sampled_token_id.m_log_prob );
573+ } else {
574+ // just update the token log prob in case of successfully validated token
575+ OPENVINO_ASSERT (generated_len < running_sequence->get_generated_len ());
576+ running_sequence->update_generated_log_prob (generated_len, sampled_token_id.m_log_prob );
577+ }
578+ // increment seq len only for one sequence in sequence group to sync them
579+ if (is_update_len_logit_processor) {
580+ logit_processor.update_generated_len (++generated_len);
581+ }
582+ };
583+
584+ std::list<uint64_t >
585+ create_n_forked_sequences (SequenceGroup::Ptr sequence_group,
586+ LogitProcessor& logit_processor,
587+ const std::vector<Token>& sampled_tokens) {
588+ const auto & running_sequences = sequence_group->get_running_sequences ();
589+ OPENVINO_ASSERT (running_sequences.size () == 1 );
590+ Sequence::Ptr sequence_to_fork = running_sequences.front ();
591+ std::list<uint64_t > forked_seq_ids;
592+ for (size_t i = 1 ; i < sampled_tokens.size (); ++i) {
593+ const auto forked_sequence = sequence_group->fork_sequence (sequence_to_fork);
594+ const auto forked_seq_id = forked_sequence->get_id ();
595+ forked_seq_ids.push_back (forked_seq_id);
596+ register_new_token (sampled_tokens[i], forked_sequence, logit_processor, true , false );
597+ }
598+ return forked_seq_ids;
599+ }
600+
601+ bool
602+ is_continue_to_sample_tokens (Sequence::Ptr running_sequence,
603+ size_t token_idx,
604+ size_t max_gen_len,
605+ size_t & decrease_context_len_per_seq_group) {
606+ if (max_gen_len == 0 ) {
607+ running_sequence->remove_last_tokens (token_idx);
608+ decrease_context_len_per_seq_group = std::max (decrease_context_len_per_seq_group, token_idx);
609+ return false ;
610+ }
611+ return true ;
612+ }
613+
614+ bool
615+ validate_candidate (Sequence::Ptr running_sequence,
616+ size_t & token_idx,
617+ Token& sampled_token,
618+ bool & is_extend_sequence,
619+ size_t & decrease_context_len_per_seq_group) {
620+ if (token_idx > 0 ) {
621+ const auto & generated_tokens = running_sequence->get_generated_ids ();
622+ auto it = generated_tokens.rbegin ();
623+ std::advance (it, token_idx - 1 );
624+ // to validate candidates from assisting model and remove incorrect ones from generated sequence
625+ if (*it != sampled_token.m_index ) {
626+ running_sequence->remove_last_tokens (token_idx);
627+ decrease_context_len_per_seq_group = std::max (decrease_context_len_per_seq_group, token_idx);
628+ is_extend_sequence = true ;
629+ return false ;
630+ } else {
631+ sampled_token.m_index = *it;
632+ }
633+ }
634+ return true ;
563635
564- SamplerOutput Sampler::sample (std::vector<SequenceGroup::Ptr> & sequence_groups, ov::Tensor logits) {
636+ }
637+
638+ SamplerOutput Sampler::sample (std::vector<SequenceGroup::Ptr> & sequence_groups,
639+ ov::Tensor logits,
640+ bool is_validation_mode_enabled) {
565641 const float * logits_data = logits.data <float >();
566642 ov::Shape logits_shape = logits.get_shape ();
567643 OPENVINO_ASSERT (logits_shape.size () == 3 );
568644 size_t batch_seq_len = logits_shape[1 ], vocab_size = logits_shape[2 ];
569645
570646 SamplerOutput sampler_output;
571-
572647 for (size_t sequence_group_id = 0 , currently_processed_tokens = 0 ; sequence_group_id < sequence_groups.size (); ++sequence_group_id) {
573648 SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id];
574649 if (!sequence_group->is_scheduled ())
@@ -587,45 +662,67 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
587662
588663 const void * sequence_group_logits_data = logits_data + vocab_size * currently_processed_tokens;
589664 ov::Tensor sequence_group_logits (ov::element::f32 , ov::Shape{num_running_sequences, actual_seq_len, vocab_size}, (void *)sequence_group_logits_data);
590-
665+ size_t decrease_context_len_per_seq_group = 0 ;
591666 if (sequence_group->requires_sampling ()) {
667+ // get number of token to be validated
668+ auto num_tokens_to_process = sequence_group->get_num_tokens_to_validate ();
592669 if (sampling_params.is_greedy_decoding () || sampling_params.is_multinomial ()) {
593670 std::vector<Sequence::Ptr> running_sequences = sequence_group->get_running_sequences ();
594671 if (sampling_params.is_greedy_decoding ()) {
595672 OPENVINO_ASSERT (num_running_sequences == 1 );
596673 }
597- auto register_new_token = [&](const Token& sampled_token_id, Sequence::Ptr running_sequence) {
598- logit_processor.register_new_generated_token (sampled_token_id.m_index );
599- running_sequence->append_token (sampled_token_id.m_index , sampled_token_id.m_log_prob );
600- };
601674 for (size_t running_sequence_id = 0 ; running_sequence_id < num_running_sequences; ++running_sequence_id) {
602- auto logit_vector = _get_logit_vector (sequence_group_logits, running_sequence_id);
603- logit_processor.apply (logit_vector);
604- Token sampled_token_id;
605- if (sampling_params.is_greedy_decoding ()) {
606- sampled_token_id = _greedy_sample (logit_vector);
607- } else {
608- // is_multinomial()
609- const bool is_generate_n_tokens = sequence_group->num_total_seqs () == 1 ;
610- const size_t num_tokens_per_sequence = is_generate_n_tokens ? sampling_params.num_return_sequences : 1 ;
611- auto sampled_token_ids = _multinomial_sample (logit_vector, num_tokens_per_sequence);
612- sampled_token_id = sampled_token_ids[0 ];
613-
614- if (is_generate_n_tokens) {
615- auto sequence_to_fork = running_sequences[0 ];
616- std::list<uint64_t > forked_seq_ids;
617- for (size_t i = num_running_sequences; i < num_tokens_per_sequence; ++i) {
618- const auto forked_sequence = sequence_group->fork_sequence (sequence_to_fork);
619- forked_seq_ids.push_back (forked_sequence->get_id ());
620- register_new_token (sampled_token_ids[i], forked_sequence);
675+ auto & running_sequence = running_sequences[running_sequence_id];
676+ // make `num_tokens_to_process` iteration to validate a candidate generated by `draft_model` + 1 iteration to generate one more token by `main_model`
677+ for (size_t i = 0 ; i <= num_tokens_to_process; ++i) {
678+ // calculate token offset from the end of logit
679+ size_t token_offset = num_tokens_to_process - i;
680+ // max counter of needed to be sampled tokens
681+ size_t max_num_sampled_token = sampling_params.max_new_tokens + token_offset - running_sequence->get_generated_len ();
682+ if (!is_continue_to_sample_tokens (running_sequence, token_offset, max_num_sampled_token, decrease_context_len_per_seq_group)) {
683+ break ;
684+ }
685+
686+ // do sampling only for token validation/generation.
687+ // continue in case of extending draft model sequences by main model generated tokens which
688+ // should be taken to KV cache without validation
689+ if (!is_validation_mode_enabled && token_offset > 0 ) {
690+ continue ;
691+ }
692+
693+ auto logit_vector = _get_logit_vector (sequence_group_logits, running_sequence_id, token_offset);
694+ logit_processor.apply (logit_vector);
695+
696+ Token sampled_token_id;
697+ if (sampling_params.is_greedy_decoding ()) {
698+ sampled_token_id = _greedy_sample (logit_vector);
699+ } else {
700+ // is_multinomial()
701+ const bool is_generate_n_tokens = sequence_group->num_total_seqs () == 1 ;
702+ const size_t num_tokens_per_sequence = is_generate_n_tokens ? sampling_params.num_return_sequences : 1 ;
703+ auto sampled_token_ids = _multinomial_sample (logit_vector, num_tokens_per_sequence);
704+ OPENVINO_ASSERT (sampled_token_ids.size (), num_tokens_per_sequence);
705+ if (is_generate_n_tokens) {
706+ const auto forked_seq_ids = create_n_forked_sequences (sequence_group, logit_processor, sampled_token_ids);
707+ sampler_output.m_forked_sequences .insert ({running_sequences[0 ]->get_id (), forked_seq_ids});
621708 }
622- sampler_output.m_forked_sequences .insert ({running_sequences[0 ]->get_id (), forked_seq_ids});
709+ sampled_token_id = sampled_token_ids.front ();
710+ }
711+ // flag to add sampled token to generated sequence or extend logit processors only
712+ bool is_extend_sequence = token_offset == 0 ,
713+ // flag to update generated length of sequence group in logit processor
714+ is_update_len_logit_processor = running_sequence_id == num_running_sequences - 1 ,
715+ is_validation_passed = true ;
716+ if (is_validation_mode_enabled) {
717+ is_validation_passed = validate_candidate (running_sequences[running_sequence_id], token_offset, sampled_token_id, is_extend_sequence, decrease_context_len_per_seq_group);
718+ }
719+ register_new_token (sampled_token_id, running_sequences[running_sequence_id], logit_processor, is_extend_sequence, is_update_len_logit_processor);
720+ // to exit from sampling in case of failed token validation
721+ if (!is_validation_passed) {
722+ break ;
623723 }
624724 }
625-
626- register_new_token (sampled_token_id, running_sequences[running_sequence_id]);
627725 }
628- logit_processor.increment_gen_tokens ();
629726 for (const auto & dropped_seq_id : _try_finish_generation (sequence_group)) {
630727 sampler_output.m_dropped_sequences .push_back (dropped_seq_id);
631728 }
@@ -658,14 +755,29 @@ SamplerOutput Sampler::sample(std::vector<SequenceGroup::Ptr> & sequence_groups,
658755 // NOTE: it should be before 'get_num_scheduled_tokens' is used
659756 // update internal state of sequence group to reset scheduler tokens and update currently processed ones
660757 sequence_group->finish_iteration ();
758+ // decrease sequence_group context in case of candidates generated by draft_model were not accepted by main_model
759+ if (decrease_context_len_per_seq_group) {
760+ const auto num_processed_tokens = sequence_group->get_num_processed_tokens ();
761+ OPENVINO_ASSERT (num_processed_tokens >= decrease_context_len_per_seq_group);
762+ OPENVINO_ASSERT (sequence_group->get_context_len () >= decrease_context_len_per_seq_group);
763+ sequence_group->update_processed_tokens_num (num_processed_tokens - decrease_context_len_per_seq_group);
764+ }
661765
662766 // accumulate a number of processed tokens
663- currently_processed_tokens += padded_amount_of_processed_tokens * num_running_sequences;
767+ currently_processed_tokens += ( padded_amount_of_processed_tokens - decrease_context_len_per_seq_group) * num_running_sequences;
664768 }
665769
666770 return sampler_output;
667771}
668772
773+ void Sampler::update_logit_processor (uint64_t request_id, uint64_t token_id) {
774+ OPENVINO_ASSERT (m_logit_processors.count (request_id));
775+ auto & logit_processor = m_logit_processors.at (request_id);
776+ logit_processor.decrease_generated_token_occurance (token_id);
777+ auto gen_size = logit_processor.get_generated_len ();
778+ logit_processor.update_generated_len (gen_size - 1 );
779+ }
780+
669781void Sampler::clear_beam_search_info (uint64_t request_id) {
670782 m_beam_search_info.erase (request_id);
671783}
0 commit comments