diff --git a/fastdeploy/cache_manager/cache_data.py b/fastdeploy/cache_manager/cache_data.py index 631f5efb05a..0708178c9af 100644 --- a/fastdeploy/cache_manager/cache_data.py +++ b/fastdeploy/cache_manager/cache_data.py @@ -21,18 +21,6 @@ logger = get_logger("prefix_cache_manager", "cache_manager.log") -DISABLE_PREFIX_CACHE_MM_MODEL: set[str] = { - "Ernie5ForCausalLM", -} - - -def is_mm_model_disable_prefix_cache(model_config): - """ - check if the model architecture is in DISABLE_PREFIX_CACHE_MM_MODEL - """ - return model_config._architecture in DISABLE_PREFIX_CACHE_MM_MODEL - - class CacheStatus(Enum): """ cache status enum class diff --git a/fastdeploy/cache_manager/multimodal_cache_manager.py b/fastdeploy/cache_manager/multimodal_cache_manager.py index febce1bc203..379340c5d0b 100644 --- a/fastdeploy/cache_manager/multimodal_cache_manager.py +++ b/fastdeploy/cache_manager/multimodal_cache_manager.py @@ -53,9 +53,6 @@ def apply_cache(self, mm_hashes: list[str], mm_items: list[Any]) -> list[str]: else: item_size = self.get_item_size(mm_items[idx]) if self.current_cache_size + item_size >= self.max_cache_size: - if item_size > self.max_cache_size: - # cannot be inserted even if we clear all cached data, skip it directly - continue needed = item_size - (self.max_cache_size - self.current_cache_size) evicted_hashes.extend(self.evict_cache(needed)) self.cache[mm_hashes[idx]] = mm_items[idx] diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 58e7c4f3144..cbfd71e5f2f 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1591,7 +1591,7 @@ def __init__( and self.model_config is not None and self.model_config.enable_mm ): - self.max_prefill_batch = 1 # TODO:当前多模prefill阶段只支持并行度为1,待优化 + self.max_prefill_batch = 1 # TODO:当前V0多模prefill阶段只支持并行度为1,待优化 else: self.max_prefill_batch = self.scheduler_config.max_num_seqs diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index f8cf662ef73..a75edd8b00e 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -552,8 +552,6 @@ def __post_init__(self): if "PaddleOCR" in get_model_architecture(self.model, self.model_config_name): envs.FD_ENABLE_MAX_PREFILL = 1 - self.enable_prefix_caching = False - self.max_encoder_cache = 0 @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 2494881da04..65b335d56ad 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -333,26 +333,6 @@ def _update_mm_hashes(self, request): inputs["mm_positions"] = [] inputs["mm_hashes"] = [] - def _is_mm_request(self, request): - inputs = request.multimodal_inputs - if inputs is None or len(inputs) == 0: - return False - - if ( - (inputs.get("video_feature_urls") is not None and len(inputs["video_feature_urls"]) > 0) - or (inputs.get("image_feature_urls") is not None and len(inputs["image_feature_urls"]) > 0) - or (inputs.get("audio_feature_urls") is not None and len(inputs["audio_feature_urls"]) > 0) - ): - return True - elif ( - inputs.get("images", None) is not None - and inputs.get("image_patch_id", None) is not None - and inputs.get("grid_thw", None) is not None - ): - return True - - return False - def _get_num_new_tokens(self, request, token_budget): # TODO: set condition to new _get_num_new_tokens num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens @@ -484,20 +464,14 @@ def _get_num_new_tokens(self, request, token_budget): request.image_start = np.sum(np.prod(grid_thw[: request.num_image_start], axis=1)) request.image_end = np.sum(np.prod(grid_thw[: request.num_image_end], axis=1)) - cur_mm_hashes = inputs["mm_hashes"][request.num_image_start : request.num_image_end] - cur_mm_positions = inputs["mm_positions"][request.num_image_start : request.num_image_end] if self.encoder_cache: + cur_mm_hashes = inputs["mm_hashes"][request.num_image_start : request.num_image_end] + cur_mm_positions = inputs["mm_positions"][request.num_image_start : request.num_image_end] request.evict_mm_hashes = self.encoder_cache.apply_cache(cur_mm_hashes, cur_mm_positions) # Compatible with scenarios without images and videos. return num_new_tokens - def exist_mm_prefill(self, scheduled_reqs): - for request in scheduled_reqs: - if request.task_type == RequestType.PREFILL and self._is_mm_request(request): - return True - return False - def exist_prefill(self, scheduled_reqs): for request in scheduled_reqs: if request.task_type == RequestType.PREFILL: @@ -654,11 +628,7 @@ def _allocate_decode_and_extend(): break request = self.waiting[0] - if ( - not envs.FD_ENABLE_MAX_PREFILL - and self._is_mm_request(request) - and self.exist_mm_prefill(scheduled_reqs) - ) or (paddle.is_compiled_with_xpu() and self.exist_prefill(scheduled_reqs)): + if paddle.is_compiled_with_xpu() and self.exist_prefill(scheduled_reqs): break if request.status == RequestStatus.WAITING: result = self._waiting_async_process(request) diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 7d387acc609..f7d58ea22cd 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -82,13 +82,6 @@ def __init__(self, pid: int | str, port: int | str, fd_config: FDConfig, workers self.enable_splitwise = self.fd_config.scheduler_config.splitwise_role != "mixed" self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8 - if self.enable_mm and self.enable_prefix_caching: - from fastdeploy.cache_manager.cache_data import ( - is_mm_model_disable_prefix_cache, - ) - - self.disable_prefix_mm = is_mm_model_disable_prefix_cache(self.fd_config.model_config) - if self.tensor_parallel_size <= self.max_chips_per_node: self.is_master = True else: @@ -265,16 +258,6 @@ async def format_and_add_data(self, prompts: dict): await self.add_requests(prompts) return prompts["prompt_token_ids"] - def _check_mm_disable_prefix_cache(self, task): - is_multimodal_data = False - if self.disable_prefix_mm: - multimodal_inputs = task.get("multimodal_inputs", []) - if multimodal_inputs: - token_type_ids = multimodal_inputs.get("token_type_ids", []) - if token_type_ids: - is_multimodal_data = np.sum(token_type_ids) > 0 - return is_multimodal_data - async def add_requests(self, task): """ Add a new request to the queue. @@ -298,16 +281,6 @@ async def add_requests(self, task): else: self.data_processor.process_request_dict(task, self.max_model_len) - if self.enable_mm and self.enable_prefix_caching: - if self._check_mm_disable_prefix_cache(task): - api_server_logger.error( - "The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache" - ) - raise EngineError( - "The current service does not support processing requests containing multimodal data when prefix cache is enabled. Please send only text-based requests or disable prefix cache", - error_code=400, - ) - task["prompt_token_ids_len"] = len(task["prompt_token_ids"]) input_ids_len = task["prompt_token_ids_len"] diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py b/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py index dfc0644e556..552be13373b 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/modeling_resampler.py @@ -218,15 +218,11 @@ def spatial_conv_reshape(self, x, spatial_conv_size): x = x.reshape([-1, C * (spatial_conv_size**2)]) return x - def forward(self, x, image_mask, token_type_ids, image_type_ids, grid_thw): + def forward(self, x, grid_thw): """ x: image_features - image_mask: [B] - token_types_ids: [B] - image_type_ids: [B_image] grid_thw: [B_image, 3] """ - assert image_type_ids is not None def fwd_spatial(x): """ diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index a91611524ac..f186f5e366e 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -14,6 +14,7 @@ # limitations under the License. """ +import copy import os import queue import time @@ -28,7 +29,7 @@ from fastdeploy.config import FDConfig from fastdeploy.engine.pooling_params import PoolingParams -from fastdeploy.engine.request import Request, RequestType +from fastdeploy.engine.request import ImagePosition, Request, RequestType from fastdeploy.model_executor.graph_optimization.utils import ( GPUMemoryChecker, profile_run_guard, @@ -367,188 +368,242 @@ def _init_logits_processor(self, request) -> tuple[Future[LogitsProcessorBase],] schemata_key, ) - def get_chunked_inputs(self, req: Request): + def _process_mm_features(self, request_list: List[Request]): """ - Get inputs in current chunk + Process and cache vision features from model + - add image_features, extract and cache vision features from model + - add rope_emb, rotate position embeddings """ - prefill_start_index = req.prefill_start_index - prefill_end_index = req.prefill_end_index - inputs = req.multimodal_inputs - input_ids = inputs["input_ids"][prefill_start_index:prefill_end_index] - token_type_ids = inputs["token_type_ids"][prefill_start_index:prefill_end_index] - image_type_ids = inputs["image_type_ids"][req.image_type_ids_start : req.image_type_ids_end] - images = inputs["images"][req.image_start : req.image_end] - grid_thw = inputs["grid_thw"][req.num_image_start : req.num_image_end] - mm_hashes = inputs["mm_hashes"][req.num_image_start : req.num_image_end] + if not self.enable_mm: + return - return ( - input_ids, - token_type_ids, - image_type_ids, - images, - grid_thw, - mm_hashes, - ) + self.share_inputs["image_features"] = None + multi_vision_inputs = { + "images_lst": [], + "grid_thw_lst": [], + "vit_position_ids_lst": [], + "cu_seqlens": [0], + "encoder_cache_info": [], + "feature_position_list": [], + } + rope_3d_position_ids = { + "position_ids_idx": [], + "position_ids_lst": [], + "position_ids_offset": [0], + "max_tokens_lst": [], + } - def batch_uncached_inputs(self, req: Request): - """ - Batch uncached multimodal inputs - """ - (input_ids, token_type_ids, image_type_ids, images, grid_thw, mm_hashes) = self.get_chunked_inputs(req) - - image_type_ids_size = grid_thw[:, 0] - image_type_ids_split = np.cumsum(image_type_ids_size)[:-1] - image_type_ids_lst = np.array_split(image_type_ids, image_type_ids_split, axis=0) - - images_size = np.prod(grid_thw, axis=1) - images_split = np.cumsum(images_size)[:-1] - images_lst = np.array_split(images, images_split, axis=0) - - assert len(image_type_ids_lst) == len( - mm_hashes - ), f"image_type_ids_lst length {len(image_type_ids_lst)} != mm_hashes length {len(mm_hashes)}" - assert len(images_lst) == len( - mm_hashes - ), f"images_lst length {len(images_lst)} != mm_hashes length {len(mm_hashes)}" - - uncached_image_type_ids = [] - uncached_images = [] - uncached_grid_thw = [] - uncached_mm_hashes = [] - for i, mm_hash in enumerate(mm_hashes): - if mm_hash in self.encoder_cache: + for request in request_list: + if request.task_type.value != RequestType.PREFILL.value: continue - uncached_image_type_ids.append(image_type_ids_lst[i]) - uncached_images.append(images_lst[i]) - uncached_grid_thw.append(grid_thw[i]) - uncached_mm_hashes.append(mm_hash) - - uncached_input_ids = paddle.to_tensor(input_ids, dtype=paddle.int64) - uncached_token_type_ids = paddle.to_tensor(token_type_ids, dtype=paddle.int64) - if len(uncached_mm_hashes) > 0: - uncached_image_type_ids = paddle.to_tensor(np.hstack(uncached_image_type_ids), dtype=paddle.int64) - uncached_images = paddle.to_tensor( - np.vstack(uncached_images), dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16" - ) - uncached_grid_thw = paddle.to_tensor(uncached_grid_thw, dtype=paddle.int64) - return ( - uncached_input_ids, - uncached_token_type_ids, - uncached_image_type_ids, - uncached_images, - uncached_grid_thw, - uncached_mm_hashes, - ) + if self.encoder_cache is not None: + evict_mm_hashes = request.get("evict_mm_hashes", None) + if evict_mm_hashes: + for mm_hash in evict_mm_hashes: + self.encoder_cache.pop(mm_hash, None) + + position_ids = request.multimodal_inputs["position_ids"] + rope_3d_position_ids["position_ids_idx"].append(request.idx) + rope_3d_position_ids["position_ids_lst"].append(position_ids) + rope_3d_position_ids["position_ids_offset"].append( + position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1] + ) - def scatter_and_cache_features(self, image_features, inputs): - """ - Split batched image features and cache them - """ - merge_size = 2 - grid_thw = inputs["grid_thw"] - mm_hashes = inputs["mm_hashes"] - image_features_size = (paddle.prod(grid_thw[:, 1:], axis=1) // (merge_size**2)).tolist() - image_features_lst = paddle.split(image_features, image_features_size, axis=0) - - assert len(image_features_lst) == len( - mm_hashes - ), f"image_features_lst length {len(image_features_lst)} != mm_hashes length {len(mm_hashes)}" - for i, mm_hash in enumerate(mm_hashes): - self.encoder_cache[mm_hash] = image_features_lst[i].cpu() - - def _apply_mm_inputs(self, request: Request, multi_vision_inputs: dict, rope_3d_position_ids: dict): - """ - Apply multimodal inputs to share_inputs - - add image_features, extract and cache vision features from model - - add rope_emb, rotate position embeddings - """ - if self.encoder_cache: - evict_mm_hashes = request.get("evict_mm_hashes", None) - if evict_mm_hashes: - for mm_hash in evict_mm_hashes: - self.encoder_cache.pop(mm_hash, None) - - inputs = request.multimodal_inputs - if request.with_image: - if envs.FD_ENABLE_MAX_PREFILL: - multi_vision_inputs["images_lst"].append( - inputs["images"][request.image_start : request.image_end].cuda() - ) - multi_vision_inputs["grid_thw_lst"].extend( - inputs["grid_thw"][request.num_image_start : request.num_image_end] - ) - if "vit_seqlen" in inputs: - multi_vision_inputs["cu_seqlens"].extend( - inputs["vit_seqlen"][request.num_image_start : request.num_image_end] + if self.is_pooling_model: + rope_3d_position_ids["max_tokens_lst"].append(0) + else: + rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048)) + + if request.with_image: + inputs = request.multimodal_inputs + if self.encoder_cache is not None: + if envs.FD_ENABLE_MAX_PREFILL: + if "vit_seqlen" in inputs: + vit_seqlen_list = inputs["vit_seqlen"][request.num_image_start : request.num_image_end] + if "vit_position_ids" in inputs: + vit_position_ids_list = inputs["vit_position_ids"][ + request.num_image_start : request.num_image_end + ] + grid_thw_list = inputs["grid_thw"][request.num_image_start : request.num_image_end] + mm_hashes_list = inputs["mm_hashes"][request.num_image_start : request.num_image_end] + feature_positions = self._get_feature_positions( + mm_positions=inputs["mm_positions"][request.num_image_start : request.num_image_end], + prefill_start_index=request.prefill_start_index, + prefill_end_index=request.prefill_end_index, ) - if "vit_position_ids" in inputs: - multi_vision_inputs["vit_position_ids_lst"].extend( - inputs["vit_position_ids"][request.num_image_start : request.num_image_end] + image_start_idx = request.num_image_start + + logger.debug( + f"request {request.request_id} start process encoder info, image_start_idx: {image_start_idx} " + f"grid_thw_list: {grid_thw_list}, feature_positions: {feature_positions}, mm_hashes_list: {mm_hashes_list}" ) - else: - vision_inputs = inputs - if self.encoder_cache: - ( - vision_inputs["input_ids"], - vision_inputs["token_type_ids"], - vision_inputs["image_type_ids"], - vision_inputs["images"], - vision_inputs["grid_thw"], - vision_inputs["mm_hashes"], - ) = self.batch_uncached_inputs(request) - if len(vision_inputs["mm_hashes"]) > 0: - # uncached multimodal inputs exist - image_features = self.extract_vision_features(vision_inputs) - self.scatter_and_cache_features(image_features, vision_inputs) - - full_image_features_lst = [] - for mm_hash in inputs["mm_hashes"][request.num_image_start : request.num_image_end]: - feature = self.encoder_cache[mm_hash].cuda() - full_image_features_lst.append(feature) - image_features = paddle.concat(full_image_features_lst, axis=0) + for i, mm_hash in enumerate(mm_hashes_list): + image_offset = np.prod(grid_thw_list[i]) + logger.debug( + f"run idx {i} with mm_hash {mm_hash} image_offset: {image_offset} grid_thw: {grid_thw_list[i]}" + ) + if mm_hash in self.encoder_cache: + multi_vision_inputs["encoder_cache_info"].append((mm_hash, feature_positions[i], True)) + continue + + multi_vision_inputs["encoder_cache_info"].append((mm_hash, feature_positions[i], False)) + if envs.FD_ENABLE_MAX_PREFILL: + multi_vision_inputs["images_lst"].append( + inputs["images"][image_start_idx : image_start_idx + image_offset].cuda() + ) + multi_vision_inputs["grid_thw_lst"].append(grid_thw_list[i]) + multi_vision_inputs["cu_seqlens"].append(vit_seqlen_list[i]) + multi_vision_inputs["vit_position_ids_lst"].append(vit_position_ids_list[i]) + else: + multi_vision_inputs["images_lst"].append( + paddle.to_tensor( + inputs["images"][image_start_idx : image_start_idx + image_offset], + dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16", + ) + ) + multi_vision_inputs["grid_thw_lst"].append( + paddle.to_tensor(grid_thw_list[i], dtype=paddle.int64) + ) + image_start_idx += image_offset else: - ( - input_ids, - token_type_ids, - image_type_ids, - images, - grid_thw, - mm_hashes, - ) = self.get_chunked_inputs(request) - vision_inputs["input_ids"] = paddle.to_tensor(input_ids, dtype=paddle.int64) - vision_inputs["token_type_ids"] = paddle.to_tensor(token_type_ids, dtype=paddle.int64) - vision_inputs["image_type_ids"] = paddle.to_tensor(image_type_ids, dtype=paddle.int64) - vision_inputs["images"] = paddle.to_tensor( - images, dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16" + if envs.FD_ENABLE_MAX_PREFILL: + multi_vision_inputs["images_lst"].append( + inputs["images"][request.image_start : request.image_end].cuda() + ) + multi_vision_inputs["grid_thw_lst"].extend( + inputs["grid_thw"][request.num_image_start : request.num_image_end] + ) + multi_vision_inputs["cu_seqlens"].extend( + inputs["vit_seqlen"][request.num_image_start : request.num_image_end] + ) + multi_vision_inputs["vit_position_ids_lst"].extend( + inputs["vit_position_ids"][request.num_image_start : request.num_image_end] + ) + else: + multi_vision_inputs["images_lst"].append( + paddle.to_tensor( + inputs["images"][request.image_start : request.image_end], + dtype="uint8" if "ernie" in self.model_config.model_type else "bfloat16", + ) + ) + multi_vision_inputs["grid_thw_lst"].extend( + paddle.to_tensor( + inputs["grid_thw"][request.num_image_start : request.num_image_end], + dtype=paddle.int64, + ) + ) + + multi_vision_inputs["feature_position_list"].extend( + self._get_feature_positions( + mm_positions=inputs["mm_positions"][request.num_image_start : request.num_image_end], + prefill_start_index=request.prefill_start_index, + prefill_end_index=request.prefill_end_index, + ) ) - vision_inputs["grid_thw"] = paddle.to_tensor(grid_thw, dtype=paddle.int64) - vision_inputs["mm_hashes"] = mm_hashes - - image_features = self.extract_vision_features(vision_inputs) - - # part of the first image may be already cached - if "ernie" in self.model_config.model_type: - actual_image_token_num = paddle.sum(vision_inputs["input_ids"] == self.model_config.im_patch_id) - elif "qwen" in self.model_config.model_type: - actual_image_token_num = paddle.sum( - vision_inputs["input_ids"] == vision_inputs["image_patch_id"] - ) + paddle.sum(vision_inputs["input_ids"] == vision_inputs["video_patch_id"]) - else: - raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported") - self.share_inputs["image_features"] = image_features[-actual_image_token_num:] - - position_ids = request.multimodal_inputs["position_ids"] - rope_3d_position_ids["position_ids_idx"].append(request.idx) - rope_3d_position_ids["position_ids_lst"].append(position_ids) - rope_3d_position_ids["position_ids_offset"].append( - position_ids.shape[0] + rope_3d_position_ids["position_ids_offset"][-1] - ) - if self.is_pooling_model: - rope_3d_position_ids["max_tokens_lst"].append(0) - else: - rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048)) + if self.encoder_cache is not None: + if len(multi_vision_inputs["images_lst"]) > 0 or len(multi_vision_inputs["encoder_cache_info"]) > 0: + image_features_output = None + if len(multi_vision_inputs["images_lst"]) > 0: + image_features_output = self.extract_vision_features(multi_vision_inputs) + + logger.debug(f"encoder_cache_info: {multi_vision_inputs['encoder_cache_info']}") + merge_image_features, feature_idx, thw_idx = [], 0, 0 + for mm_hash, feature_position, use_cache in multi_vision_inputs["encoder_cache_info"]: + if use_cache: + assert mm_hash in self.encoder_cache, f"{mm_hash} not in encoder cache" + mm_feature = self.encoder_cache[mm_hash].cuda() + else: + assert ( + image_features_output is not None + ), f"image_features_output is None, images_lst length: {len(multi_vision_inputs['images_lst'])}" + mm_token_lenght = paddle.prod(multi_vision_inputs["grid_thw_lst"][thw_idx]) // 4 + mm_feature = image_features_output[feature_idx : feature_idx + mm_token_lenght] + + # add feature to encoder cache + self.encoder_cache[mm_hash] = mm_feature.detach().cpu() + feature_idx += mm_token_lenght + thw_idx += 1 + + feature_start = feature_position.offset + feature_end = feature_position.offset + feature_position.length + merge_image_features.append(mm_feature[feature_start:feature_end]) + + self.share_inputs["image_features"] = paddle.concat(merge_image_features, axis=0) + logger.debug( + f"merge_image_features length: {len(merge_image_features)}, features shape: {self.share_inputs['image_features'].shape}" + ) + elif len(multi_vision_inputs["images_lst"]) > 0: + assert len(multi_vision_inputs["feature_position_list"]) == len( + multi_vision_inputs["grid_thw_lst"] + ), f"{multi_vision_inputs['feature_position_list']} != {multi_vision_inputs['grid_thw_lst']}" + + merge_image_features, feature_idx, thw_idx = [], 0, 0 + image_features_output = self.extract_vision_features(multi_vision_inputs) + for feature_position in multi_vision_inputs["feature_position_list"]: + mm_token_lenght = paddle.prod(multi_vision_inputs["grid_thw_lst"][thw_idx]) // 4 + mm_feature = image_features_output[feature_idx : feature_idx + mm_token_lenght] + + feature_start = feature_position.offset + feature_end = feature_position.offset + feature_position.length + merge_image_features.append(mm_feature[feature_start:feature_end]) + feature_idx += mm_token_lenght + thw_idx += 1 + self.share_inputs["image_features"] = paddle.concat(merge_image_features, axis=0) + + if len(rope_3d_position_ids["position_ids_idx"]) > 0: + packed_position_ids = paddle.to_tensor( + np.concatenate(rope_3d_position_ids["position_ids_lst"]), dtype="int64" + ) + rope_3d_lst = self.prepare_rope3d( + packed_position_ids, + rope_3d_position_ids["max_tokens_lst"], + rope_3d_position_ids["position_ids_offset"], + ) + for i, idx in enumerate(rope_3d_position_ids["position_ids_idx"]): + self.share_inputs["rope_emb"][idx : idx + 1, :] = rope_3d_lst[i] + + def _get_feature_positions( + self, mm_positions: List[ImagePosition], prefill_start_index: int, prefill_end_index: int + ): + """ + Filter and adjust ImagePosition objects that fall within the specified prefill range. + + Args: + mm_positions: List of ImagePosition objects to filter + prefill_start_index: Start index of the prefill range + prefill_end_index: End index of the prefill range + + Returns: + List of ImagePosition objects that are within or intersect with the prefill range + """ + feature_positions = [] + for position in mm_positions: + position_start = position.offset + position_end = position.offset + position.length + if position_end <= prefill_start_index or position_start >= prefill_end_index: + continue + elif position_start >= prefill_start_index and position_end <= prefill_end_index: + new_position = copy.deepcopy(position) + new_position.offset = 0 + feature_positions.append(new_position) + else: + new_position = copy.deepcopy(position) + # Adjust offset if it starts before prefill_start_index + if position_start < prefill_start_index: + new_position.offset = prefill_start_index - position_start + new_position.length = min(position_end, prefill_end_index) - prefill_start_index + # Adjust length if it extends beyond prefill_end_index + elif position_end > prefill_end_index: + new_position.offset = 0 + new_position.length = prefill_end_index - position_start + feature_positions.append(new_position) + + logger.debug( + f"get feature_positions, original positions: {mm_positions}, filtered positions: {feature_positions}" + ) + return feature_positions def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None): """ @@ -565,15 +620,6 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = has_decode_task = False batch_pooling_params = [] - self.share_inputs["image_features"] = None - multi_vision_inputs = {"images_lst": [], "grid_thw_lst": [], "vit_position_ids_lst": [], "cu_seqlens": [0]} - rope_3d_position_ids = { - "position_ids_idx": [], - "position_ids_lst": [], - "position_ids_offset": [0], - "max_tokens_lst": [], - } - for i in range(req_len): request = req_dicts[i] idx = request.idx @@ -606,9 +652,6 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = prefill_start_index = request.prefill_start_index prefill_end_index = request.prefill_end_index length = prefill_end_index - prefill_start_index - if self.enable_mm: - self._apply_mm_inputs(request, multi_vision_inputs, rope_3d_position_ids) - if not self.is_pooling_model: if request.get("enable_thinking", False) and request.get("reasoning_max_tokens", None) is not None: # Enable thinking @@ -739,21 +782,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.sampler.apply_logits_processor(idx, logits_info, prefill_tokens) - if len(multi_vision_inputs["images_lst"]) > 0: - self.share_inputs["image_features"] = self.extract_vision_features(multi_vision_inputs) - - if len(rope_3d_position_ids["position_ids_idx"]) > 0: - packed_position_ids = paddle.to_tensor( - np.concatenate(rope_3d_position_ids["position_ids_lst"]), dtype="int64" - ) - rope_3d_lst = self.prepare_rope3d( - packed_position_ids, - rope_3d_position_ids["max_tokens_lst"], - rope_3d_position_ids["position_ids_offset"], - ) - for i, idx in enumerate(rope_3d_position_ids["position_ids_idx"]): - self.share_inputs["rope_emb"][idx : idx + 1, :] = rope_3d_lst[i] - + self._process_mm_features(req_dicts) if has_prefill_task or has_decode_task: self.share_inputs["not_need_stop"][0] = True self.share_inputs["seq_lens_this_time"] = self.seq_lens_this_time_buffer[:num_running_requests] @@ -2700,21 +2729,19 @@ def _preprocess_mm_task(self, one: dict) -> None: ) return result - def extract_vision_features_ernie(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: - assert inputs["images"] is not None - grid_thw = inputs["grid_thw"] + def extract_vision_features_ernie(self, vision_inputs: dict[str, list[paddle.Tensor]]) -> paddle.Tensor: + """ + vision feature extactor for ernie-vl + """ + assert len(vision_inputs["images_lst"]) > 0, "at least one image needed" + + grid_thw = paddle.to_tensor(vision_inputs["grid_thw_lst"], dtype=paddle.int64) # ernie-vl has images norm - images = inputs["images"].cast("float32") + images = paddle.concat(vision_inputs["images_lst"]).cast("float32") images = self.image_preprocess.rescale_factor * images - self.image_preprocess.image_mean_tensor images = images / self.image_preprocess.image_std_tensor images = images.cast("bfloat16") - token_type_ids = inputs["token_type_ids"] - token_type_ids_w_video = token_type_ids - input_ids = inputs["input_ids"] - # convert to img patch id - image_mask = input_ids == self.model_config.im_patch_id - image_type_ids = inputs["image_type_ids"] with paddle.amp.auto_cast( True, custom_black_list=self.amp_black, @@ -2731,21 +2758,15 @@ def extract_vision_features_ernie(self, inputs: list[paddle.Tensor]) -> paddle.T # ernie-vl has resampler_model image_features = self.model.resampler_model( image_features, - image_mask, - token_type_ids_w_video, - image_type_ids, grid_thw, ) return image_features - def extract_vision_features_qwen(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: - if envs.FD_ENABLE_MAX_PREFILL: - images = paddle.concat(inputs["images_lst"]).cast("bfloat16") - grid_thw = paddle.to_tensor(inputs["grid_thw_lst"], dtype="int64") - else: - assert inputs["images"] is not None - grid_thw = inputs["grid_thw"] - images = inputs["images"] + def extract_vision_features_qwen(self, vision_inputs: dict[str, list[paddle.Tensor]]) -> paddle.Tensor: + assert len(vision_inputs["images_lst"]) > 0, "at least one image needed" + + grid_thw = paddle.to_tensor(vision_inputs["grid_thw_lst"], dtype=paddle.int64) + images = paddle.concat(vision_inputs["images_lst"]).cast("bfloat16") with paddle.amp.auto_cast( True, custom_black_list=self.amp_black, @@ -2757,7 +2778,7 @@ def extract_vision_features_qwen(self, inputs: list[paddle.Tensor]) -> paddle.Te return image_features - def extract_vision_features_paddleocr(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: + def extract_vision_features_paddleocr(self, inputs: dict[str, list[paddle.Tensor]]) -> paddle.Tensor: if envs.FD_ENABLE_MAX_PREFILL: inputs["vit_position_ids_lst"] = np.concatenate(inputs["vit_position_ids_lst"]) images = paddle.concat(inputs["images_lst"]).cast("bfloat16") @@ -2801,14 +2822,14 @@ def extract_vision_features_paddleocr(self, inputs: list[paddle.Tensor]) -> padd return image_features @paddle.no_grad() - def extract_vision_features(self, inputs: list[paddle.Tensor]) -> paddle.Tensor: + def extract_vision_features(self, multi_vision_inputs: dict[str, list[paddle.Tensor]]) -> paddle.Tensor: """extract_vision_features""" if "ernie" in self.model_config.model_type: - return self.extract_vision_features_ernie(inputs) + return self.extract_vision_features_ernie(multi_vision_inputs) elif "qwen" in self.model_config.model_type: - return self.extract_vision_features_qwen(inputs) + return self.extract_vision_features_qwen(multi_vision_inputs) elif "paddleocr" in self.model_config.model_type: - return self.extract_vision_features_paddleocr(inputs) + return self.extract_vision_features_paddleocr(multi_vision_inputs) else: raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported") diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 02d66f4bc53..7e7d4f08c75 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -1011,8 +1011,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: architecture = fd_config.model_config.architectures[0] if "PaddleOCR" in architecture: envs.FD_ENABLE_MAX_PREFILL = 1 - fd_config.cache_config.enable_prefix_caching = False - fd_config.cache_config.max_encoder_cache = 0 return fd_config diff --git a/tests/worker/test_gpu_model_runner.py b/tests/worker/test_gpu_model_runner.py new file mode 100644 index 00000000000..1d1cabddb30 --- /dev/null +++ b/tests/worker/test_gpu_model_runner.py @@ -0,0 +1,199 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from dataclasses import dataclass +from unittest.mock import Mock + +from fastdeploy.engine.request import ImagePosition +from fastdeploy.worker.gpu_model_runner import GPUModelRunner + + +@dataclass +class TestRequest: + multimodal_inputs: dict = None + + +class TestFeaturePositions(unittest.TestCase): + + def setUp(self): + # Create a mock GPUModelRunner instance for testing + self.mock_fd_config = Mock() + self.mock_model_config = Mock() + self.mock_model_config.enable_mm = True + self.mock_fd_config.model_config = self.mock_model_config + + # Mock other necessary configurations + self.mock_fd_config.scheduler_config = Mock() + self.mock_fd_config.scheduler_config.max_num_seqs = 10 + self.mock_fd_config.parallel_config = Mock() + self.mock_fd_config.parallel_config.tensor_parallel_size = 1 + + self.runner = GPUModelRunner.__new__(GPUModelRunner) + self.runner.fd_config = self.mock_fd_config + self.runner.model_config = self.mock_model_config + + def test_completely_within_range(self): + """Test positions that are completely within the prefill range""" + mm_positions = [ + ImagePosition(offset=10, length=5), # [10, 14] + ImagePosition(offset=15, length=5), # [15, 19] + ] + prefill_start_index = 10 + prefill_end_index = 20 + + result = self.runner._get_feature_positions(mm_positions, prefill_start_index, prefill_end_index) + + self.assertEqual(len(result), 2) + self.assertEqual(result[0].offset, 0) + self.assertEqual(result[0].length, 5) + self.assertEqual(result[1].offset, 0) + self.assertEqual(result[1].length, 5) + + def test_completely_outside_range(self): + """Test positions that are completely outside the prefill range""" + mm_positions = [ + ImagePosition(offset=5, length=3), # [5, 7] - before range + ImagePosition(offset=25, length=5), # [25, 29] - after range + ] + prefill_start_index = 10 + prefill_end_index = 20 + + result = self.runner._get_feature_positions(mm_positions, prefill_start_index, prefill_end_index) + + self.assertEqual(len(result), 0) + + def test_partial_overlap_start(self): + """Test positions that partially overlap at the start of the range""" + mm_positions = [ + ImagePosition(offset=8, length=5), # [8, 12] overlaps with [10, 20] + ] + prefill_start_index = 10 + prefill_end_index = 20 + + result = self.runner._get_feature_positions(mm_positions, prefill_start_index, prefill_end_index) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0].offset, 2) # Adjusted to start at prefill_start_index + self.assertEqual(result[0].length, 3) # Length reduced to fit within range + + def test_partial_overlap_end(self): + """Test positions that partially overlap at the end of the range""" + mm_positions = [ + ImagePosition(offset=8, length=50), # [8, 58] overlaps with [10, 20] + ] + prefill_start_index = 10 + prefill_end_index = 20 + + result = self.runner._get_feature_positions(mm_positions, prefill_start_index, prefill_end_index) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0].offset, 2) # Offset remains the same + self.assertEqual(result[0].length, 10) # Length reduced to fit within range + + def test_exact_range_boundary(self): + """Test positions that exactly match the range boundaries""" + mm_positions = [ + ImagePosition(offset=10, length=10), # Exactly matches [10, 20] + ] + prefill_start_index = 10 + prefill_end_index = 20 + + result = self.runner._get_feature_positions(mm_positions, prefill_start_index, prefill_end_index) + + self.assertEqual(len(result), 1) + self.assertEqual(result[0].offset, 0) + self.assertEqual(result[0].length, 10) + + def test_edge_overlap(self): + """Test positions that exactly touch the range boundaries""" + mm_positions = [ + ImagePosition(offset=20, length=5), # Starts exactly at end boundary but should be excluded + ] + prefill_start_index = 10 + prefill_end_index = 20 + + result = self.runner._get_feature_positions(mm_positions, prefill_start_index, prefill_end_index) + + self.assertEqual(len(result), 0) # Should be excluded - ends at boundary means outside + + def test_multiple_overlapping_positions(self): + """Test mixed positions with different overlap scenarios""" + mm_positions = [ + ImagePosition(offset=5, length=3), # [5, 8] - before range + ImagePosition(offset=8, length=5), # [8, 13] - overlaps start + ImagePosition(offset=13, length=6), # [13, 19] - completely within + ImagePosition(offset=19, length=5), # [19, 24] - overlaps end + ImagePosition(offset=24, length=3), # [24, 27] - after range + ] + prefill_start_index = 10 + prefill_end_index = 20 + + result = self.runner._get_feature_positions(mm_positions, prefill_start_index, prefill_end_index) + self.assertEqual(len(result), 3) + + # First position (overlapping start) + self.assertEqual(result[0].offset, 2) + self.assertEqual(result[0].length, 3) + + # Second position (completely within) + self.assertEqual(result[1].offset, 0) + self.assertEqual(result[1].length, 6) + + # Third position (overlapping end) + self.assertEqual(result[2].offset, 0) + self.assertEqual(result[2].length, 1) + + def test_zero_length_range(self): + """Test with zero-length prefill range""" + mm_positions = [ + ImagePosition(offset=10, length=5), + ] + prefill_start_index = 15 + prefill_end_index = 15 # Zero-length range + + result = self.runner._get_feature_positions(mm_positions, prefill_start_index, prefill_end_index) + + self.assertEqual(len(result), 0) + + def test_empty_positions_list(self): + """Test with an empty positions list""" + mm_positions = [] + prefill_start_index = 10 + prefill_end_index = 20 + + result = self.runner._get_feature_positions(mm_positions, prefill_start_index, prefill_end_index) + + self.assertEqual(len(result), 0) + + def test_identical_positions_copy(self): + """Test that positions within range are correctly deep copied""" + mm_positions = [ + ImagePosition(offset=12, length=5), + ] + prefill_start_index = 10 + prefill_end_index = 20 + + result = self.runner._get_feature_positions(mm_positions, prefill_start_index, prefill_end_index) + + self.assertEqual(len(result), 1) + # Verify it's a copy, not the same object + self.assertIsNot(result[0], mm_positions[0]) + # But has the same values + self.assertEqual(result[0].offset, 0) + self.assertEqual(result[0].length, 5) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/woker/test_gpu_prompt_logprobs.py b/tests/worker/test_gpu_prompt_logprobs.py similarity index 100% rename from tests/woker/test_gpu_prompt_logprobs.py rename to tests/worker/test_gpu_prompt_logprobs.py diff --git a/tests/woker/test_logprobs_output.py b/tests/worker/test_logprobs_output.py similarity index 100% rename from tests/woker/test_logprobs_output.py rename to tests/worker/test_logprobs_output.py