Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@ class ModelConfig:
# check env for model-device combination
check_env_func: Callable = _default_check_env

# fp32 lm head
enforce_fp32_head: bool = False
tie_word_embeddings: bool = False

def get_head_size(self):
"""Get head size."""
return self.head_dim
Expand Down Expand Up @@ -347,12 +351,17 @@ def from_pretrained(
is_draft_model=is_draft_model,
spec_method=spec_method,
)

enforce_fp32_head = False
if hf_overrides is not None:
logger = get_logger('lmdeploy')
logger.warning(f'Overriding HF config with {hf_overrides}')
enforce_fp32_head = hf_overrides.pop('enforce_fp32_head', False)
override_hf_config(model_config.hf_config, hf_overrides)

# for fp32 head
model_config.enforce_fp32_head = enforce_fp32_head
model_config.tie_word_embeddings = getattr(hf_config, 'tie_word_embeddings', False)

# for serialization of transformers modules
maybe_register_config_serialize_by_value(trust_remote_code)

Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/engine/model_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,8 @@ def _build_model(self):
dllm_config=self.misc_config.dllm_config,
strategy_factory=self.strategy_factory,
enable_return_routed_experts=enable_return_routed_experts,
enforce_fp32_head=self.model_config.enforce_fp32_head,
tie_word_embeddings=self.model_config.tie_word_embeddings,
)
patched_model = build_patched_model(self.model_config,
device=device,
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/model_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,8 @@ class BuildModelContext:
dllm_config: DLLMConfig = None
strategy_factory: 'StrategyFactoryBase' = None
enable_return_routed_experts: bool = False
enforce_fp32_head: bool = False
tie_word_embeddings: bool = False


class StepContextManager(CtxMgrBase[StepContext]):
Expand Down
33 changes: 12 additions & 21 deletions lmdeploy/pytorch/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import build_o_proj, build_qkv_proj, build_rowwise_linear
from lmdeploy.pytorch.nn.linear import build_o_proj, build_qkv_proj
from lmdeploy.pytorch.nn.moe import build_fused_moe
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixinV1, build_embedding


class GptOssAttention(nn.Module):
Expand Down Expand Up @@ -333,11 +334,14 @@ class GptOssModel(nn.Module):

def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None):
super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
config.pad_token_id,
dtype=dtype,
device=device)

self.embed_tokens = build_embedding(
config.vocab_size,
config.hidden_size,
config.pad_token_id,
dtype=dtype,
device=device,
)

