Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 46 additions & 23 deletions custom_ops/gpu_ops/rebuild_padding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,25 +92,28 @@ __global__ void RebuildAppendPaddingKernel(T *output_data,

template <paddle::DataType D>
std::vector<paddle::Tensor> rebuild_padding(
const paddle::Tensor &tmp_out, // [token_num, dim_embed]
const paddle::Tensor &cu_seqlens_q, // [bsz+1, 1]
const paddle::Tensor &tmp_out,
const paddle::Tensor &cu_seqlens_q,
const paddle::Tensor &seq_len_this_time,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_encoder,
const paddle::optional<paddle::Tensor> &output_padding_offset,
const paddle::optional<paddle::Tensor> &first_token_out,
int max_input_length,
bool enable_logprob) {

typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;

#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(tmp_out.place()));
auto dev_ctx = static_cast<const phi::CustomContext*>(
paddle::experimental::DeviceContextPool::Instance().Get(tmp_out.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = tmp_out.stream();
#endif

std::vector<int64_t> tmp_out_shape = tmp_out.shape();
const int token_num = tmp_out_shape[0];
const int dim_embed = tmp_out_shape[1];
Expand All @@ -128,20 +131,20 @@ std::vector<paddle::Tensor> rebuild_padding(
}
}
out = paddle::full({token_num - need_delete_token_num, dim_embed},
0,
D,
tmp_out.place());
0, D, tmp_out.place());
} else {
out =
paddle::full({bsz, dim_embed}, 0, tmp_out.dtype(), tmp_out.place());
out = paddle::full({bsz, dim_embed}, 0, tmp_out.dtype(), tmp_out.place());
}

constexpr int PackSize = VEC_16B / sizeof(DataType_);
int elem_nums = out.numel();
int pack_num = elem_nums / PackSize;
const int elem_nums = out.numel();
const int blocksize = 128;
const int grid_size = (pack_num + blocksize - 1) / blocksize;

if (output_padding_offset) {
// Speculative decoding 分支
int pack_num = (elem_nums + PackSize - 1) / PackSize;
int grid_size = std::max(1, (pack_num + blocksize - 1) / blocksize);

RebuildAppendPaddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(out.data<data_t>()),
Expand All @@ -161,22 +164,42 @@ std::vector<paddle::Tensor> rebuild_padding(
bsz,
enable_logprob);
} else {
RebuildPaddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(out.data<data_t>()),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(tmp_out.data<data_t>())),
cu_seqlens_q.data<int>(),
seq_len_this_time.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_encoder.data<int>(),
max_input_length,
dim_embed,
elem_nums);
const int actual_pack_size = (dim_embed < PackSize) ? 1 : PackSize;
const int pack_num = (elem_nums + actual_pack_size - 1) / actual_pack_size;
const int grid_size = std::max(1, (pack_num + blocksize - 1) / blocksize);

if (actual_pack_size == 1) {
RebuildPaddingKernel<DataType_, 1>
<<<grid_size, blocksize, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(out.data<data_t>()),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(tmp_out.data<data_t>())),
cu_seqlens_q.data<int>(),
seq_len_this_time.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_encoder.data<int>(),
max_input_length,
dim_embed,
elem_nums);
} else {
RebuildPaddingKernel<DataType_, PackSize>
<<<grid_size, blocksize, 0, cu_stream>>>(
reinterpret_cast<DataType_ *>(out.data<data_t>()),
reinterpret_cast<DataType_ *>(
const_cast<data_t *>(tmp_out.data<data_t>())),
cu_seqlens_q.data<int>(),
seq_len_this_time.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_encoder.data<int>(),
max_input_length,
dim_embed,
elem_nums);
}
}
return {out};
}


paddle::Tensor RebuildPaddingFunc(
const paddle::Tensor &tmp_out, // [token_num, dim_embed]
const paddle::Tensor &cu_seqlens_q, // [bsz+1, 1]
Expand Down
7 changes: 5 additions & 2 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def override_name_from_config(self):
self.tensor_parallel_size = self.infer_model_mp_num
del self.infer_model_mp_num

if hasattr(self, "num_hidden_layers"):
if hasattr(self, "num_hidden_layers") and self.runner != "pooling":
if hasattr(self, "remove_tail_layer"):
if self.remove_tail_layer is True:
self.num_hidden_layers -= 1
Expand Down Expand Up @@ -1564,18 +1564,20 @@ def __init__(
self.max_long_partial_prefills = max_long_partial_prefills
self.long_prefill_token_threshold = long_prefill_token_threshold


if envs.FD_FOR_TORCH_MODEL_FORMAT:
self.model_config.model_format = "torch"

# TODO
if not envs.FD_ENABLE_MAX_PREFILL:
self.max_prefill_batch = int(os.getenv("MAX_PREFILL_NUM", "3"))
self.max_prefill_batch = int(os.getenv("MAX_PREFILL_NUM", "10"))
if current_platform.is_xpu():
self.max_prefill_batch = 1
if self.model_config is not None and self.model_config.enable_mm and not envs.ENABLE_V1_KVCACHE_SCHEDULER:
self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化
else:
self.max_prefill_batch = self.scheduler_config.max_num_seqs
# print("self.max_prefill_batch",self.max_prefill_batch)

num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
Expand Down Expand Up @@ -1627,6 +1629,7 @@ def postprocess(self):
self.scheduler_config.max_num_batched_tokens = 2048
else:
self.scheduler_config.max_num_batched_tokens = self.model_config.max_model_len
# print("self.scheduler_config.max_num_bathed_tokens",self.scheduler_config.max_num_batched_tokens)

if self.long_prefill_token_threshold == 0:
self.long_prefill_token_threshold = int(self.model_config.max_model_len * 0.04)
Expand Down
3 changes: 3 additions & 0 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,9 @@ def __post_init__(self):
self.enable_prefix_caching = False
self.max_encoder_cache = 0

if self.runner == "pooling" and self.enable_prefix_caching:
self.enable_prefix_caching = False

@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""
Expand Down
3 changes: 2 additions & 1 deletion fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,14 +693,15 @@ def _fetch_request():
max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens
else:
max_num_batched_tokens = self.cfg.model_config.max_model_len

tasks = self.scheduler.get_requests(
available_blocks=self.cfg.cache_config.max_block_num_per_seq,
block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
max_num_batched_tokens=max_num_batched_tokens,
batch=num_prefill_batch,
)

for task in tasks:
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))

Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/engine/pooling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _set_default_parameters(self, model_config: Optional["ModelConfig"]):
self.softmax = True
elif self.task == "reward":
if self.normalize is None:
self.normalize = True
self.normalize = False
else:
raise ValueError(f"Unknown pooling task: {self.task}")

