Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions lmdeploy/pytorch/backends/cuda/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,14 @@ def prepare_inputs_for_generation(
def reset(self):
"""Remove all graphs to prevent hanging on exit."""
self._runner_map.clear()
if get_moe_backend().use_deepep_moe_backend():
from dlblas.layers.moe.token_dispatcher import DeepEPBuffer

if hasattr(DeepEPBuffer, 'destroy'):
from torch import distributed as dist

DeepEPBuffer.destroy()
dist.barrier()

def update_inputs(self, inputs):
"""Update inputs."""
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/backends/cuda/moe/blocked_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def __init__(self,
try:
from dlblas.layers.moe.token_dispatcher import DeepEPBuffer, DeepEPMode, use_deepep # noqa: F401
get_moe_backend().set_deepep_moe_backend()
if hasattr(DeepEPBuffer, 'set_explicitly_destroy'):
DeepEPBuffer.set_explicitly_destroy()
except ImportError:
logger.warning('For higher performance, please install DeepEP https://github.com/deepseek-ai/DeepEP')

Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/backends/cuda/moe/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@ def __init__(
try:
from dlblas.layers.moe.token_dispatcher import DeepEPBuffer, DeepEPMode, use_deepep # noqa: F401
get_moe_backend().set_deepep_moe_backend()
if hasattr(DeepEPBuffer, 'set_explicitly_destroy'):
DeepEPBuffer.set_explicitly_destroy()
except ImportError:
logger.warning('For higher performance, please install DeepEP https://github.com/deepseek-ai/DeepEP')

Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/engine/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def warmup(self):
"""warmup."""
raise NotImplementedError('Not Implemented.')

def sleep(self, level: int = 1):
async def sleep(self, level: int = 1):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: uni and mp executor may implement

"""Sleep."""
raise NotImplementedError('Not Implemented.')

Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/engine/executor/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ def warmup(self):
"""warmup."""
self.model_agent.warmup()

def sleep(self, level: int = 1):
async def sleep(self, level: int = 1):
"""Sleep."""
self.model_agent.sleep(level)
await self.model_agent.sleep(level)

def wakeup(self, tags: Optional[List[str]] = None):
"""Wakeup."""
Expand Down
36 changes: 30 additions & 6 deletions lmdeploy/pytorch/engine/model_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import asyncio
import time
from contextlib import contextmanager
from dataclasses import dataclass, fields
from dataclasses import dataclass, field, fields
from multiprocessing.reduction import ForkingPickler
from os import getenv
from typing import Any, Dict, List, Optional
Expand Down Expand Up @@ -38,6 +38,13 @@
logger = get_logger('lmdeploy')


@dataclass
class SleepWakeupState:
to_sleep: asyncio.Event = field(default_factory=asyncio.Event)
to_wakeup: asyncio.Event = field(default_factory=asyncio.Event)
is_sleeping: bool = False


@dataclass
class BatchedLogProbs:
vals: torch.Tensor
Expand Down Expand Up @@ -398,6 +405,8 @@ def __init__(
self.inputs_strategy,
self.agent_strategy,
device=device)
# sleep wakeup state
self.state: SleepWakeupState = SleepWakeupState()

# decoding inputs
self.step_inputs = StepInputs()
Expand Down Expand Up @@ -565,7 +574,8 @@ async def _prepare_dp_v1(self, inputs: ModelInputs):

# gather dp forward metadata
batch_size = inputs.seq_length.numel()
dp_forward_meta = [int(is_decoding), int(is_dummy), num_tokens]
is_sleeping = self.state.is_sleeping
dp_forward_meta = [int(is_decoding), int(is_dummy), num_tokens, int(is_sleeping)]
# check enable_microbatch
if self.enable_microbatch:
tokens_num = inputs.input_ids.numel()
Expand All @@ -589,8 +599,9 @@ async def _prepare_dp_v1(self, inputs: ModelInputs):

# check if all inputs are dummy inputs
is_all_dummy = gathered_meta[:, 1].all().item()
is_all_sleeping = gathered_meta[:, 3].all().item()
if is_all_dummy:
return None
return None, is_all_sleeping

# pad batch size for decoding
all_num_tokens = gathered_meta[:, 2].tolist()
Expand All @@ -607,7 +618,7 @@ async def _prepare_dp_v1(self, inputs: ModelInputs):
# update dp meta
inputs.build_dp_meta(all_num_tokens)
inputs = self.patched_model.update_inputs(inputs)
return inputs
return inputs, is_all_sleeping

def _get_inputs_from_delta(
self,
Expand Down Expand Up @@ -792,9 +803,15 @@ def __update_inputs(
is_decoding = inputs.is_decoding
if dp > 1:
# update inputs for dp
inputs = await self._prepare_dp_v1(inputs)
inputs, is_all_sleeping = await self._prepare_dp_v1(inputs)
# skip dummy forward.
if inputs is None:
if is_all_sleeping:
self.state.to_sleep.set()
await self.state.to_wakeup.wait()
self.state.to_wakeup.clear()
# sync after wakeup
dist.barrier()
logger.debug(f'<ForwardTask> rank[{rank}]: all inputs are dummy, skip forward.')
await asyncio.sleep(0.01)
return
Expand Down Expand Up @@ -1176,14 +1193,17 @@ def _construct(item):
torch.cuda.empty_cache()

@torch.inference_mode()
def sleep(self, level: int = 1):
async def sleep(self, level: int = 1):
"""Sleep."""
self.state.is_sleeping = True
await self.state.to_sleep.wait()
self.cache_engine = None
self.reset_graph_runner()
device = 'cpu' if level == 1 else 'meta'
self.patched_model.get_model().to(device=device, non_blocking=True)
torch.cuda.synchronize()
torch.cuda.empty_cache()
self.state.to_sleep.clear()

@torch.inference_mode()
def wakeup(self, tags: Optional[List[str]] = None):
Expand All @@ -1202,8 +1222,12 @@ def wakeup(self, tags: Optional[List[str]] = None):
self.build_model()
self.build_graph_runner()
self.misc_config.empty_init = old_empty_init

if 'kv_cache' in tags:
self.build_cache_engine()
# wake up signal
self.state.is_sleeping = False
self.state.to_wakeup.set()

def release(self):
"""release."""
Expand Down