# build all decode layers
self.layers = nn.ModuleList([
Expand Down Expand Up @@ -394,7 +398,7 @@ def get_input_embeddings(self):
return self.embed_tokens


class GptOssForCausalLM(nn.Module, CudaGraphMixin):
class GptOssForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):
"""ModelForCausalLM."""

packed_modules_mapping = {
Expand All @@ -416,11 +420,7 @@ def __init__(self,
# build model
self.model = GptOssModel(config, dtype=dtype, device=device)
# build lm_head
self.lm_head = build_rowwise_linear(config.hidden_size,
config.vocab_size,
bias=False,
dtype=dtype,
device=device)
self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

def forward(
self,
Expand All @@ -441,15 +441,6 @@ def forward(
)
return hidden_states

def get_logits(self, hidden_states: torch.Tensor):
"""Compute logits of the model output."""
return self.lm_head(hidden_states)

def update_weights(self):
"""Update weights."""
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight

def get_input_embeddings(self):
"""Get input embeddings."""
return self.model.get_input_embeddings()
Expand Down
25 changes: 11 additions & 14 deletions lmdeploy/pytorch/models/internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,
build_rowwise_linear)
from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixinV1, build_embedding


class InternLM2Attention(nn.Module):
Expand Down Expand Up @@ -208,12 +208,13 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.tok_embeddings = nn.Embedding(config.vocab_size,
config.hidden_size,
self.padding_idx,
dtype=dtype,
device=device)

self.tok_embeddings = build_embedding(
config.vocab_size,
config.hidden_size,
self.padding_idx,
dtype=dtype,
device=device,
)
# build all decode layers
self.layers = nn.ModuleList([
InternLM2DecoderLayer(config, layer_idx, dtype=dtype, device=device)
Expand Down Expand Up @@ -269,7 +270,7 @@ def get_input_embeddings(self):
return self.tok_embeddings


class InternLM2ForCausalLM(nn.Module, CudaGraphMixin):
class InternLM2ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):
"""Rewrote model of InternLM2ForCausalLM."""

packed_modules_mapping = {
Expand All @@ -290,11 +291,7 @@ def __init__(self,
# build Model
self.model = InternLM2Model(config, dtype=dtype, device=device)
# build lm_head
self.output = build_rowwise_linear(config.hidden_size,
config.vocab_size,
bias=False,
dtype=dtype,
device=device)
self.output = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

def forward(
self,
Expand Down
33 changes: 11 additions & 22 deletions lmdeploy/pytorch/models/internlm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config
from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,
build_rowwise_linear)
from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixinV1, build_embedding


class InternLM3Attention(nn.Module):
Expand Down Expand Up @@ -210,11 +210,13 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
self.padding_idx,
dtype=dtype,
device=device)
self.embed_tokens = build_embedding(
config.vocab_size,
config.hidden_size,
self.padding_idx,
dtype=dtype,
device=device,
)

# build all decode layers
self.layers = nn.ModuleList([
Expand Down Expand Up @@ -271,7 +273,7 @@ def get_input_embeddings(self):
return self.embed_tokens


class InternLM3ForCausalLM(nn.Module, CudaGraphMixin):
class InternLM3ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):
"""Rewrote model of InternLM3ForCausalLM."""

packed_modules_mapping = {
Expand All @@ -297,11 +299,7 @@ def __init__(self,
# build InternLM3Model
self.model = InternLM3Model(config, dtype=dtype, device=device)
# build lm_head
self.lm_head = build_rowwise_linear(config.hidden_size,
config.vocab_size,
bias=False,
dtype=dtype,
device=device)
self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

def forward(
self,
Expand All @@ -322,15 +320,6 @@ def forward(
)
return hidden_states

def update_weights(self):
"""Update weights."""
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight

def get_logits(self, hidden_states: torch.Tensor):
"""Compute logits of the model output."""
return self.lm_head(hidden_states)

def get_input_embeddings(self):
"""Get input embeddings."""
return self.model.get_input_embeddings()
Expand Down
10 changes: 5 additions & 5 deletions lmdeploy/pytorch/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .patch import build_model_from_hf_config
from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixin, vlm_model
from .utils.model import DeployModelMixinV1, vlm_model


class Gating(nn.Module):
Expand Down Expand Up @@ -444,7 +444,7 @@ def forward(
return last_hidden_state


class InternVLChatModel(nn.Module, DeployModelMixin, CudaGraphMixin):
class InternVLChatModel(nn.Module, DeployModelMixinV1, CudaGraphMixin):

def __init__(self,
config: PretrainedConfig,
Expand Down Expand Up @@ -801,9 +801,9 @@ def forward(
position_ids=position_ids,
attn_metadata=attn_metadata)

def get_logits(self, hidden_states: torch.Tensor):
"""Compute logits of the model output."""
return self.language_model.get_logits(hidden_states)
def get_lm_head(self):
"""Get lm_head."""
return self.language_model.get_lm_head()

def get_input_embeddings(self):
"""Get input embeddings."""
Expand Down
10 changes: 5 additions & 5 deletions lmdeploy/pytorch/models/internvl3_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from .patch import build_model_from_hf_config
from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixin, vlm_model
from .utils.model import DeployModelMixinV1, vlm_model


@torch.compile(dynamic=True)
Expand Down Expand Up @@ -439,7 +439,7 @@ def forward(self, image_features):
return hidden_states


class InternVLForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin):
class InternVLForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin):

def __init__(self,
config: PretrainedConfig,
Expand Down Expand Up @@ -485,9 +485,9 @@ def _mark_dynamic_once(self, pixel_values, dims):
torch._dynamo.mark_dynamic(pixel_values, dims)
self.has_compiled_vit = True

def get_logits(self, hidden_states: torch.Tensor):
"""Compute logits of the model output."""
return self.language_model.get_logits(hidden_states)
def get_lm_head(self):
"""Get lm_head."""
return self.language_model.get_lm_head()

def get_input_embeddings(self):
"""Get input embeddings."""
Expand Down
21 changes: 4 additions & 17 deletions lmdeploy/pytorch/models/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul
from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj,
build_rowwise_linear)
from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj
from lmdeploy.pytorch.nn.rotary_embedding import build_rotary_embedding_from_config
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .utils.cudagraph import CudaGraphMixin
from .utils.model import DeployModelMixinV1


class Phi3Attention(nn.Module):
Expand Down Expand Up @@ -272,7 +272,7 @@ def get_input_embeddings(self):
return self.embed_tokens


class Phi3ForCausalLM(nn.Module, CudaGraphMixin):
class Phi3ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin):
"""ModelForCausalLM."""

packed_modules_mapping = {
Expand All @@ -293,11 +293,7 @@ def __init__(self,
# build model
self.model = Phi3Model(config, dtype=dtype, device=device)
# build lm_head
self.lm_head = build_rowwise_linear(config.hidden_size,
config.vocab_size,
bias=False,
dtype=dtype,
device=device)
self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

def forward(
self,
Expand All @@ -318,15 +314,6 @@ def forward(
)
return hidden_states

def get_logits(self, hidden_states: torch.Tensor):
"""Compute logits of the model output."""
return self.lm_head(hidden_states)

def update_weights(self):
"""Update weights."""
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight

def get_input_embeddings(self):
"""Get input embeddings."""
return self.model.get_input_embeddings()
Expand Down
11 changes: 3 additions & 8 deletions lmdeploy/pytorch/models/phi3_v.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor
from lmdeploy.pytorch.nn.linear import build_rowwise_linear
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight

from .phi3 import Phi3ForCausalLM, Phi3Model
from .utils.model import DeployModelMixin, vlm_model
from .utils.model import vlm_model

CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(attention_dropout=0.0,
dropout=0.0,
Expand Down Expand Up @@ -264,7 +263,7 @@ def forward(
)


class Phi3VForCausalLM(Phi3ForCausalLM, DeployModelMixin):
class Phi3VForCausalLM(Phi3ForCausalLM):

def __init__(self,
config: PretrainedConfig,
Expand All @@ -277,11 +276,7 @@ def __init__(self,
# build model
self.model = Phi3VModel(config, dtype=dtype, device=device)
# build lm_head
self.lm_head = build_rowwise_linear(config.hidden_size,
config.vocab_size,
bias=False,
dtype=dtype,
device=device)
self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device)

self.input_processor = Phi3VInputProcessor(config, dtype)

Expand Down
Loading