Expand Down
12 changes: 9 additions & 3 deletions fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import numpy as np
from typing_extensions import TypeVar

from fastdeploy import envs
from fastdeploy.engine.pooling_params import PoolingParams
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.openai.protocol import ToolCall
Expand Down Expand Up @@ -145,7 +144,10 @@ def __init__(
self.multimodal_data = multimodal_data
self.multimodal_img_boundaries = None

self.enable_thinking = enable_thinking
if pooling_params is not None:
self.enable_thinking = False
else:
self.enable_thinking = True
self.reasoning_max_tokens = reasoning_max_tokens
self.trace_carrier = trace_carrier

Expand Down Expand Up @@ -190,6 +192,10 @@ def from_dict(cls, d: dict):
pooling_params = PoolingParams.from_dict(d["pooling_params"])
else:
sampling_params = SamplingParams.from_dict(d)

enable_thinking = d.get("enable_thinking", None)
if pooling_params is not None:
enable_thinking = False
return cls(
request_id=d["request_id"],
prompt=d.get("prompt"),
Expand All @@ -216,7 +222,7 @@ def from_dict(cls, d: dict):
guided_grammar=d.get("guided_grammar", None),
structural_tag=d.get("structural_tag", None),
guided_json_object=d.get("guided_json_object", None),
enable_thinking=d.get("enable_thinking", None),
enable_thinking=enable_thinking,
reasoning_max_tokens=d.get("reasoning_max_tokens", None),
trace_carrier=d.get("trace_carrier", {}),
chat_template=d.get("chat_template", None),
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/engine/sched/resource_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,7 @@ def _allocate_decode_and_extend():

request = self.waiting[0]
if (self._is_mm_request(request) and self.exist_mm_prefill(scheduled_reqs)) or (
paddle.is_compiled_with_xpu() and self.exist_prefill(scheduled_reqs)
paddle.is_compiled_with_xpu() and self.exist_prefill(scheduled_reqs)
):
break
if request.status == RequestStatus.WAITING:
Expand Down
10 changes: 0 additions & 10 deletions fastdeploy/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,16 +889,6 @@ class EmbeddingChatRequest(BaseModel):
user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None

# --8<-- [start:chat-embedding-extra-params]
add_generation_prompt: bool = Field(
default=False,
description=(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)

add_special_tokens: bool = Field(
default=False,
description=(
Expand Down
1 change: 0 additions & 1 deletion fastdeploy/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def _process_chat_template_kwargs(self, request_dict):
chat_template_kwargs.update(
{
"chat_template": request_dict.get("chat_template"),
"add_generation_prompt": request_dict.get("add_generation_prompt"),
"add_stop_sequences": request_dict.get("add_stop_sequences"),
}
)
Expand Down
16 changes: 9 additions & 7 deletions fastdeploy/model_executor/layers/pool/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
"""

from dataclasses import dataclass
from typing import Optional
from typing import Optional, Union

import paddle

from fastdeploy.engine.pooling_params import PoolingParams

Device = Union[paddle.CPUPlace, paddle.CUDAPlace, paddle.XPUPlace]


@dataclass
class PoolingCursor:
Expand Down Expand Up @@ -60,21 +62,21 @@ def __getitem__(self, indices: slice):
pooling_cursor=None if self.pooling_cursor is None else self.pooling_cursor[indices],
)

def build_pooling_cursor(self, num_scheduled_tokens: list[int], device: str):
def build_pooling_cursor(self, num_scheduled_tokens: list[int], device: Device):
self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens, self.prompt_lens, device)


def build_pooling_cursor(num_scheduled_tokens: list[int], prompt_lens: paddle.Tensor, device: str):
def build_pooling_cursor(num_scheduled_tokens: list[int], prompt_lens: paddle.Tensor, device: Device):
assert len(prompt_lens) == len(num_scheduled_tokens)

n_seq = len(num_scheduled_tokens)
index = list(range(n_seq))
num_scheduled_tokens = paddle.to_tensor(num_scheduled_tokens)
cumsum = paddle.zeros([n_seq + 1], dtype="int64")
num_scheduled_tokens = paddle.to_tensor(num_scheduled_tokens, dtype="int64", place=paddle.CPUPlace())
cumsum = paddle.zeros([n_seq + 1], dtype="int64", device=paddle.CPUPlace())

paddle.cumsum(num_scheduled_tokens, axis=0, out=cumsum[1:])
if device == "gpu":
cumsum_device = cumsum.cuda()
if isinstance(device, paddle.CUDAPlace):
cumsum_device = paddle.assign(cumsum).cuda(device.get_device_id())
else:
cumsum_device = cumsum
return PoolingCursor(
Expand Down
Loading