Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 24 additions & 41 deletions xllm/core/framework/batch/batch_input_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,21 @@ BatchInputBuilder::BatchInputBuilder(
ForwardInput BatchInputBuilder::build_forward_input(
uint32_t num_decoding_tokens,
uint32_t min_decoding_batch_size) {
need_unique_tokens_ =
std::any_of(sequences_.begin(), sequences_.end(), [](Sequence* seq) {
return seq->check_need_unique_tokens();
});
process_sequences();
padding_decode_batch_size(num_decoding_tokens, min_decoding_batch_size);

return state_to_forward_input();
}

RawForwardInput BatchInputBuilder::build_raw_forward_input() {
need_unique_tokens_ =
std::any_of(sequences_.begin(), sequences_.end(), [](Sequence* seq) {
return seq->check_need_unique_tokens();
});
if (!thread_pool_ || num_sequences_ < thread_pool_->size()) {
process_sequences();
} else {
Expand Down Expand Up @@ -318,14 +326,6 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence,
const auto& token_ids = sequence->tokens();
const uint32_t n_tokens = token_ids.size();

// Prepare adjusted token counts for sampling
std::unordered_map<int32_t, int32_t> adjusted_token_to_count_map;
for (uint32_t j = n_kv_cache_tokens; j < seq_len; ++j) {
// skip prompt tokens except the last one
if (j + 1 < n_tokens) continue;
++adjusted_token_to_count_map[token_ids[j]];
}

// Handle MRope positions
if (use_mrope_) {
const auto& args = *args_;
Expand All @@ -344,8 +344,7 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence,
// Handle sampling for last tokens
if (j + 1 < n_tokens) continue;

handle_sampling_parameters(
sequence, j, seq_len, adjusted_token_to_count_map, state_ptr);
handle_sampling_parameters(sequence, j, seq_len, state_ptr);
}

// Add extra token id
Expand All @@ -359,52 +358,36 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence,
}
}

void BatchInputBuilder::handle_sampling_parameters(
Sequence* sequence,
uint32_t token_position,
uint32_t seq_len,
std::unordered_map<int32_t, int32_t>& adjusted_token_to_count_map,
BuilderState* state_ptr) {
void BatchInputBuilder::handle_sampling_parameters(Sequence* sequence,
uint32_t token_position,
uint32_t seq_len,
BuilderState* state_ptr) {
BuilderState& state = state_ptr ? *state_ptr : state_;

// const auto token_ids = sequence->token_ids();
const auto token_id = sequence->tokens()[token_position];

// Adjust token count
--adjusted_token_to_count_map[token_id];

// Select token for sampling
state.selected_token_idxes.push_back(state.flatten_tokens_vec.size() - 1);
state.sampling_params.push_back(sequence->sampling_param());
state.sample_idxes.push_back(state.selected_token_idxes.size() - 1);

// Process unique tokens
const auto& seq_token_counts = sequence->token_to_count_map();
auto& ids = state.unique_token_ids_vec.emplace_back();
auto& counts = state.unique_token_counts_vec.emplace_back();

ids.reserve(seq_token_counts.size());
counts.reserve(seq_token_counts.size());
if (need_unique_tokens_) {
const auto& seq_token_counts = sequence->token_to_count_map();
auto& ids = state.unique_token_ids_vec.emplace_back();
auto& counts = state.unique_token_counts_vec.emplace_back();

for (const auto& [token_id, count] : seq_token_counts) {
const auto it = adjusted_token_to_count_map.find(token_id);
const auto adjust_count =
(it != adjusted_token_to_count_map.end()) ? it->second : 0;
ids.reserve(seq_token_counts.size());
counts.reserve(seq_token_counts.size());

if (count > adjust_count) {
for (const auto& [token_id, count] : seq_token_counts) {
CHECK(count >= 0) << "token count should be greater than 0";
ids.push_back(token_id);
counts.push_back(count - adjust_count);
counts.push_back(count);
}
}

state.unique_token_lens_vec.push_back(static_cast<int32_t>(ids.size()));

// Mark sample token if it's the last token
// TODO add test
// in chunked prefill condition, if allowed_max_token = 128, n_tokens=1000,
// n_kv_cache_tokens=256, q_seq_len = 128, seq_len=384
if (token_position == seq_len - 1) {
state.sample_idxes.push_back(
static_cast<int32_t>(state.selected_token_idxes.size() - 1));
state.unique_token_lens_vec.push_back(static_cast<int32_t>(ids.size()));
}
}

Expand Down
11 changes: 5 additions & 6 deletions xllm/core/framework/batch/batch_input_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,10 @@ class BatchInputBuilder {
uint32_t n_kv_cache_tokens,
uint32_t seq_len,
BuilderState* state_ptr = nullptr);
void handle_sampling_parameters(
Sequence* sequence,
uint32_t token_position,
uint32_t seq_len,
std::unordered_map<int32_t, int32_t>& adjusted_counts,
BuilderState* state_ptr = nullptr);
void handle_sampling_parameters(Sequence* sequence,
uint32_t token_position,
uint32_t seq_len,
BuilderState* state_ptr = nullptr);
void setup_kv_cache_info(
Sequence* sequence,
uint32_t n_kv_cache_tokens,
Expand All @@ -153,6 +151,7 @@ class BatchInputBuilder {
// Configuration
bool use_mrope_ = false;
uint32_t num_sequences_ = 0;
bool need_unique_tokens_ = true;

// copy in and out cache contents
std::unordered_set<int32_t> write_block_ids_;
Expand Down
14 changes: 7 additions & 7 deletions xllm/core/framework/batch/batch_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,13 @@ TEST(BatchTest, Basic) {
// seq4 has no sampling parameters
EXPECT_TRUE(equal(sampling_params.unique_token_ids, unique_ids));

const std::vector<int32_t> unique_counts = {
/*seq1*/ 1, 1, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
/*seq2*/ 1, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
/*seq3*/ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
};
// seq4 has no sampling parameters
EXPECT_TRUE(equal(sampling_params.unique_token_counts, unique_counts));
// const std::vector<int32_t> unique_counts = {
// /*seq1*/ 1, 1, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
// /*seq2*/ 1, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
// /*seq3*/ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
// };
// // seq4 has no sampling parameters
// EXPECT_TRUE(equal(sampling_params.unique_token_counts, unique_counts));

const std::vector<int32_t> token_ids_lens = {6, 5, 16};
EXPECT_TRUE(equal(sampling_params.unique_token_ids_lens, token_ids_lens));
Expand Down
26 changes: 21 additions & 5 deletions xllm/core/framework/request/sequence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,21 @@ Sequence::Sequence(size_t index,
// init logprob state
logprob_state_ = std::make_unique<LogprobState>(num_prompt_tokens_, capacity);

if (sequence_params_.sampling_param->frequency_penalty != 0 ||
sequence_params_.sampling_param->presence_penalty != 0 ||
sequence_params_.sampling_param->repetition_penalty != 1) {
need_unique_tokens_ = true;
}

// add the prompt tokens
for (const auto token_id : prompt_token_ids) {
tokens_[num_tokens_++] = token_id;
token_to_count_map_[token_id]++;
if (need_unique_tokens_) {
token_to_count_map_[token_id] = 0;
}
}
// need one token to padding even dont need token count
token_to_count_map_[prompt_token_ids.back()] = 0;
input_embedding_ = input_embedding;
cur_generated_token_idx_ = num_prompt_tokens_;
}
Expand Down Expand Up @@ -133,7 +143,9 @@ void Sequence::append_token(const Token& token) {
return;
}

token_to_count_map_[token_id]++;
if (need_unique_tokens_) {
token_to_count_map_[token_id]++;
}
// update logprobs if needed
if (sequence_params_.sampling_param->logprobs) {
logprob_state_->update_logprob(
Expand Down Expand Up @@ -174,7 +186,9 @@ void Sequence::update_last_step_token(const Token& token, size_t token_offset) {

const int32_t token_id = static_cast<int32_t>(token.id);
tokens_[cur_generated_token_idx_] = token_id;
token_to_count_map_[token_id]++;
if (need_unique_tokens_) {
token_to_count_map_[token_id]++;
}
// update logprobs if needed
if (sequence_params_.sampling_param->logprobs) {
logprob_state_->update_logprob(
Expand Down Expand Up @@ -202,8 +216,10 @@ void Sequence::update_token(size_t index, const Token& token) {
const int32_t origin_token_id = tokens_[index];
const int32_t token_id = static_cast<int32_t>(token.id);
tokens_[index] = token_id;
--token_to_count_map_[origin_token_id];
++token_to_count_map_[token_id];
if (need_unique_tokens_) {
--token_to_count_map_[origin_token_id];
++token_to_count_map_[token_id];
}
// update logprobs if needed
if (sequence_params_.sampling_param->logprobs) {
logprob_state_->update_logprob(
Expand Down
3 changes: 3 additions & 0 deletions xllm/core/framework/request/sequence.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ class Sequence final {
return sequence_params_.sampling_param->beam_width > 1;
}

bool check_need_unique_tokens() { return need_unique_tokens_; }

LogprobState* logprob_state() { return logprob_state_.get(); }

// set sequence id
Expand Down Expand Up @@ -310,6 +312,7 @@ class Sequence final {

// the count of each token id
std::unordered_map<int32_t, int32_t> token_to_count_map_;
bool need_unique_tokens_ = false;

// the length of the prompt tokens
size_t num_prompt_tokens_ = 0;
Expand Down
6 changes: 3 additions & 3 deletions xllm/core/framework/sampling/sampling_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ void SamplingParameters::init(
const std::vector<int32_t>& unique_token_lens_vec) {
CHECK_EQ(req_sampling_params.size(), selected_token_idxes.size());
CHECK_GE(req_sampling_params.size(), sample_idxes.size());
CHECK_EQ(req_sampling_params.size(), unique_token_ids_vec.size());
CHECK_EQ(req_sampling_params.size(), unique_token_counts_vec.size());
CHECK_EQ(req_sampling_params.size(), unique_token_lens_vec.size());

std::vector<float> frequency_penalties;
std::vector<float> presence_penalties;
Expand Down Expand Up @@ -118,6 +115,9 @@ void SamplingParameters::init(
this->selected_token_idxes =
torch::tensor(selected_token_idxes, int_tensor_options);
if (need_token_stats) {
CHECK_EQ(req_sampling_params.size(), unique_token_ids_vec.size());
CHECK_EQ(req_sampling_params.size(), unique_token_counts_vec.size());
CHECK_EQ(req_sampling_params.size(), unique_token_lens_vec.size());
this->unique_token_ids =
create_2d_tensor(unique_token_ids_vec, torch::kInt64);
this->unique_token_counts =
Expand Down
Loading