From 54f0ed4164ce8be66836c51733ee088c008eaeed Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Tue, 11 Nov 2025 14:21:48 +0800 Subject: [PATCH 1/4] [Optimize] Improve perf for fd response token with internal adapter --- fastdeploy/engine/common_engine.py | 12 +++- fastdeploy/envs.py | 2 +- fastdeploy/inter_communicator/zmq_server.py | 67 +++++++++++++++++---- fastdeploy/scheduler/local_scheduler.py | 17 ++++-- 4 files changed, 79 insertions(+), 19 deletions(-) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 8c22a761e06..cfb02cc63fd 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -802,7 +802,10 @@ def _insert_zmq_task_to_scheduler(self): ) # Since the request is not in scheduler # Send result by zmq directly - self.send_response_server.send_response(request_id, [error_result]) + if envs.FD_ENABLE_INTERNAL_ADAPTER: + self.send_response_server.send_response(None, [[error_result]]) + else: + self.send_response_server.send_response(request_id, [error_result]) except Exception as e: self.llm_logger.error( f"Error happened while receiving new request from zmq, details={e}, " @@ -819,8 +822,11 @@ def _zmq_send_generated_tokens(self): if len(results) == 0: time.sleep(0.005) continue - for request_id, contents in results.items(): - self.send_response_server.send_response(request_id, contents) + if envs.FD_ENABLE_INTERNAL_ADAPTER: + self.send_response_server.send_response(None, results) + else: + for request_id, contents in results.items(): + self.send_response_server.send_response(request_id, contents) except Exception as e: self.llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}") diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 01bf0f2d9bc..e3051a4e94a 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -44,7 +44,7 @@ # Whether to use HuggingFace tokenizer. "FD_USE_HF_TOKENIZER": lambda: bool(int(os.getenv("FD_USE_HF_TOKENIZER", "0"))), # Set the high watermark (HWM) for receiving data during ZMQ initialization - "FD_ZMQ_SNDHWM": lambda: os.getenv("FD_ZMQ_SNDHWM", 64000), + "FD_ZMQ_SNDHWM": lambda: os.getenv("FD_ZMQ_SNDHWM", 0), # cache kv quant params directory "FD_CACHE_PARAMS": lambda: os.getenv("FD_CACHE_PARAMS", "none"), # Set attention backend. "NATIVE_ATTN", "APPEND_ATTN" diff --git a/fastdeploy/inter_communicator/zmq_server.py b/fastdeploy/inter_communicator/zmq_server.py index 72eb734c6bd..8d0795d5fc2 100644 --- a/fastdeploy/inter_communicator/zmq_server.py +++ b/fastdeploy/inter_communicator/zmq_server.py @@ -35,6 +35,9 @@ class ZmqServerBase(ABC): def __init__(self): self.cached_results = defaultdict(list) self.response_token_lock = threading.Lock() + self.response_handle_per_step = None + self.response_handle_name_per_step = None + self.batch_id_per_step = 0 @abstractmethod def _create_socket(self): @@ -125,16 +128,20 @@ def recv_result_handle(self): with self.response_token_lock: client, _, request_id = self.socket.recv_multipart(flags=zmq.NOBLOCK) req_id_str = request_id.decode("utf-8") - need_send_after_finished_inference = False - with self.mutex: - self.req_dict[req_id_str] = client - if req_id_str in self.cached_results: - if self.cached_results[req_id_str][-1][-1].finished: - need_send_after_finished_inference = True - if need_send_after_finished_inference: - self.send_response(req_id_str, []) - llm_logger.info(f"send_multipart finished, req_id: {req_id_str}") - self.req_dict.pop(req_id_str, None) + if envs.FD_ENABLE_INTERNAL_ADAPTER: + with self.mutex: + self.response_handle_per_step = client + else: + need_send_after_finished_inference = False + with self.mutex: + self.req_dict[req_id_str] = client + if req_id_str in self.cached_results: + if self.cached_results[req_id_str][-1][-1].finished: + need_send_after_finished_inference = True + if need_send_after_finished_inference: + self.send_response(req_id_str, []) + llm_logger.info(f"send_multipart finished, req_id: {req_id_str}") + self.req_dict.pop(req_id_str, None) except zmq.Again: time.sleep(0.001) @@ -143,7 +150,39 @@ def recv_result_handle(self): llm_logger.error(f"recv_result_handle get unknown exception: {e}") continue - def send_response(self, req_id, data): + def _send_response_per_step(self, batch_id, data): + """ + Send generated token result to client. + """ + self._ensure_socket() + if self.socket is None: + raise RuntimeError("Router socket not created. Call create_router() first.") + need_send_data = [] + with self.mutex: + if self.response_handle_per_step is None: + self.cached_results["data"].extend(data) + else: + need_send_data = self.cached_results["data"] + self.cached_results["data"] = [] + if self.response_handle_per_step is not None: + try: + if data: + need_send_data.extend(data) + start_send = time.time() + result = msgpack.packb( + [[response.to_dict() for response in send_data_list] for send_data_list in need_send_data] + ) + with self.response_token_lock: + self.socket.send_multipart([self.response_handle_per_step, b"", result]) + llm_logger.info( + f"send_multipart result: step {self.batch_id_per_step} lens {len(need_send_data)} elapse: {time.time()-start_send}" + ) + self.batch_id_per_step += 1 + + except Exception as e: + llm_logger.error(f"Send result to zmq client failed: {e}") + + def _send_response_per_query(self, req_id, data): """ Send generated token result to client. """ @@ -187,6 +226,12 @@ def send_response(self, req_id, data): llm_logger.info(f"send_multipart finished, req_id: {req_id}") self.req_dict.pop(req_id, None) + def send_response(self, req_id, data): + if envs.FD_ENABLE_INTERNAL_ADAPTER: + self._send_response_per_step(req_id, data) + else: + self._send_response_per_query(req_id, data) + @abstractmethod def close(self): pass diff --git a/fastdeploy/scheduler/local_scheduler.py b/fastdeploy/scheduler/local_scheduler.py index fd4ec37572e..f96d3cf9541 100644 --- a/fastdeploy/scheduler/local_scheduler.py +++ b/fastdeploy/scheduler/local_scheduler.py @@ -20,7 +20,7 @@ from fastdeploy.engine.request import Request, RequestOutput from fastdeploy.scheduler.data import ScheduledRequest, ScheduledResponse -from fastdeploy.utils import scheduler_logger +from fastdeploy.utils import envs, scheduler_logger class LocalScheduler: @@ -79,6 +79,7 @@ def __init__( self.requests: Dict[str, ScheduledRequest] = dict() self.responses: Dict[str, List[ScheduledResponse]] = dict() + self.batch_responses_per_step: List[List[ScheduledResponse]] = list() self.wait_request_timeout = 10 self.wait_response_timeout = 0.001 @@ -298,6 +299,7 @@ def put_results(self, results: List[RequestOutput]): scheduler_logger.info(f"Scheduler has received some finished responses: {finished_responses}") with self.mutex: + self.batch_responses_per_step.append([response.raw for response in responses]) for response in responses: if response.request_id not in self.requests: scheduler_logger.warning(f"Scheduler has received a expired response: {[response.request_id]}") @@ -336,11 +338,15 @@ def get_results(self) -> Dict[str, List[RequestOutput]]: def _get_results(): responses = self.responses + batch_responses_per_step = self.batch_responses_per_step self.responses = dict() - return responses + self.batch_responses_per_step = list() + return responses, batch_responses_per_step with self.responses_not_empty: - responses = self.responses_not_empty.wait_for(_get_results, self.wait_response_timeout) + responses, batch_responses_per_step = self.responses_not_empty.wait_for( + _get_results, self.wait_response_timeout + ) results = dict() for request_id, resps in responses.items(): @@ -353,4 +359,7 @@ def _get_results(): if finished: self._recycle(request_id) scheduler_logger.info(f"Scheduler has pulled a finished response: {[request_id]}") - return results + if envs.FD_ENABLE_INTERNAL_ADAPTER: + return batch_responses_per_step + else: + return results From 02023e5f07d0e44f8674aaa867300985094f42dc Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Mon, 1 Dec 2025 21:14:39 +0800 Subject: [PATCH 2/4] [Optimize] Robust stability for PD deployment --- fastdeploy/engine/common_engine.py | 124 ++++++++++++++---- .../engine/sched/resource_manager_v1.py | 40 +++++- fastdeploy/envs.py | 4 + fastdeploy/output/token_processor.py | 4 +- fastdeploy/splitwise/splitwise_connector.py | 31 +++-- 5 files changed, 167 insertions(+), 36 deletions(-) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 7464e2d9598..944594c6314 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -304,16 +304,6 @@ def insert_tasks(self, tasks, current_id=-1, allocated=False): status, msg = self.split_connector.check_decode_allocated(task) if not status: self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.") - self.scheduler.put_results( - [ - RequestOutput( - request_id=task.request_id, - finished=True, - error_code=500, - error_msg=msg, - ) - ] - ) need_delete_tasks.append(task) continue for tmp_task in need_delete_tasks: @@ -603,36 +593,77 @@ def _fetch_request(): task.schedule_start_time = time.time() self.llm_logger.debug(f"get tasks from scheduler: {tasks}") + if self.cfg.scheduler_config.splitwise_role != "mixed": need_delete_tasks = [] if envs.FD_OFFLINE_PERF_TEST_FOR_PD: for task in tasks: + if self.resource_manager.has_existed_request(task.request_id): + self.llm_logger.error( + f"request_id: {task.request_id} has been added to scheduler, recieved requests with same request_id." + ) + need_delete_tasks.append(task) + continue # assure can allocate block ids in P while not self.resource_manager.preallocate_resource_in_p(task): time.sleep(0.005) self.llm_logger.info(f"ask D resource for req_id: {task.request_id}") while True: - self.split_connector.send_splitwise_tasks([task], task.idx) + is_successful = self.split_connector.send_splitwise_tasks([task], task.idx) + if not is_successful: # Send request for block ids to D failed + self.llm_logger.error(f"{task.request_id} send request for block ids to D failed.") + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=500, + error_msg="send request for block ids to D failed", + ) + ] + ) + need_delete_tasks.append(task) + break status, msg = self.split_connector.check_decode_allocated(task) if not status: - self.llm_logger.error(f"{task.request_id} ask D resource failed, try again.") + self.llm_logger.error( + f"{task.request_id} ask D resource failed, due to: {msg}, try again." + ) time.sleep(0.05) else: break else: for task in tasks: + if self.resource_manager.has_existed_request(task.request_id): + self.llm_logger.error( + f"request_id: {task.request_id} has been added to scheduler, recieved requests with same request_id." + ) + need_delete_tasks.append(task) + continue # assure can allocate block ids in P while not self.resource_manager.preallocate_resource_in_p(task): time.sleep(0.005) self.llm_logger.info(f"ask D resource for req_id: {task.request_id}") - self.split_connector.send_splitwise_tasks([task], task.idx) - - for task in tasks: - if self.cfg.scheduler_config.splitwise_role != "mixed": - # assure fetch block ids from D - status, msg = self.split_connector.check_decode_allocated(task) - if not status: - self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.") + is_successful = self.split_connector.send_splitwise_tasks([task], task.idx) + if not is_successful: # Send request for block ids to D failed + self.llm_logger.error(f"{task.request_id} send request for block ids to D failed.") + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=500, + error_msg=msg, + ) + ] + ) + need_delete_tasks.append(task) + continue + # assure fetch block ids from D + status, msg = self.split_connector.check_decode_allocated(task) + if not status: + self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.") + if msg != "Add task repeated": # if request repeated in D, do not need report. self.scheduler.put_results( [ RequestOutput( @@ -643,8 +674,8 @@ def _fetch_request(): ) ] ) - need_delete_tasks.append(task) - continue + need_delete_tasks.append(task) + continue for tmp_task in need_delete_tasks: tasks.remove(tmp_task) # release resource in P @@ -936,8 +967,35 @@ def receiver_loop(): for task in tasks: can_allocate_resource = False if envs.ENABLE_V1_KVCACHE_SCHEDULER: + if self.resource_manager.has_existed_request(task.request_id): + self.llm_logger.error( + f"request_id: {task.request_id} has been added to scheduler, can not add it again." + ) + task.error_msg = "Add task repeated" + task.error_code = 501 + new_waiting.append(task) + continue if self.resource_manager.preallocate_resource_in_d(task): - self.split_connector.send_cache_infos([task], -1) + is_successful = self.split_connector.send_cache_infos([task], -1) + if is_successful is False: + cur_task = self.resource_manager.requests[task.request_id] + self.resource_manager.prerelease_resource(cur_task) + if cur_task.request_id in self.token_processor.tokens_counter: + del self.token_processor.tokens_counter[task.request_id] + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=500, + error_msg="failed to send block ids back to prefill instance", + ) + ] + ) + self.llm_logger.error( + f"request {task.request_id} failed to send block_ids back to Prefill instance." + ) + continue can_allocate_resource = True else: if self.resource_manager.is_resource_sufficient( @@ -952,7 +1010,25 @@ def receiver_loop(): if new_waiting: if not self.enable_decode_cache_task: - self.split_connector.send_cache_infos(new_waiting, -1) + for task in new_waiting: + is_successful = self.split_connector.send_cache_infos([task], -1) + if ( + is_successful is False + ): # not enough block ids, D not allocated yet, due to communication failed, just report + if ( + task.error_code != 501 + ): # if repeated request, do not need to report again + self.scheduler.put_results( + [ + RequestOutput( + request_id=task.request_id, + finished=True, + error_code=500, + error_msg="failed to send not enough blocks msg to prefill instance", + ) + ] + ) + else: self.waiting_requests.extend(new_waiting) self.llm_logger.info( diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index a5dc9027af0..673a83ce9e3 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -30,7 +30,7 @@ from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.metrics.metrics import main_process_metrics -from fastdeploy.utils import llm_logger +from fastdeploy.utils import envs, llm_logger @dataclass @@ -94,6 +94,8 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l main_process_metrics.set_value("max_batch_size", max_num_seqs) self.using_extend_tables_req_id = set() + if self.config.scheduler_config.splitwise_role == "decode": + threading.Thread(target=self._monitor_recycle_block_ids_in_D, daemon=True).start() def allocated_slots(self, request: Request): return len(request.block_tables) * self.config.cache_config.block_size @@ -607,6 +609,26 @@ def add_request_in_p(self, requests: list[Request]): request.inference_start_time = time.time() self.running.append(request) + def _monitor_recycle_block_ids_in_D(self): + while True: + try: + with self.lock: + need_recycle_request_ids = [] + for request_id, timestamp in self.preallocated_requests_timestamp.items(): + if time.time() - timestamp >= envs.FD_GET_FIRST_TOKEN_FROM_P_TIMEOUT: + need_recycle_request_ids.append(request_id) + for request_id in need_recycle_request_ids: + llm_logger.error( + f"Recycle block ids for request {request_id} forcefully, due to get first token from P timeout." + ) + del self.preallocated_requests_timestamp[request_id] + request = self.requests[request_id] + self.prerelease_resource(request) + time.sleep(10) + except Exception as e: + llm_logger.error(f"Monitor recycle block ids in D error: {e}, {str(traceback.format_exc())}") + time.sleep(10) + def preallocate_resource_in_p(self, request: Request): """ In P/D aggregated deployment, preallocate resource for P. @@ -689,17 +711,33 @@ def preallocate_resource_in_d(self, request: Request): self.stop_flags[request.idx] = False self.requests[request.request_id] = request self.req_dict[request.request_id] = allocated_position + self.preallocated_requests_timestamp[request.request_id] = time.time() return True return False + def has_existed_request(self, request_id): + """ + Whether a request with the given request_id has been added to the scheduler. + """ + if request_id in self.requests: + return True + return False + def insert_task_for_decoding(self, request_output_in_p: RequestOutput): """ In P/D aggregated deployment, D should continue to decode after recieving first token and cache from P. """ assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method" with self.lock: + if request_output_in_p.request_id not in self.requests: + llm_logger.error( + f"request {request_output_in_p.request_id} with first token from P not found in preallocated resource, please check whether recycled due to timeout." + ) + return request = self.requests[request_output_in_p.request_id] request.output_token_ids.append(request_output_in_p.outputs.token_ids[0]) + if request.request_id: + del self.preallocated_requests_timestamp[request.request_id] request.num_cached_tokens = request_output_in_p.num_cached_tokens if ( self.config.speculative_config.method in ["mtp"] diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 339e84f810d..02dce298627 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -128,6 +128,10 @@ "FD_DEFAULT_METRIC_LABEL_VALUES": lambda: os.getenv("FD_DEFAULT_METRIC_LABEL_VALUES", "{}"), # Enable offline perf test mode for PD disaggregation "FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")), + # Timout for D response in PD disaggregation + "FD_GET_RESPONSE_FROM_D_TIMEOUT": lambda: int(os.getenv("FD_GET_RESPONSE_FROM_D_TIMEOUT", "5")), + # Timeout for first token from P in PD disaggregation + "FD_GET_FIRST_TOKEN_FROM_P_TIMEOUT": lambda: int(os.getenv("FD_GET_RESPONSE_FROM_D_TIMEOUT", "300")), } diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 3315992f7e9..d5a5fb2b67e 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -179,12 +179,12 @@ def _process_per_token(self, task, batch_id: int, token_ids: np.ndarray, result: llm_logger.info( f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}" ) - llm_logger.info(f"{self.resource_manager.info()}") if self.cfg.speculative_config.method: self._compute_speculative_status() if not is_prefill: self._record_completion_metrics(task, current_time) self._recycle_resources(task_id, batch_id, task, result, is_prefill) + llm_logger.info(f"{self.resource_manager.info()}") break return result @@ -581,12 +581,12 @@ def _process_batch_output(self): llm_logger.info( f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}" ) - llm_logger.info(f"{self.resource_manager.info()}") if self.cfg.speculative_config.method: self._compute_speculative_status() if not is_prefill: self._record_completion_metrics(task, current_time) self._recycle_resources(task_id, i, task, result, is_prefill) + llm_logger.info(f"{self.resource_manager.info()}") break if ( not is_prefill diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index a4a76ca9f14..35bfaed111c 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -125,6 +125,8 @@ def _get_push_socket(self, addr): sock.setsockopt(zmq.SNDHWM, 1000) sock.setsockopt(zmq.RECONNECT_IVL, 1000) sock.setsockopt(zmq.RECONNECT_IVL_MAX, 5000) + if envs.FD_ENABLE_INTERNAL_ADAPTER: + sock.setsockopt(zmq.IMMEDIATE, 1) sock.setsockopt(zmq.TCP_KEEPALIVE, 1) sock.setsockopt(zmq.TCP_KEEPALIVE_IDLE, 60) @@ -143,29 +145,34 @@ def _get_push_socket(self, addr): def _send_message(self, addr, msg_type: str, payload): if not addr: return - + is_successful = True try: self.logger.info(f"Sent {msg_type} to {addr}") message = self._serialize_message(msg_type, payload) try: - sock = self._get_push_socket(addr) - sock.send_multipart([b"", message]) + if envs.FD_ENABLE_INTERNAL_ADAPTER: + sock.send_multipart([b"", message], flags=zmq.NOBLOCK) + else: + sock.send_multipart([b"", message]) self.logger.info(f"Sent {msg_type} to {addr}") except ConnectionError: self.logger.warning(f"Connection to {addr} not established") + is_successful = False except zmq.Again: self.logger.warning(f"Send queue full for {addr}") + is_successful = False except Exception as e: self.logger.error(f"Send to {addr} failed: {e}, {str(traceback.format_exc())}") main_process_metrics.inc_value("send_cache_failed_num") + is_successful = False self._close_connection(addr) - except Exception as e: self.logger.error(f"Message preparation failed: {e}") + return is_successful def _close_connection(self, addr): """ @@ -255,6 +262,7 @@ def send_splitwise_tasks(self, tasks, current_id): return addr = None decode_diagg = None + splitwise_task_send_status = {} for task in tasks: if task.disaggregate_info is None: continue @@ -276,9 +284,11 @@ def send_splitwise_tasks(self, tasks, current_id): task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"] task.disaggregate_info["cache_info"]["rdma"]["current_id"] = current_id task.disaggregate_info["role"] = "decode" - self._send_message(addr, "prefill", [task]) + is_successful = self._send_message(addr, "prefill", [task]) + splitwise_task_send_status[task.request_id] = is_successful task.disaggregate_info["cache_info"] = decode_diagg task.disaggregate_info["role"] = "prefill" + return splitwise_task_send_status def send_splitwise_tasks_innode(self, tasks, port): """ @@ -309,6 +319,7 @@ def send_first_token(self, prefill_msg, tasks_list): """ send first token to specific port """ + is_successful = True if not isinstance(tasks_list, list): tasks_list = [tasks_list] self.logger.info("send first token to port decode") @@ -320,7 +331,8 @@ def send_first_token(self, prefill_msg, tasks_list): else: node = f"{prefill_msg['cache_info']['rdma']['ip']}:{prefill_msg['cache_info']['rdma']['port']}" self.logger.info(f"send first token to port {node} decode") - self._send_message(node, "decode", tasks_list) + is_successful = self._send_message(node, "decode", tasks_list) + return is_successful def create_connection(self, port): """ @@ -345,7 +357,7 @@ def check_decode_allocated(self, task): return True, "" while self.current_request_ids[task.request_id] == "init": time.sleep(0.001) - if time.time() - start_time > 30: + if time.time() - start_time > envs.FD_GET_RESPONSE_FROM_D_TIMEOUT: del self.current_request_ids[task.request_id] return False, "timeout" msg = self.current_request_ids[task.request_id] @@ -368,6 +380,7 @@ def send_cache_infos(self, tasks, current_id): """ is_decode = False temp_cache_info = dict() + is_successful = True for i in range(len(tasks)): if tasks[i].disaggregate_info is None: continue @@ -438,13 +451,13 @@ def send_cache_infos(self, tasks, current_id): for k, v in temp_cache_info.items(): self.logger.info(f"{k} {v}") if ":" in str(k): - self._send_message(k, "cache_sync", v) + is_successful = self._send_message(k, "cache_sync", v) else: if k not in self.connect_innode_instances: self.create_connection(k) self.connect_innode_instances[k].put_cache_info(v) - return is_decode + return is_successful def _serialize_message(self, msg_type: str, payload) -> bytes: # TODO 压缩 From 24287181db77baf12a975a5a20de599fe90ca66a Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Tue, 2 Dec 2025 11:43:08 +0800 Subject: [PATCH 3/4] robust health check --- fastdeploy/envs.py | 2 ++ fastdeploy/output/token_processor.py | 23 +++++++++++++++++++ .../splitwise/internal_adapter_utils.py | 12 +++++++++- 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 02dce298627..aeaec8c9a73 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -132,6 +132,8 @@ "FD_GET_RESPONSE_FROM_D_TIMEOUT": lambda: int(os.getenv("FD_GET_RESPONSE_FROM_D_TIMEOUT", "5")), # Timeout for first token from P in PD disaggregation "FD_GET_FIRST_TOKEN_FROM_P_TIMEOUT": lambda: int(os.getenv("FD_GET_RESPONSE_FROM_D_TIMEOUT", "300")), + # Timeout for token processor health check + "FD_TOKEN_PROCESSOR_HEALTH_TIMEOUT": lambda: int(os.getenv("FD_TOKEN_PROCESSOR_HEALTH_TIMEOUT", "120")), } diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index d5a5fb2b67e..39a7834fcb8 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -106,6 +106,10 @@ def __init__(self, cfg, cached_generated_tokens, engine_worker_queue, split_conn self.executor = ThreadPoolExecutor(max_workers=1) self.prefill_result_status = dict() self._finalizer = weakref.finalize(self, self._cleanup_resources) + # health monitor + self.timestamp_for_alive_before_handle_batch = None + self.timestamp_for_alive_after_handle_batch = None + self.health_lock = threading.Lock() def _cleanup_resources(self): """Cleaning up shared memory resources""" @@ -273,6 +277,18 @@ def process_sampling_results_use_zmq(self): llm_logger.error(f"Recieve message error: {e}") continue + def healthy(self): + """ + whether token processor is healthy + """ + with self.health_lock: + if self.timestamp_for_alive_after_handle_batch is None: # has entered handle batch + if time.time() - self.timestamp_for_alive_before_handle_batch > envs.FD_TOKEN_PROCESSOR_HEALTH_TIMEOUT: + return False + else: + return True + return True + def process_sampling_results(self): """ read tokens from paddle inference engine and process @@ -329,7 +345,14 @@ def process_sampling_results(self): continue llm_logger.debug(f"rank_id {rank_id} self.output_tokens[0, 0] {self.output_tokens[0, 0]}") # self._process_prefill_metrics() + with self.health_lock: + self.timestamp_for_alive_before_handle_batch = time.time() + self.timestamp_for_alive_after_handle_batch = None self._process_batch_output() + with self.health_lock: + self.timestamp_for_alive_before_handle_batch = None + self.timestamp_for_alive_after_handle_batch = time.time() + except Exception as e: llm_logger.info(f"while get input_data error: {e} {traceback.format_exc()!s}") diff --git a/fastdeploy/splitwise/internal_adapter_utils.py b/fastdeploy/splitwise/internal_adapter_utils.py index 010df6e5b7a..841895659db 100644 --- a/fastdeploy/splitwise/internal_adapter_utils.py +++ b/fastdeploy/splitwise/internal_adapter_utils.py @@ -22,7 +22,11 @@ import zmq from fastdeploy.inter_communicator import ZmqTcpServer -from fastdeploy.metrics.metrics import EXCLUDE_LABELS, get_filtered_metrics, main_process_metrics +from fastdeploy.metrics.metrics import ( + EXCLUDE_LABELS, + get_filtered_metrics, + main_process_metrics, +) from fastdeploy.utils import envs, get_logger logger = get_logger("internal_adapter_utils", "internal_adapter_utils.log") @@ -98,6 +102,12 @@ def _recv_external_module_control_instruct(self): self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) elif task["cmd"] == "connect_rdma": self.engine.engine_worker_queue.put_connect_rdma_task(task) + elif task["cmd"] == "check_health": + is_health = self.engine.token_processor.healthy() + result = {"task_id": task_id_str, "result": is_health} + logger.debug(f"Response for task: {task_id_str}: is_health {is_health}") + with self.response_lock: + self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result) except Exception as e: logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}") From 5410ecabafb4a9160511292efd7875d67106d2ee Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Wed, 3 Dec 2025 11:56:43 +0800 Subject: [PATCH 4/4] fix --- fastdeploy/engine/sched/resource_manager_v1.py | 1 + fastdeploy/envs.py | 2 +- fastdeploy/worker/gpu_model_runner.py | 6 +++--- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 673a83ce9e3..cc93ccac40d 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -95,6 +95,7 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l self.using_extend_tables_req_id = set() if self.config.scheduler_config.splitwise_role == "decode": + self.preallocated_requests_timestamp = {} threading.Thread(target=self._monitor_recycle_block_ids_in_D, daemon=True).start() def allocated_slots(self, request: Request): diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index aeaec8c9a73..0e10edcd48d 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -131,7 +131,7 @@ # Timout for D response in PD disaggregation "FD_GET_RESPONSE_FROM_D_TIMEOUT": lambda: int(os.getenv("FD_GET_RESPONSE_FROM_D_TIMEOUT", "5")), # Timeout for first token from P in PD disaggregation - "FD_GET_FIRST_TOKEN_FROM_P_TIMEOUT": lambda: int(os.getenv("FD_GET_RESPONSE_FROM_D_TIMEOUT", "300")), + "FD_GET_FIRST_TOKEN_FROM_P_TIMEOUT": lambda: int(os.getenv("FD_GET_FIRST_TOKEN_FROM_P_TIMEOUT", "300")), # Timeout for token processor health check "FD_TOKEN_PROCESSOR_HEALTH_TIMEOUT": lambda: int(os.getenv("FD_TOKEN_PROCESSOR_HEALTH_TIMEOUT", "120")), } diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 233d78eba48..c600f6fcb35 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -328,12 +328,12 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = if request.get("enable_thinking", False) and request.get("reasoning_max_tokens") is not None: # Enable thinking - self.share_inputs["enable_thinking"][:] = True + self.share_inputs["enable_thinking"][idx : idx + 1] = True self.share_inputs["need_think_end"][idx : idx + 1, :] = 1 self.share_inputs["reasoning_index"][idx : idx + 1, :] = request.get("reasoning_max_tokens") else: # Disable thinking - self.share_inputs["enable_thinking"][:] = False + self.share_inputs["enable_thinking"][idx : idx + 1] = False self.share_inputs["need_think_end"][idx : idx + 1, :] = 0 self.share_inputs["reasoning_index"][idx : idx + 1, :] = 0 @@ -867,7 +867,7 @@ def _init_share_inputs(self, max_num_seqs: int): # Initialize thinking related buffers self.share_inputs["need_think_end"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") - self.share_inputs["enable_thinking"] = paddle.full(shape=[1], fill_value=False, dtype="bool") + self.share_inputs["enable_thinking"] = paddle.full(shape=[max_num_seqs, 1], fill_value=False, dtype="bool") self.share_inputs["reasoning_index"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") # TODO(gongshaotian): move to models