diff --git a/xllm/core/framework/batch/batch_input_builder.cpp b/xllm/core/framework/batch/batch_input_builder.cpp index 37a55bbff..fb58da771 100755 --- a/xllm/core/framework/batch/batch_input_builder.cpp +++ b/xllm/core/framework/batch/batch_input_builder.cpp @@ -72,6 +72,10 @@ BatchInputBuilder::BatchInputBuilder( ForwardInput BatchInputBuilder::build_forward_input( uint32_t num_decoding_tokens, uint32_t min_decoding_batch_size) { + need_unique_tokens_ = + std::any_of(sequences_.begin(), sequences_.end(), [](Sequence* seq) { + return seq->check_need_unique_tokens(); + }); process_sequences(); padding_decode_batch_size(num_decoding_tokens, min_decoding_batch_size); @@ -79,6 +83,10 @@ ForwardInput BatchInputBuilder::build_forward_input( } RawForwardInput BatchInputBuilder::build_raw_forward_input() { + need_unique_tokens_ = + std::any_of(sequences_.begin(), sequences_.end(), [](Sequence* seq) { + return seq->check_need_unique_tokens(); + }); if (!thread_pool_ || num_sequences_ < thread_pool_->size()) { process_sequences(); } else { @@ -318,14 +326,6 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence, const auto& token_ids = sequence->tokens(); const uint32_t n_tokens = token_ids.size(); - // Prepare adjusted token counts for sampling - std::unordered_map adjusted_token_to_count_map; - for (uint32_t j = n_kv_cache_tokens; j < seq_len; ++j) { - // skip prompt tokens except the last one - if (j + 1 < n_tokens) continue; - ++adjusted_token_to_count_map[token_ids[j]]; - } - // Handle MRope positions if (use_mrope_) { const auto& args = *args_; @@ -344,8 +344,7 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence, // Handle sampling for last tokens if (j + 1 < n_tokens) continue; - handle_sampling_parameters( - sequence, j, seq_len, adjusted_token_to_count_map, state_ptr); + handle_sampling_parameters(sequence, j, seq_len, state_ptr); } // Add extra token id @@ -359,52 +358,36 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence, } } -void BatchInputBuilder::handle_sampling_parameters( - Sequence* sequence, - uint32_t token_position, - uint32_t seq_len, - std::unordered_map& adjusted_token_to_count_map, - BuilderState* state_ptr) { +void BatchInputBuilder::handle_sampling_parameters(Sequence* sequence, + uint32_t token_position, + uint32_t seq_len, + BuilderState* state_ptr) { BuilderState& state = state_ptr ? *state_ptr : state_; // const auto token_ids = sequence->token_ids(); const auto token_id = sequence->tokens()[token_position]; - // Adjust token count - --adjusted_token_to_count_map[token_id]; - // Select token for sampling state.selected_token_idxes.push_back(state.flatten_tokens_vec.size() - 1); state.sampling_params.push_back(sequence->sampling_param()); + state.sample_idxes.push_back(state.selected_token_idxes.size() - 1); // Process unique tokens - const auto& seq_token_counts = sequence->token_to_count_map(); - auto& ids = state.unique_token_ids_vec.emplace_back(); - auto& counts = state.unique_token_counts_vec.emplace_back(); - - ids.reserve(seq_token_counts.size()); - counts.reserve(seq_token_counts.size()); + if (need_unique_tokens_) { + const auto& seq_token_counts = sequence->token_to_count_map(); + auto& ids = state.unique_token_ids_vec.emplace_back(); + auto& counts = state.unique_token_counts_vec.emplace_back(); - for (const auto& [token_id, count] : seq_token_counts) { - const auto it = adjusted_token_to_count_map.find(token_id); - const auto adjust_count = - (it != adjusted_token_to_count_map.end()) ? it->second : 0; + ids.reserve(seq_token_counts.size()); + counts.reserve(seq_token_counts.size()); - if (count > adjust_count) { + for (const auto& [token_id, count] : seq_token_counts) { + CHECK(count >= 0) << "token count should be greater than 0"; ids.push_back(token_id); - counts.push_back(count - adjust_count); + counts.push_back(count); } - } - - state.unique_token_lens_vec.push_back(static_cast(ids.size())); - // Mark sample token if it's the last token - // TODO add test - // in chunked prefill condition, if allowed_max_token = 128, n_tokens=1000, - // n_kv_cache_tokens=256, q_seq_len = 128, seq_len=384 - if (token_position == seq_len - 1) { - state.sample_idxes.push_back( - static_cast(state.selected_token_idxes.size() - 1)); + state.unique_token_lens_vec.push_back(static_cast(ids.size())); } } diff --git a/xllm/core/framework/batch/batch_input_builder.h b/xllm/core/framework/batch/batch_input_builder.h index 8f18d6ac3..d1a9a9974 100644 --- a/xllm/core/framework/batch/batch_input_builder.h +++ b/xllm/core/framework/batch/batch_input_builder.h @@ -121,12 +121,10 @@ class BatchInputBuilder { uint32_t n_kv_cache_tokens, uint32_t seq_len, BuilderState* state_ptr = nullptr); - void handle_sampling_parameters( - Sequence* sequence, - uint32_t token_position, - uint32_t seq_len, - std::unordered_map& adjusted_counts, - BuilderState* state_ptr = nullptr); + void handle_sampling_parameters(Sequence* sequence, + uint32_t token_position, + uint32_t seq_len, + BuilderState* state_ptr = nullptr); void setup_kv_cache_info( Sequence* sequence, uint32_t n_kv_cache_tokens, @@ -153,6 +151,7 @@ class BatchInputBuilder { // Configuration bool use_mrope_ = false; uint32_t num_sequences_ = 0; + bool need_unique_tokens_ = true; // copy in and out cache contents std::unordered_set write_block_ids_; diff --git a/xllm/core/framework/batch/batch_test.cpp b/xllm/core/framework/batch/batch_test.cpp index c3905997c..28c58782c 100644 --- a/xllm/core/framework/batch/batch_test.cpp +++ b/xllm/core/framework/batch/batch_test.cpp @@ -192,13 +192,13 @@ TEST(BatchTest, Basic) { // seq4 has no sampling parameters EXPECT_TRUE(equal(sampling_params.unique_token_ids, unique_ids)); - const std::vector unique_counts = { - /*seq1*/ 1, 1, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - /*seq2*/ 1, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - /*seq3*/ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 - }; - // seq4 has no sampling parameters - EXPECT_TRUE(equal(sampling_params.unique_token_counts, unique_counts)); + // const std::vector unique_counts = { + // /*seq1*/ 1, 1, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + // /*seq2*/ 1, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + // /*seq3*/ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + // }; + // // seq4 has no sampling parameters + // EXPECT_TRUE(equal(sampling_params.unique_token_counts, unique_counts)); const std::vector token_ids_lens = {6, 5, 16}; EXPECT_TRUE(equal(sampling_params.unique_token_ids_lens, token_ids_lens)); diff --git a/xllm/core/framework/request/sequence.cpp b/xllm/core/framework/request/sequence.cpp index 297986f47..24cc47f8d 100644 --- a/xllm/core/framework/request/sequence.cpp +++ b/xllm/core/framework/request/sequence.cpp @@ -56,11 +56,21 @@ Sequence::Sequence(size_t index, // init logprob state logprob_state_ = std::make_unique(num_prompt_tokens_, capacity); + if (sequence_params_.sampling_param->frequency_penalty != 0 || + sequence_params_.sampling_param->presence_penalty != 0 || + sequence_params_.sampling_param->repetition_penalty != 1) { + need_unique_tokens_ = true; + } + // add the prompt tokens for (const auto token_id : prompt_token_ids) { tokens_[num_tokens_++] = token_id; - token_to_count_map_[token_id]++; + if (need_unique_tokens_) { + token_to_count_map_[token_id] = 0; + } } + // need one token to padding even dont need token count + token_to_count_map_[prompt_token_ids.back()] = 0; input_embedding_ = input_embedding; cur_generated_token_idx_ = num_prompt_tokens_; } @@ -133,7 +143,9 @@ void Sequence::append_token(const Token& token) { return; } - token_to_count_map_[token_id]++; + if (need_unique_tokens_) { + token_to_count_map_[token_id]++; + } // update logprobs if needed if (sequence_params_.sampling_param->logprobs) { logprob_state_->update_logprob( @@ -174,7 +186,9 @@ void Sequence::update_last_step_token(const Token& token, size_t token_offset) { const int32_t token_id = static_cast(token.id); tokens_[cur_generated_token_idx_] = token_id; - token_to_count_map_[token_id]++; + if (need_unique_tokens_) { + token_to_count_map_[token_id]++; + } // update logprobs if needed if (sequence_params_.sampling_param->logprobs) { logprob_state_->update_logprob( @@ -202,8 +216,10 @@ void Sequence::update_token(size_t index, const Token& token) { const int32_t origin_token_id = tokens_[index]; const int32_t token_id = static_cast(token.id); tokens_[index] = token_id; - --token_to_count_map_[origin_token_id]; - ++token_to_count_map_[token_id]; + if (need_unique_tokens_) { + --token_to_count_map_[origin_token_id]; + ++token_to_count_map_[token_id]; + } // update logprobs if needed if (sequence_params_.sampling_param->logprobs) { logprob_state_->update_logprob( diff --git a/xllm/core/framework/request/sequence.h b/xllm/core/framework/request/sequence.h index 3fad73a1a..c98e47076 100644 --- a/xllm/core/framework/request/sequence.h +++ b/xllm/core/framework/request/sequence.h @@ -255,6 +255,8 @@ class Sequence final { return sequence_params_.sampling_param->beam_width > 1; } + bool check_need_unique_tokens() { return need_unique_tokens_; } + LogprobState* logprob_state() { return logprob_state_.get(); } // set sequence id @@ -310,6 +312,7 @@ class Sequence final { // the count of each token id std::unordered_map token_to_count_map_; + bool need_unique_tokens_ = false; // the length of the prompt tokens size_t num_prompt_tokens_ = 0; diff --git a/xllm/core/framework/sampling/sampling_params.cpp b/xllm/core/framework/sampling/sampling_params.cpp index 4ed8cb777..205e9fc1b 100644 --- a/xllm/core/framework/sampling/sampling_params.cpp +++ b/xllm/core/framework/sampling/sampling_params.cpp @@ -35,9 +35,6 @@ void SamplingParameters::init( const std::vector& unique_token_lens_vec) { CHECK_EQ(req_sampling_params.size(), selected_token_idxes.size()); CHECK_GE(req_sampling_params.size(), sample_idxes.size()); - CHECK_EQ(req_sampling_params.size(), unique_token_ids_vec.size()); - CHECK_EQ(req_sampling_params.size(), unique_token_counts_vec.size()); - CHECK_EQ(req_sampling_params.size(), unique_token_lens_vec.size()); std::vector frequency_penalties; std::vector presence_penalties; @@ -118,6 +115,9 @@ void SamplingParameters::init( this->selected_token_idxes = torch::tensor(selected_token_idxes, int_tensor_options); if (need_token_stats) { + CHECK_EQ(req_sampling_params.size(), unique_token_ids_vec.size()); + CHECK_EQ(req_sampling_params.size(), unique_token_counts_vec.size()); + CHECK_EQ(req_sampling_params.size(), unique_token_lens_vec.size()); this->unique_token_ids = create_2d_tensor(unique_token_ids_vec, torch::kInt64); this->unique_token_counts =