diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 0ac1ea206..7ddeb4c8d 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -316,6 +316,8 @@ async def generate( nixl_pd_upload_websocket: ClientConnection = None, # 用于等待 pd_master 下发的交换信息 nixl_pd_event: asyncio.Event = None, + # 用于在 generate 内部检测请求是否已被提前 abort + pending_aborts: set = None, ) -> AsyncGenerator[Tuple[int, str, dict, FinishStatus], None]: start_time = time.time() @@ -371,6 +373,10 @@ async def generate( ) if nixl_pd_upload_websocket is not None and self.pd_mode.is_NP(): + if pending_aborts is not None and group_request_id in pending_aborts: + pending_aborts.discard(group_request_id) + raise Exception(f"req_id {group_request_id} aborted before nixl upload") + # 在 nixl pd 模式下的 p 节点, 为了更好的兼容多模态的推理流程,np 节点需要先上报其 encode 好的 prompt ids 信息,然后 # 再等待 pd_master 传输下来的对应的进行 decode 节点的decode信息,然后再执行后续的流程 logger.info( @@ -385,6 +391,10 @@ async def generate( logger.error(f"nixl np node wait nixl_pd_event 36s time out, group_req_id {group_request_id}") raise Exception(f"group_req_id {group_request_id} wait nixl_pd_event time out") + if pending_aborts is not None and group_request_id in pending_aborts: + pending_aborts.discard(group_request_id) + raise Exception(f"req_id {group_request_id} aborted while waiting nixl decode node info") + decode_node_info: NIXLDecodeNodeInfo = nixl_pd_event.decode_node_info sampling_params.nixl_params.set(pickle.dumps(decode_node_info)) @@ -399,11 +409,19 @@ async def generate( alloc_req_index = await self.shm_req_manager.async_alloc_req_index() sleep_time = 0.1 while alloc_req_index is None: + if pending_aborts is not None and group_request_id in pending_aborts: + pending_aborts.discard(group_request_id) + raise Exception(f"req_id {group_request_id} aborted during SHM allocation wait") + await asyncio.sleep(sleep_time) sleep_time *= 1.1 sleep_time = min(1, sleep_time) alloc_req_index = await self.shm_req_manager.async_alloc_req_index() + if pending_aborts is not None and group_request_id in pending_aborts: + pending_aborts.discard(group_request_id) + await self.shm_req_manager.async_release_req_index(alloc_req_index) + raise Exception(f"req_id {group_request_id} aborted after SHM allocation") alloced_req_indexes.append(alloc_req_index) req_objs: List[Req] = [] for i, req_index in enumerate(alloced_req_indexes): diff --git a/lightllm/server/httpserver/pd_loop.py b/lightllm/server/httpserver/pd_loop.py index e341da2a8..fd9b7f5aa 100644 --- a/lightllm/server/httpserver/pd_loop.py +++ b/lightllm/server/httpserver/pd_loop.py @@ -9,7 +9,7 @@ import os import signal import sys -from typing import Dict, Optional, Union, List +from typing import Dict, Optional, Set, Union, List from websockets import ClientConnection from lightllm.server.pd_io_struct import NodeRole, ObjType from lightllm.server.httpserver.async_queue import AsyncQueue @@ -108,6 +108,7 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O forwarding_tokens_task = asyncio.create_task(_up_tokens_to_pd_master(forwarding_queue, websocket)) group_req_id_to_event: Dict[int, asyncio.Event] = weakref.WeakValueDictionary() + pending_aborts: Set[int] = set() # 接收 pd master 发来的请求,并推理后,将生成的token转发回pd master。 while True: recv_bytes = await websocket.recv() @@ -126,20 +127,32 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O forwarding_queue=forwarding_queue, nixl_pd_upload_websocket=websocket, nixl_pd_event=nixl_pd_event, + pending_aborts=pending_aborts, ) ) elif obj[0] == ObjType.ABORT: group_req_id = obj[1] logger.warning(f"recv cmd aborted req id {group_req_id}") if not (await manager.abort(group_req_id)): + pending_aborts.add(group_req_id) - async def delayed_abort_task(group_req_id, retry_count): - for _ in range(retry_count): + async def delayed_abort_task(group_req_id, pending_aborts): + for i in range(60): await asyncio.sleep(5.0) + if group_req_id not in pending_aborts: + return if await manager.abort(group_req_id): - break - - asyncio.create_task(delayed_abort_task(group_req_id=group_req_id, retry_count=4)) + pending_aborts.discard(group_req_id) + return + pending_aborts.discard(group_req_id) + logger.error(f"abort lost for group_req_id {group_req_id} after all retries") + + asyncio.create_task( + delayed_abort_task( + group_req_id=group_req_id, + pending_aborts=pending_aborts, + ) + ) elif obj[0] == ObjType.NIXL_REQ_DECODE_NODE_INFO: _, group_req_id, decode_node_info = obj @@ -211,7 +224,14 @@ async def _pd_process_generate( forwarding_queue: AsyncQueue, nixl_pd_upload_websocket: ClientConnection, nixl_pd_event: asyncio.Event, + pending_aborts: Set[int] = None, ): + group_req_id = sampling_params.group_request_id + if pending_aborts is not None and group_req_id in pending_aborts: + pending_aborts.discard(group_req_id) + logger.info(f"skip pre-aborted request {group_req_id}") + return + try: async for sub_req_id, request_output, metadata, finish_status in manager.generate( prompt=prompt, @@ -220,6 +240,7 @@ async def _pd_process_generate( request=None, nixl_pd_upload_websocket=nixl_pd_upload_websocket, nixl_pd_event=nixl_pd_event, + pending_aborts=pending_aborts, ): metadata["node_mode"] = manager.args.run_mode await forwarding_queue.put((sub_req_id, request_output, metadata, finish_status)) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 307a3d48a..bca16dfac 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -262,7 +262,7 @@ async def fetch_stream( return try: - await asyncio.wait_for(up_status_event.wait(), timeout=60) + await self._wait_event_or_disconnect(up_status_event, 60, request, group_request_id) except asyncio.TimeoutError: logger.warning(f"group_request_id: {group_request_id} kv move time out err, server is busy now.") raise ServerBusyError() @@ -289,6 +289,25 @@ async def fetch_stream( return + async def _wait_event_or_disconnect(self, event: asyncio.Event, timeout: float, request: Request, group_request_id): + """Wait for an event with periodic client disconnect checks.""" + deadline = time.time() + timeout + while True: + remaining = deadline - time.time() + if remaining <= 0: + raise asyncio.TimeoutError() + try: + await asyncio.wait_for(event.wait(), timeout=min(remaining, 3.0)) + return + except asyncio.TimeoutError: + if event.is_set(): + return + if await request.is_disconnected(): + raise ClientDisconnected( + group_request_id=group_request_id, + reason="client disconnected while waiting", + ) + async def fetch_nixl_stream( self, p_node: PD_Client_Obj, @@ -312,16 +331,11 @@ async def fetch_nixl_stream( await p_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params)))) try: - await asyncio.wait_for(nixl_np_up_prompt_ids_event.wait(), timeout=60) + await self._wait_event_or_disconnect(nixl_np_up_prompt_ids_event, 60, request, group_request_id) except asyncio.TimeoutError: logger.warning(f"group_request_id: {group_request_id} wait np up prompt ids time out") raise ServerBusyError() - if await request.is_disconnected(): - raise ClientDisconnected( - group_request_id=group_request_id, reason="fetch_nixl_stream prefill period check network disconnected" - ) - prompt_ids = nixl_np_up_prompt_ids_event.prompt_ids logger.info(f"group_request_id: {group_request_id} get np up prompt ids len {len(prompt_ids)}") @@ -331,7 +345,7 @@ async def fetch_nixl_stream( ) try: - await asyncio.wait_for(up_status_event.wait(), timeout=60) + await self._wait_event_or_disconnect(up_status_event, 60, request, group_request_id) except asyncio.TimeoutError: logger.warning(f"group_request_id: {group_request_id} kv move time out err, server is busy now.") raise ServerBusyError() diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index f0ec69b2c..27cf322e7 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -534,6 +534,7 @@ def __init__( self.nixl_pd_task_sunccess_num: int = 0 self.nixl_pd_task_failed_num: int = 0 self.nixl_trans_device_id: int = -1 + self.nixl_abort_last_send_time: float = 0.0 # 类似 qwen3.5 这种混合linear att 模型使用的状态,记录申请来用于保存对应的线性att缓存的 buffer id # 当 prefill 阶段结束后, 对应长度的 linear att state 会写入到申请 buffer id 对应的块中, 方便插入到 radix cache中 diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py index 481a3197d..98d23a2c6 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_impl.py @@ -1,3 +1,4 @@ +import time import random import torch.multiprocessing as mp from lightllm.server.pd_io_struct import NIXLChunckedTransTask, NIXLChunckedTransTaskGroup, NIXLAbortReq @@ -69,10 +70,16 @@ def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: for request_id in req_ids: req_obj: InferReq = g_infer_context.requests_mapping[request_id] - if self.is_master_in_dp and req_obj.infer_aborted and req_obj.nixl_pd_task_num != 0: - self.info_queue.put(NIXLAbortReq(request_id=req_obj.req_id, device_id=req_obj.nixl_trans_device_id)) + has_pending_tasks = req_obj.nixl_pd_task_num > ( + req_obj.nixl_pd_task_failed_num + req_obj.nixl_pd_task_sunccess_num + ) + now = time.time() + if self.is_master_in_dp and req_obj.infer_aborted and has_pending_tasks: + if now - req_obj.nixl_abort_last_send_time > 1.0: + req_obj.nixl_abort_last_send_time = now + self.info_queue.put(NIXLAbortReq(request_id=req_obj.req_id, device_id=req_obj.nixl_trans_device_id)) - if req_obj.nixl_pd_task_num != (req_obj.nixl_pd_task_failed_num + req_obj.nixl_pd_task_sunccess_num): + if has_pending_tasks: continue if req_obj.nixl_pd_task_failed_num > 0: diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py index 791386540..8fb6299f3 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/decode_node_impl/decode_trans_process.py @@ -148,22 +148,15 @@ def recv_task_loop(self): assert False, f"recv error obj {obj}" def _abort(self, cmd: NIXLAbortReq): - # check time_out update + aborted_tasks = [] with self.waiting_dict_lock: - keys = list(self.waiting_dict.keys()) + for key in list(self.waiting_dict.keys()): + if self.waiting_dict[key].request_id == cmd.request_id: + aborted_tasks.append(self.waiting_dict.pop(key)) - for key in keys: - with self.waiting_dict_lock: - trans_task = self.waiting_dict.pop(key, None) - - if trans_task is not None and trans_task.request_id == cmd.request_id: - trans_task.error_info = "aborted req" - self.failed_queue.put(trans_task) - continue - - if trans_task is not None: - with self.waiting_dict_lock: - self.waiting_dict[trans_task.get_key()] = trans_task + for trans_task in aborted_tasks: + trans_task.error_info = "aborted req" + self.failed_queue.put(trans_task) return @log_exception @@ -261,22 +254,15 @@ def accept_peer_task_loop( self._check_tasks_time_out() def _check_tasks_time_out(self): - # check time_out update + expired_tasks = [] with self.waiting_dict_lock: - keys = list(self.waiting_dict.keys()) - - for key in keys: - with self.waiting_dict_lock: - trans_task = self.waiting_dict.pop(key, None) - - if trans_task is not None and trans_task.time_out(): - trans_task.error_info = "time out in accept_peer_task_loop" - self.failed_queue.put(trans_task) - continue + for key in list(self.waiting_dict.keys()): + if self.waiting_dict[key].time_out(): + expired_tasks.append(self.waiting_dict.pop(key)) - if trans_task is not None: - with self.waiting_dict_lock: - self.waiting_dict[trans_task.get_key()] = trans_task + for trans_task in expired_tasks: + trans_task.error_info = "time out in accept_peer_task_loop" + self.failed_queue.put(trans_task) return @log_exception @@ -310,23 +296,34 @@ def read_peer_kv_loop(self): def update_task_status_loop( self, ): - while True: - trans_task: NIXLChunckedTransTask = self.update_status_task_queue.get() + in_flight: List[NIXLChunckedTransTask] = [] + while True: + # Drain all newly submitted tasks + if not in_flight: + in_flight.append(self.update_status_task_queue.get()) while True: + try: + in_flight.append(self.update_status_task_queue.get_nowait()) + except queue.Empty: + break + + remaining = [] + for trans_task in in_flight: ret = self.transporter.check_task_status(trans_task=trans_task) if ret == "DONE": self.ready_page_task_queue.put(trans_task) - break elif ret == "ERR": trans_task.error_info = "xfer error" self.failed_queue.put(trans_task) - break elif trans_task.time_out(): trans_task.error_info = "time out in update_task_status_loop" self.failed_queue.put(trans_task) - break + else: + remaining.append(trans_task) + in_flight = remaining + if in_flight: time.sleep(0.001) @log_exception diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py index 20c487d56..d6fedfb93 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_impl.py @@ -1,8 +1,9 @@ +import time import torch.multiprocessing as mp import random from typing import List, Tuple, Optional from lightllm.server.router.model_infer.infer_batch import InferReq -from lightllm.server.pd_io_struct import NIXLChunckedTransTask +from lightllm.server.pd_io_struct import NIXLChunckedTransTask, NIXLAbortReq from lightllm.utils.log_utils import init_logger from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.server.router.model_infer.infer_batch import g_infer_context @@ -33,15 +34,31 @@ def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]: for request_id in req_ids: req_obj: InferReq = g_infer_context.requests_mapping[request_id] prefill_finished = req_obj.shm_req.input_len <= req_obj.cur_kv_len + all_tasks_done = req_obj.nixl_pd_task_num == ( + req_obj.nixl_pd_task_failed_num + req_obj.nixl_pd_task_sunccess_num + ) + has_pending_tasks = req_obj.nixl_pd_task_num > ( + req_obj.nixl_pd_task_failed_num + req_obj.nixl_pd_task_sunccess_num + ) + now = time.time() if prefill_finished: - # 等待所有传输任务都已经完成。 - if req_obj.nixl_pd_task_num == (req_obj.nixl_pd_task_failed_num + req_obj.nixl_pd_task_sunccess_num): + if self.is_master_in_dp and req_obj.infer_aborted and has_pending_tasks: + if now - req_obj.nixl_abort_last_send_time > 1.0: + req_obj.nixl_abort_last_send_time = now + self.info_queue.put( + NIXLAbortReq(request_id=req_obj.req_id, device_id=req_obj.nixl_trans_device_id) + ) + if all_tasks_done: ans_list.append(req_obj) else: if req_obj.infer_aborted: - if req_obj.nixl_pd_task_num == ( - req_obj.nixl_pd_task_failed_num + req_obj.nixl_pd_task_sunccess_num - ): + if self.is_master_in_dp and has_pending_tasks: + if now - req_obj.nixl_abort_last_send_time > 1.0: + req_obj.nixl_abort_last_send_time = now + self.info_queue.put( + NIXLAbortReq(request_id=req_obj.req_id, device_id=req_obj.nixl_trans_device_id) + ) + if all_tasks_done: ans_list.append(req_obj) else: continue diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_kv_move_manager.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_kv_move_manager.py index fb9509115..831af46cd 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_kv_move_manager.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_kv_move_manager.py @@ -4,7 +4,7 @@ import time from typing import List, Dict, Union, Callable from lightllm.utils.log_utils import init_logger -from lightllm.server.pd_io_struct import NIXLChunckedTransTask +from lightllm.server.pd_io_struct import NIXLChunckedTransTask, NIXLAbortReq from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.core.objs import StartArgs from ..trans_process_obj import KVTransProcess @@ -56,12 +56,20 @@ def __init__(self, args: StartArgs, info_queue: mp.Queue, start_trans_process_fu def task_dispatcher_loop(self): # 获取任务,并分发给相关卡的处理队列 while True: - task: NIXLChunckedTransTask = self.info_queue.get() + task = self.info_queue.get() + + if isinstance(task, NIXLAbortReq): + device_id = task.device_id + elif isinstance(task, NIXLChunckedTransTask): + device_id = task.src_device_id + else: + logger.error(f"unexpected task type: {type(task)}") + continue - device_id = task.src_device_id try: trans_process: KVTransProcess = self.kv_trans_processes[device_id] trans_process.task_in_queue.put(task) - logger.info(f"kv move manager dispatch task {task.to_str()} to device {device_id}") + if isinstance(task, NIXLChunckedTransTask): + logger.info(f"kv move manager dispatch task {task.to_str()} to device {device_id}") except BaseException as e: logger.exception(str(e)) diff --git a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py index 7975a253f..839275cb3 100644 --- a/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py +++ b/lightllm/server/router/model_infer/mode_backend/pd_nixl/prefill_node_impl/prefill_trans_process.py @@ -10,7 +10,7 @@ from typing import List, Dict, Union, Deque, Optional from lightllm.utils.log_utils import init_logger from lightllm.common.kv_cache_mem_manager import MemoryManager -from lightllm.server.pd_io_struct import NIXLChunckedTransTask, NIXLChunckedTransTaskRet +from lightllm.server.pd_io_struct import NIXLChunckedTransTask, NIXLChunckedTransTaskRet, NIXLAbortReq from lightllm.utils.device_utils import kv_trans_use_p2p from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.core.objs import StartArgs @@ -128,8 +128,14 @@ def recv_task_loop(self): torch.cuda.set_device(self.device_id) while True: + obj = self.task_in_queue.get() + + if isinstance(obj, NIXLAbortReq): + self._abort(cmd=obj) + continue + + trans_task: NIXLChunckedTransTask = obj page_index = self.page_index_queue.get() - trans_task: NIXLChunckedTransTask = self.task_in_queue.get() trans_task.nixl_src_page_index = page_index # 初次校验 time out @@ -139,6 +145,18 @@ def recv_task_loop(self): else: self.local_copy_kv_queue.put(trans_task) + def _abort(self, cmd: NIXLAbortReq): + aborted_tasks = [] + with self.waiting_dict_lock: + for key in list(self.waiting_dict.keys()): + if self.waiting_dict[key].request_id == cmd.request_id: + aborted_tasks.append(self.waiting_dict.pop(key)) + + for trans_task in aborted_tasks: + trans_task.error_info = "aborted req" + self.failed_queue.put(trans_task) + return + @log_exception def local_copy_kv_loop(self): torch.cuda.set_device(self.device_id) @@ -236,21 +254,15 @@ def update_task_status_loop( self._check_tasks_time_out() def _check_tasks_time_out(self): + expired_tasks = [] with self.waiting_dict_lock: - keys = list(self.waiting_dict.keys()) + for key in list(self.waiting_dict.keys()): + if self.waiting_dict[key].time_out(): + expired_tasks.append(self.waiting_dict.pop(key)) - for key in keys: - with self.waiting_dict_lock: - trans_task = self.waiting_dict.pop(key, None) - - if trans_task is not None and trans_task.time_out(): - trans_task.error_info = "time out in update_task_status_loop" - self.failed_queue.put(trans_task) - continue - - if trans_task is not None: - with self.waiting_dict_lock: - self.waiting_dict[trans_task.get_key()] = trans_task + for trans_task in expired_tasks: + trans_task.error_info = "time out in update_task_status_loop" + self.failed_queue.put(trans_task) return @log_exception