diff --git a/lmdeploy/pytorch/backends/cuda/graph_runner.py b/lmdeploy/pytorch/backends/cuda/graph_runner.py index 15759c3252..58d093cf9b 100644 --- a/lmdeploy/pytorch/backends/cuda/graph_runner.py +++ b/lmdeploy/pytorch/backends/cuda/graph_runner.py @@ -281,6 +281,14 @@ def prepare_inputs_for_generation( def reset(self): """Remove all graphs to prevent hanging on exit.""" self._runner_map.clear() + if get_moe_backend().use_deepep_moe_backend(): + from dlblas.layers.moe.token_dispatcher import DeepEPBuffer + + if hasattr(DeepEPBuffer, 'destroy'): + from torch import distributed as dist + + DeepEPBuffer.destroy() + dist.barrier() def update_inputs(self, inputs): """Update inputs.""" diff --git a/lmdeploy/pytorch/backends/cuda/moe/blocked_fp8.py b/lmdeploy/pytorch/backends/cuda/moe/blocked_fp8.py index 00f1036b0d..8810f57a7e 100644 --- a/lmdeploy/pytorch/backends/cuda/moe/blocked_fp8.py +++ b/lmdeploy/pytorch/backends/cuda/moe/blocked_fp8.py @@ -117,6 +117,8 @@ def __init__(self, try: from dlblas.layers.moe.token_dispatcher import DeepEPBuffer, DeepEPMode, use_deepep # noqa: F401 get_moe_backend().set_deepep_moe_backend() + if hasattr(DeepEPBuffer, 'set_explicitly_destroy'): + DeepEPBuffer.set_explicitly_destroy() except ImportError: logger.warning('For higher performance, please install DeepEP https://github.com/deepseek-ai/DeepEP') diff --git a/lmdeploy/pytorch/backends/cuda/moe/default.py b/lmdeploy/pytorch/backends/cuda/moe/default.py index 9042a72aad..ceef74a2b1 100644 --- a/lmdeploy/pytorch/backends/cuda/moe/default.py +++ b/lmdeploy/pytorch/backends/cuda/moe/default.py @@ -387,6 +387,8 @@ def __init__( try: from dlblas.layers.moe.token_dispatcher import DeepEPBuffer, DeepEPMode, use_deepep # noqa: F401 get_moe_backend().set_deepep_moe_backend() + if hasattr(DeepEPBuffer, 'set_explicitly_destroy'): + DeepEPBuffer.set_explicitly_destroy() except ImportError: logger.warning('For higher performance, please install DeepEP https://github.com/deepseek-ai/DeepEP') diff --git a/lmdeploy/pytorch/engine/executor/base.py b/lmdeploy/pytorch/engine/executor/base.py index f2730b2005..b5d560e5a1 100644 --- a/lmdeploy/pytorch/engine/executor/base.py +++ b/lmdeploy/pytorch/engine/executor/base.py @@ -74,7 +74,7 @@ def warmup(self): """warmup.""" raise NotImplementedError('Not Implemented.') - def sleep(self, level: int = 1): + async def sleep(self, level: int = 1): """Sleep.""" raise NotImplementedError('Not Implemented.') diff --git a/lmdeploy/pytorch/engine/executor/base_worker.py b/lmdeploy/pytorch/engine/executor/base_worker.py index a82feb75a0..40ace7defc 100644 --- a/lmdeploy/pytorch/engine/executor/base_worker.py +++ b/lmdeploy/pytorch/engine/executor/base_worker.py @@ -119,9 +119,9 @@ def warmup(self): """warmup.""" self.model_agent.warmup() - def sleep(self, level: int = 1): + async def sleep(self, level: int = 1): """Sleep.""" - self.model_agent.sleep(level) + await self.model_agent.sleep(level) def wakeup(self, tags: Optional[List[str]] = None): """Wakeup.""" diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index f8c8cd226c..b687e8097f 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -2,7 +2,7 @@ import asyncio import time from contextlib import contextmanager -from dataclasses import dataclass, fields +from dataclasses import dataclass, field, fields from multiprocessing.reduction import ForkingPickler from os import getenv from typing import Any, Dict, List, Optional @@ -38,6 +38,13 @@ logger = get_logger('lmdeploy') +@dataclass +class SleepWakeupState: + to_sleep: asyncio.Event = field(default_factory=asyncio.Event) + to_wakeup: asyncio.Event = field(default_factory=asyncio.Event) + is_sleeping: bool = False + + @dataclass class BatchedLogProbs: vals: torch.Tensor @@ -398,6 +405,8 @@ def __init__( self.inputs_strategy, self.agent_strategy, device=device) + # sleep wakeup state + self.state: SleepWakeupState = SleepWakeupState() # decoding inputs self.step_inputs = StepInputs() @@ -565,7 +574,8 @@ async def _prepare_dp_v1(self, inputs: ModelInputs): # gather dp forward metadata batch_size = inputs.seq_length.numel() - dp_forward_meta = [int(is_decoding), int(is_dummy), num_tokens] + is_sleeping = self.state.is_sleeping + dp_forward_meta = [int(is_decoding), int(is_dummy), num_tokens, int(is_sleeping)] # check enable_microbatch if self.enable_microbatch: tokens_num = inputs.input_ids.numel() @@ -589,8 +599,9 @@ async def _prepare_dp_v1(self, inputs: ModelInputs): # check if all inputs are dummy inputs is_all_dummy = gathered_meta[:, 1].all().item() + is_all_sleeping = gathered_meta[:, 3].all().item() if is_all_dummy: - return None + return None, is_all_sleeping # pad batch size for decoding all_num_tokens = gathered_meta[:, 2].tolist() @@ -607,7 +618,7 @@ async def _prepare_dp_v1(self, inputs: ModelInputs): # update dp meta inputs.build_dp_meta(all_num_tokens) inputs = self.patched_model.update_inputs(inputs) - return inputs + return inputs, is_all_sleeping def _get_inputs_from_delta( self, @@ -792,9 +803,15 @@ def __update_inputs( is_decoding = inputs.is_decoding if dp > 1: # update inputs for dp - inputs = await self._prepare_dp_v1(inputs) + inputs, is_all_sleeping = await self._prepare_dp_v1(inputs) # skip dummy forward. if inputs is None: + if is_all_sleeping: + self.state.to_sleep.set() + await self.state.to_wakeup.wait() + self.state.to_wakeup.clear() + # sync after wakeup + dist.barrier() logger.debug(f' rank[{rank}]: all inputs are dummy, skip forward.') await asyncio.sleep(0.01) return @@ -1176,14 +1193,17 @@ def _construct(item): torch.cuda.empty_cache() @torch.inference_mode() - def sleep(self, level: int = 1): + async def sleep(self, level: int = 1): """Sleep.""" + self.state.is_sleeping = True + await self.state.to_sleep.wait() self.cache_engine = None self.reset_graph_runner() device = 'cpu' if level == 1 else 'meta' self.patched_model.get_model().to(device=device, non_blocking=True) torch.cuda.synchronize() torch.cuda.empty_cache() + self.state.to_sleep.clear() @torch.inference_mode() def wakeup(self, tags: Optional[List[str]] = None): @@ -1202,8 +1222,12 @@ def wakeup(self, tags: Optional[List[str]] = None): self.build_model() self.build_graph_runner() self.misc_config.empty_init = old_empty_init + if 'kv_cache' in tags: self.build_cache_engine() + # wake up signal + self.state.is_sleeping = False + self.state.to_wakeup.set() def release(self): """release."""