Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,8 @@ async def generate(
nixl_pd_upload_websocket: ClientConnection = None,
# 用于等待 pd_master 下发的交换信息
nixl_pd_event: asyncio.Event = None,
# 用于在 generate 内部检测请求是否已被提前 abort
pending_aborts: set = None,
) -> AsyncGenerator[Tuple[int, str, dict, FinishStatus], None]:

start_time = time.time()
Expand Down Expand Up @@ -371,6 +373,10 @@ async def generate(
)

if nixl_pd_upload_websocket is not None and self.pd_mode.is_NP():
if pending_aborts is not None and group_request_id in pending_aborts:
pending_aborts.discard(group_request_id)
raise Exception(f"req_id {group_request_id} aborted before nixl upload")

# 在 nixl pd 模式下的 p 节点, 为了更好的兼容多模态的推理流程,np 节点需要先上报其 encode 好的 prompt ids 信息,然后
# 再等待 pd_master 传输下来的对应的进行 decode 节点的decode信息,然后再执行后续的流程
logger.info(
Expand All @@ -385,6 +391,10 @@ async def generate(
logger.error(f"nixl np node wait nixl_pd_event 36s time out, group_req_id {group_request_id}")
raise Exception(f"group_req_id {group_request_id} wait nixl_pd_event time out")

if pending_aborts is not None and group_request_id in pending_aborts:
pending_aborts.discard(group_request_id)
raise Exception(f"req_id {group_request_id} aborted while waiting nixl decode node info")

decode_node_info: NIXLDecodeNodeInfo = nixl_pd_event.decode_node_info
sampling_params.nixl_params.set(pickle.dumps(decode_node_info))

Expand All @@ -399,11 +409,19 @@ async def generate(
alloc_req_index = await self.shm_req_manager.async_alloc_req_index()
sleep_time = 0.1
while alloc_req_index is None:
if pending_aborts is not None and group_request_id in pending_aborts:
pending_aborts.discard(group_request_id)
raise Exception(f"req_id {group_request_id} aborted during SHM allocation wait")

await asyncio.sleep(sleep_time)
sleep_time *= 1.1
sleep_time = min(1, sleep_time)

alloc_req_index = await self.shm_req_manager.async_alloc_req_index()
if pending_aborts is not None and group_request_id in pending_aborts:
pending_aborts.discard(group_request_id)
await self.shm_req_manager.async_release_req_index(alloc_req_index)
raise Exception(f"req_id {group_request_id} aborted after SHM allocation")
alloced_req_indexes.append(alloc_req_index)
Comment on lines 411 to 425
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

generate 方法中,当 sampling_params.n > 1 时,会循环多次申请 SHM 索引并存入 alloced_req_indexes。如果在后续的循环中触发了提前 abort 检查(第 412-414 行或第 421-424 行),直接抛出异常会导致之前已经成功申请并存放在 alloced_req_indexes 中的 SHM 索引发生泄漏,因为它们没有被释放。

建议在抛出异常前,遍历并释放 alloced_req_indexes 中已申请的所有索引。

                while alloc_req_index is None:
                    if pending_aborts is not None and group_request_id in pending_aborts:
                        pending_aborts.discard(group_request_id)
                        for idx in alloced_req_indexes:
                            await self.shm_req_manager.async_release_req_index(idx)
                        raise Exception(f"req_id {group_request_id} aborted during SHM allocation wait")

                    await asyncio.sleep(sleep_time)
                    sleep_time *= 1.1
                    sleep_time = min(1, sleep_time)

                    alloc_req_index = await self.shm_req_manager.async_alloc_req_index()
                if pending_aborts is not None and group_request_id in pending_aborts:
                    pending_aborts.discard(group_request_id)
                    await self.shm_req_manager.async_release_req_index(alloc_req_index)
                    for idx in alloced_req_indexes:
                        await self.shm_req_manager.async_release_req_index(idx)
                    raise Exception(f"req_id {group_request_id} aborted after SHM allocation")
                alloced_req_indexes.append(alloc_req_index)

req_objs: List[Req] = []
for i, req_index in enumerate(alloced_req_indexes):
Expand Down
33 changes: 27 additions & 6 deletions lightllm/server/httpserver/pd_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
import signal
import sys
from typing import Dict, Optional, Union, List
from typing import Dict, Optional, Set, Union, List
from websockets import ClientConnection
from lightllm.server.pd_io_struct import NodeRole, ObjType
from lightllm.server.httpserver.async_queue import AsyncQueue
Expand Down Expand Up @@ -108,6 +108,7 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O
forwarding_tokens_task = asyncio.create_task(_up_tokens_to_pd_master(forwarding_queue, websocket))

group_req_id_to_event: Dict[int, asyncio.Event] = weakref.WeakValueDictionary()
pending_aborts: Set[int] = set()
# 接收 pd master 发来的请求,并推理后,将生成的token转发回pd master。
while True:
recv_bytes = await websocket.recv()
Expand All @@ -126,20 +127,32 @@ async def _pd_handle_task(manager: HttpServerManager, pd_master_obj: PD_Master_O
forwarding_queue=forwarding_queue,
nixl_pd_upload_websocket=websocket,
nixl_pd_event=nixl_pd_event,
pending_aborts=pending_aborts,
)
)
elif obj[0] == ObjType.ABORT:
group_req_id = obj[1]
logger.warning(f"recv cmd aborted req id {group_req_id}")
if not (await manager.abort(group_req_id)):
pending_aborts.add(group_req_id)

async def delayed_abort_task(group_req_id, retry_count):
for _ in range(retry_count):
async def delayed_abort_task(group_req_id, pending_aborts):
for i in range(60):
await asyncio.sleep(5.0)
if group_req_id not in pending_aborts:
return
if await manager.abort(group_req_id):
break

asyncio.create_task(delayed_abort_task(group_req_id=group_req_id, retry_count=4))
pending_aborts.discard(group_req_id)
return
pending_aborts.discard(group_req_id)
logger.error(f"abort lost for group_req_id {group_req_id} after all retries")

asyncio.create_task(
delayed_abort_task(
group_req_id=group_req_id,
pending_aborts=pending_aborts,
)
)

elif obj[0] == ObjType.NIXL_REQ_DECODE_NODE_INFO:
_, group_req_id, decode_node_info = obj
Expand Down Expand Up @@ -211,7 +224,14 @@ async def _pd_process_generate(
forwarding_queue: AsyncQueue,
nixl_pd_upload_websocket: ClientConnection,
nixl_pd_event: asyncio.Event,
pending_aborts: Set[int] = None,
):
group_req_id = sampling_params.group_request_id
if pending_aborts is not None and group_req_id in pending_aborts:
pending_aborts.discard(group_req_id)
logger.info(f"skip pre-aborted request {group_req_id}")
return

try:
async for sub_req_id, request_output, metadata, finish_status in manager.generate(
prompt=prompt,
Expand All @@ -220,6 +240,7 @@ async def _pd_process_generate(
request=None,
nixl_pd_upload_websocket=nixl_pd_upload_websocket,
nixl_pd_event=nixl_pd_event,
pending_aborts=pending_aborts,
):
metadata["node_mode"] = manager.args.run_mode
await forwarding_queue.put((sub_req_id, request_output, metadata, finish_status))
Expand Down
30 changes: 22 additions & 8 deletions lightllm/server/httpserver_for_pd_master/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ async def fetch_stream(
return

try:
await asyncio.wait_for(up_status_event.wait(), timeout=60)
await self._wait_event_or_disconnect(up_status_event, 60, request, group_request_id)
except asyncio.TimeoutError:
logger.warning(f"group_request_id: {group_request_id} kv move time out err, server is busy now.")
raise ServerBusyError()
Expand All @@ -289,6 +289,25 @@ async def fetch_stream(

return

async def _wait_event_or_disconnect(self, event: asyncio.Event, timeout: float, request: Request, group_request_id):
"""Wait for an event with periodic client disconnect checks."""
deadline = time.time() + timeout
while True:
remaining = deadline - time.time()
if remaining <= 0:
raise asyncio.TimeoutError()
try:
await asyncio.wait_for(event.wait(), timeout=min(remaining, 3.0))
return
except asyncio.TimeoutError:
if event.is_set():
return
if await request.is_disconnected():
raise ClientDisconnected(
group_request_id=group_request_id,
reason="client disconnected while waiting",
)
Comment on lines +305 to +309
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

_wait_event_or_disconnect 中,直接调用 await request.is_disconnected() 可能会在 requestNone 时引发 AttributeError。虽然在正常流程中 request 应该存在,但为了防御性编程,建议增加 request is not None 的判断。

Suggested change
if await request.is_disconnected():
raise ClientDisconnected(
group_request_id=group_request_id,
reason="client disconnected while waiting",
)
if request is not None and await request.is_disconnected():
raise ClientDisconnected(
group_request_id=group_request_id,
reason="client disconnected while waiting",
)


async def fetch_nixl_stream(
self,
p_node: PD_Client_Obj,
Expand All @@ -312,16 +331,11 @@ async def fetch_nixl_stream(
await p_node.websocket.send_bytes(pickle.dumps((ObjType.REQ, (prompt, sampling_params, multimodal_params))))

try:
await asyncio.wait_for(nixl_np_up_prompt_ids_event.wait(), timeout=60)
await self._wait_event_or_disconnect(nixl_np_up_prompt_ids_event, 60, request, group_request_id)
except asyncio.TimeoutError:
logger.warning(f"group_request_id: {group_request_id} wait np up prompt ids time out")
raise ServerBusyError()

if await request.is_disconnected():
raise ClientDisconnected(
group_request_id=group_request_id, reason="fetch_nixl_stream prefill period check network disconnected"
)

prompt_ids = nixl_np_up_prompt_ids_event.prompt_ids
logger.info(f"group_request_id: {group_request_id} get np up prompt ids len {len(prompt_ids)}")

Expand All @@ -331,7 +345,7 @@ async def fetch_nixl_stream(
)

try:
await asyncio.wait_for(up_status_event.wait(), timeout=60)
await self._wait_event_or_disconnect(up_status_event, 60, request, group_request_id)
except asyncio.TimeoutError:
logger.warning(f"group_request_id: {group_request_id} kv move time out err, server is busy now.")
raise ServerBusyError()
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def __init__(
self.nixl_pd_task_sunccess_num: int = 0
self.nixl_pd_task_failed_num: int = 0
self.nixl_trans_device_id: int = -1
self.nixl_abort_last_send_time: float = 0.0

# 类似 qwen3.5 这种混合linear att 模型使用的状态,记录申请来用于保存对应的线性att缓存的 buffer id
# 当 prefill 阶段结束后, 对应长度的 linear att state 会写入到申请 buffer id 对应的块中, 方便插入到 radix cache中
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
import random
import torch.multiprocessing as mp
from lightllm.server.pd_io_struct import NIXLChunckedTransTask, NIXLChunckedTransTaskGroup, NIXLAbortReq
Expand Down Expand Up @@ -69,10 +70,16 @@ def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]:
for request_id in req_ids:
req_obj: InferReq = g_infer_context.requests_mapping[request_id]

if self.is_master_in_dp and req_obj.infer_aborted and req_obj.nixl_pd_task_num != 0:
self.info_queue.put(NIXLAbortReq(request_id=req_obj.req_id, device_id=req_obj.nixl_trans_device_id))
has_pending_tasks = req_obj.nixl_pd_task_num > (
req_obj.nixl_pd_task_failed_num + req_obj.nixl_pd_task_sunccess_num
)
now = time.time()
if self.is_master_in_dp and req_obj.infer_aborted and has_pending_tasks:
if now - req_obj.nixl_abort_last_send_time > 1.0:
req_obj.nixl_abort_last_send_time = now
self.info_queue.put(NIXLAbortReq(request_id=req_obj.req_id, device_id=req_obj.nixl_trans_device_id))

if req_obj.nixl_pd_task_num != (req_obj.nixl_pd_task_failed_num + req_obj.nixl_pd_task_sunccess_num):
if has_pending_tasks:
continue

if req_obj.nixl_pd_task_failed_num > 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,22 +148,15 @@ def recv_task_loop(self):
assert False, f"recv error obj {obj}"

def _abort(self, cmd: NIXLAbortReq):
# check time_out update
aborted_tasks = []
with self.waiting_dict_lock:
keys = list(self.waiting_dict.keys())
for key in list(self.waiting_dict.keys()):
if self.waiting_dict[key].request_id == cmd.request_id:
aborted_tasks.append(self.waiting_dict.pop(key))

for key in keys:
with self.waiting_dict_lock:
trans_task = self.waiting_dict.pop(key, None)

if trans_task is not None and trans_task.request_id == cmd.request_id:
trans_task.error_info = "aborted req"
self.failed_queue.put(trans_task)
continue

if trans_task is not None:
with self.waiting_dict_lock:
self.waiting_dict[trans_task.get_key()] = trans_task
for trans_task in aborted_tasks:
trans_task.error_info = "aborted req"
self.failed_queue.put(trans_task)
return

@log_exception
Expand Down Expand Up @@ -261,22 +254,15 @@ def accept_peer_task_loop(
self._check_tasks_time_out()

def _check_tasks_time_out(self):
# check time_out update
expired_tasks = []
with self.waiting_dict_lock:
keys = list(self.waiting_dict.keys())

for key in keys:
with self.waiting_dict_lock:
trans_task = self.waiting_dict.pop(key, None)

if trans_task is not None and trans_task.time_out():
trans_task.error_info = "time out in accept_peer_task_loop"
self.failed_queue.put(trans_task)
continue
for key in list(self.waiting_dict.keys()):
if self.waiting_dict[key].time_out():
expired_tasks.append(self.waiting_dict.pop(key))

if trans_task is not None:
with self.waiting_dict_lock:
self.waiting_dict[trans_task.get_key()] = trans_task
for trans_task in expired_tasks:
trans_task.error_info = "time out in accept_peer_task_loop"
self.failed_queue.put(trans_task)
return

@log_exception
Expand Down Expand Up @@ -310,23 +296,34 @@ def read_peer_kv_loop(self):
def update_task_status_loop(
self,
):
while True:
trans_task: NIXLChunckedTransTask = self.update_status_task_queue.get()
in_flight: List[NIXLChunckedTransTask] = []

while True:
# Drain all newly submitted tasks
if not in_flight:
in_flight.append(self.update_status_task_queue.get())
while True:
try:
in_flight.append(self.update_status_task_queue.get_nowait())
except queue.Empty:
break

remaining = []
for trans_task in in_flight:
ret = self.transporter.check_task_status(trans_task=trans_task)
if ret == "DONE":
self.ready_page_task_queue.put(trans_task)
break
elif ret == "ERR":
trans_task.error_info = "xfer error"
self.failed_queue.put(trans_task)
break
elif trans_task.time_out():
trans_task.error_info = "time out in update_task_status_loop"
self.failed_queue.put(trans_task)
break
else:
remaining.append(trans_task)

in_flight = remaining
if in_flight:
time.sleep(0.001)

@log_exception
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import time
import torch.multiprocessing as mp
import random
from typing import List, Tuple, Optional
from lightllm.server.router.model_infer.infer_batch import InferReq
from lightllm.server.pd_io_struct import NIXLChunckedTransTask
from lightllm.server.pd_io_struct import NIXLChunckedTransTask, NIXLAbortReq
from lightllm.utils.log_utils import init_logger
from lightllm.utils.device_utils import kv_trans_use_p2p
from lightllm.server.router.model_infer.infer_batch import g_infer_context
Expand Down Expand Up @@ -33,15 +34,31 @@ def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]:
for request_id in req_ids:
req_obj: InferReq = g_infer_context.requests_mapping[request_id]
prefill_finished = req_obj.shm_req.input_len <= req_obj.cur_kv_len
all_tasks_done = req_obj.nixl_pd_task_num == (
req_obj.nixl_pd_task_failed_num + req_obj.nixl_pd_task_sunccess_num
)
has_pending_tasks = req_obj.nixl_pd_task_num > (
req_obj.nixl_pd_task_failed_num + req_obj.nixl_pd_task_sunccess_num
)
now = time.time()
if prefill_finished:
# 等待所有传输任务都已经完成。
if req_obj.nixl_pd_task_num == (req_obj.nixl_pd_task_failed_num + req_obj.nixl_pd_task_sunccess_num):
if self.is_master_in_dp and req_obj.infer_aborted and has_pending_tasks:
if now - req_obj.nixl_abort_last_send_time > 1.0:
req_obj.nixl_abort_last_send_time = now
self.info_queue.put(
NIXLAbortReq(request_id=req_obj.req_id, device_id=req_obj.nixl_trans_device_id)
)
if all_tasks_done:
ans_list.append(req_obj)
else:
if req_obj.infer_aborted:
if req_obj.nixl_pd_task_num == (
req_obj.nixl_pd_task_failed_num + req_obj.nixl_pd_task_sunccess_num
):
if self.is_master_in_dp and has_pending_tasks:
if now - req_obj.nixl_abort_last_send_time > 1.0:
req_obj.nixl_abort_last_send_time = now
self.info_queue.put(
NIXLAbortReq(request_id=req_obj.req_id, device_id=req_obj.nixl_trans_device_id)
)
if all_tasks_done:
ans_list.append(req_obj)
else:
continue
Expand Down
Loading
Loading