diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 828899885d..a8692a0d36 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -304,6 +304,10 @@ class ModelConfig: # check env for model-device combination check_env_func: Callable = _default_check_env + # fp32 lm head + enforce_fp32_head: bool = False + tie_word_embeddings: bool = False + def get_head_size(self): """Get head size.""" return self.head_dim @@ -347,12 +351,17 @@ def from_pretrained( is_draft_model=is_draft_model, spec_method=spec_method, ) - + enforce_fp32_head = False if hf_overrides is not None: logger = get_logger('lmdeploy') logger.warning(f'Overriding HF config with {hf_overrides}') + enforce_fp32_head = hf_overrides.pop('enforce_fp32_head', False) override_hf_config(model_config.hf_config, hf_overrides) + # for fp32 head + model_config.enforce_fp32_head = enforce_fp32_head + model_config.tie_word_embeddings = getattr(hf_config, 'tie_word_embeddings', False) + # for serialization of transformers modules maybe_register_config_serialize_by_value(trust_remote_code) diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index f8c8cd226c..6753ce61d2 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -1045,6 +1045,8 @@ def _build_model(self): dllm_config=self.misc_config.dllm_config, strategy_factory=self.strategy_factory, enable_return_routed_experts=enable_return_routed_experts, + enforce_fp32_head=self.model_config.enforce_fp32_head, + tie_word_embeddings=self.model_config.tie_word_embeddings, ) patched_model = build_patched_model(self.model_config, device=device, diff --git a/lmdeploy/pytorch/model_inputs.py b/lmdeploy/pytorch/model_inputs.py index 669fe97225..ebd51f33a1 100644 --- a/lmdeploy/pytorch/model_inputs.py +++ b/lmdeploy/pytorch/model_inputs.py @@ -389,6 +389,8 @@ class BuildModelContext: dllm_config: DLLMConfig = None strategy_factory: 'StrategyFactoryBase' = None enable_return_routed_experts: bool = False + enforce_fp32_head: bool = False + tie_word_embeddings: bool = False class StepContextManager(CtxMgrBase[StepContext]): diff --git a/lmdeploy/pytorch/models/gpt_oss.py b/lmdeploy/pytorch/models/gpt_oss.py index 2a6329ce70..726139967c 100644 --- a/lmdeploy/pytorch/models/gpt_oss.py +++ b/lmdeploy/pytorch/models/gpt_oss.py @@ -10,11 +10,12 @@ from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, build_rotary_embedding_from_config -from lmdeploy.pytorch.nn.linear import build_o_proj, build_qkv_proj, build_rowwise_linear +from lmdeploy.pytorch.nn.linear import build_o_proj, build_qkv_proj from lmdeploy.pytorch.nn.moe import build_fused_moe from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixinV1, build_embedding class GptOssAttention(nn.Module): @@ -333,11 +334,14 @@ class GptOssModel(nn.Module): def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): super().__init__() - self.embed_tokens = nn.Embedding(config.vocab_size, - config.hidden_size, - config.pad_token_id, - dtype=dtype, - device=device) + + self.embed_tokens = build_embedding( + config.vocab_size, + config.hidden_size, + config.pad_token_id, + dtype=dtype, + device=device, + ) # build all decode layers self.layers = nn.ModuleList([ @@ -394,7 +398,7 @@ def get_input_embeddings(self): return self.embed_tokens -class GptOssForCausalLM(nn.Module, CudaGraphMixin): +class GptOssForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -416,11 +420,7 @@ def __init__(self, # build model self.model = GptOssModel(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, @@ -441,15 +441,6 @@ def forward( ) return hidden_states - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - - def update_weights(self): - """Update weights.""" - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index ffe693f9e7..a00943a1ec 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -8,11 +8,11 @@ from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config -from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj, - build_rowwise_linear) +from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixinV1, build_embedding class InternLM2Attention(nn.Module): @@ -208,12 +208,13 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.tok_embeddings = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) - + self.tok_embeddings = build_embedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device, + ) # build all decode layers self.layers = nn.ModuleList([ InternLM2DecoderLayer(config, layer_idx, dtype=dtype, device=device) @@ -269,7 +270,7 @@ def get_input_embeddings(self): return self.tok_embeddings -class InternLM2ForCausalLM(nn.Module, CudaGraphMixin): +class InternLM2ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin): """Rewrote model of InternLM2ForCausalLM.""" packed_modules_mapping = { @@ -290,11 +291,7 @@ def __init__(self, # build Model self.model = InternLM2Model(config, dtype=dtype, device=device) # build lm_head - self.output = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.output = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, diff --git a/lmdeploy/pytorch/models/internlm3.py b/lmdeploy/pytorch/models/internlm3.py index d3bbc6830b..3005232c06 100644 --- a/lmdeploy/pytorch/models/internlm3.py +++ b/lmdeploy/pytorch/models/internlm3.py @@ -8,11 +8,11 @@ from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config -from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj, - build_rowwise_linear) +from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixinV1, build_embedding class InternLM3Attention(nn.Module): @@ -210,11 +210,13 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) + self.embed_tokens = build_embedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device, + ) # build all decode layers self.layers = nn.ModuleList([ @@ -271,7 +273,7 @@ def get_input_embeddings(self): return self.embed_tokens -class InternLM3ForCausalLM(nn.Module, CudaGraphMixin): +class InternLM3ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin): """Rewrote model of InternLM3ForCausalLM.""" packed_modules_mapping = { @@ -297,11 +299,7 @@ def __init__(self, # build InternLM3Model self.model = InternLM3Model(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, @@ -322,15 +320,6 @@ def forward( ) return hidden_states - def update_weights(self): - """Update weights.""" - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 2dbd9f9f3e..233ebabb99 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -19,7 +19,7 @@ from .patch import build_model_from_hf_config from .utils.cudagraph import CudaGraphMixin -from .utils.model import DeployModelMixin, vlm_model +from .utils.model import DeployModelMixinV1, vlm_model class Gating(nn.Module): @@ -444,7 +444,7 @@ def forward( return last_hidden_state -class InternVLChatModel(nn.Module, DeployModelMixin, CudaGraphMixin): +class InternVLChatModel(nn.Module, DeployModelMixinV1, CudaGraphMixin): def __init__(self, config: PretrainedConfig, @@ -801,9 +801,9 @@ def forward( position_ids=position_ids, attn_metadata=attn_metadata) - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.language_model.get_logits(hidden_states) + def get_lm_head(self): + """Get lm_head.""" + return self.language_model.get_lm_head() def get_input_embeddings(self): """Get input embeddings.""" diff --git a/lmdeploy/pytorch/models/internvl3_hf.py b/lmdeploy/pytorch/models/internvl3_hf.py index 6e760dbeac..0a151d8ed3 100644 --- a/lmdeploy/pytorch/models/internvl3_hf.py +++ b/lmdeploy/pytorch/models/internvl3_hf.py @@ -20,7 +20,7 @@ from .patch import build_model_from_hf_config from .utils.cudagraph import CudaGraphMixin -from .utils.model import DeployModelMixin, vlm_model +from .utils.model import DeployModelMixinV1, vlm_model @torch.compile(dynamic=True) @@ -439,7 +439,7 @@ def forward(self, image_features): return hidden_states -class InternVLForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin): +class InternVLForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin): def __init__(self, config: PretrainedConfig, @@ -485,9 +485,9 @@ def _mark_dynamic_once(self, pixel_values, dims): torch._dynamo.mark_dynamic(pixel_values, dims) self.has_compiled_vit = True - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.language_model.get_logits(hidden_states) + def get_lm_head(self): + """Get lm_head.""" + return self.language_model.get_lm_head() def get_input_embeddings(self): """Get input embeddings.""" diff --git a/lmdeploy/pytorch/models/phi3.py b/lmdeploy/pytorch/models/phi3.py index 24dd504522..c49b01737d 100644 --- a/lmdeploy/pytorch/models/phi3.py +++ b/lmdeploy/pytorch/models/phi3.py @@ -8,12 +8,12 @@ from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul -from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj, - build_rowwise_linear) +from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj from lmdeploy.pytorch.nn.rotary_embedding import build_rotary_embedding_from_config from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixinV1 class Phi3Attention(nn.Module): @@ -272,7 +272,7 @@ def get_input_embeddings(self): return self.embed_tokens -class Phi3ForCausalLM(nn.Module, CudaGraphMixin): +class Phi3ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -293,11 +293,7 @@ def __init__(self, # build model self.model = Phi3Model(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, @@ -318,15 +314,6 @@ def forward( ) return hidden_states - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - - def update_weights(self): - """Update weights.""" - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/phi3_v.py b/lmdeploy/pytorch/models/phi3_v.py index a70d0bfecc..c6804d5586 100644 --- a/lmdeploy/pytorch/models/phi3_v.py +++ b/lmdeploy/pytorch/models/phi3_v.py @@ -9,11 +9,10 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor -from lmdeploy.pytorch.nn.linear import build_rowwise_linear from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .phi3 import Phi3ForCausalLM, Phi3Model -from .utils.model import DeployModelMixin, vlm_model +from .utils.model import vlm_model CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(attention_dropout=0.0, dropout=0.0, @@ -264,7 +263,7 @@ def forward( ) -class Phi3VForCausalLM(Phi3ForCausalLM, DeployModelMixin): +class Phi3VForCausalLM(Phi3ForCausalLM): def __init__(self, config: PretrainedConfig, @@ -277,11 +276,7 @@ def __init__(self, # build model self.model = Phi3VModel(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) self.input_processor = Phi3VInputProcessor(config, dtype) diff --git a/lmdeploy/pytorch/models/qwen2.py b/lmdeploy/pytorch/models/qwen2.py index c13c805f88..ddd5b4dec9 100644 --- a/lmdeploy/pytorch/models/qwen2.py +++ b/lmdeploy/pytorch/models/qwen2.py @@ -8,11 +8,11 @@ from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config -from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj, - build_rowwise_linear) +from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixinV1, build_embedding class Qwen2Attention(nn.Module): @@ -208,11 +208,13 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) + self.embed_tokens = build_embedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device, + ) # build all decode layers self.layers = nn.ModuleList([ @@ -269,7 +271,7 @@ def get_input_embeddings(self): return self.embed_tokens -class Qwen2ForCausalLM(nn.Module, CudaGraphMixin): +class Qwen2ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -295,11 +297,7 @@ def __init__(self, # build model self.model = Qwen2Model(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, @@ -320,15 +318,6 @@ def forward( ) return hidden_states - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - - def update_weights(self): - """Update weights.""" - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/qwen2_5_vl.py b/lmdeploy/pytorch/models/qwen2_5_vl.py index 11d2948b7a..b934d4142f 100644 --- a/lmdeploy/pytorch/models/qwen2_5_vl.py +++ b/lmdeploy/pytorch/models/qwen2_5_vl.py @@ -18,7 +18,7 @@ from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin -from .utils.model import DeployModelMixin, vlm_model +from .utils.model import DeployModelMixinV1, vlm_model class Qwen2_5_PatchEmbed(nn.Module): @@ -366,7 +366,7 @@ def forward(self, return hidden_states -class Qwen2_5_VLForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin): +class Qwen2_5_VLForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -402,11 +402,7 @@ def __init__(self, # build model self.model = Qwen2Model(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, @@ -447,15 +443,6 @@ def forward( ) return hidden_states - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - - def update_weights(self): - """Update weights.""" - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/qwen2_moe.py b/lmdeploy/pytorch/models/qwen2_moe.py index b62e5df4bf..c86608fc12 100644 --- a/lmdeploy/pytorch/models/qwen2_moe.py +++ b/lmdeploy/pytorch/models/qwen2_moe.py @@ -16,6 +16,7 @@ from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixinV1, build_embedding class Qwen2MoeAttention(nn.Module): @@ -309,11 +310,13 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) + self.embed_tokens = build_embedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device, + ) # build all decode layers self.layers = nn.ModuleList([ @@ -370,7 +373,7 @@ def get_input_embeddings(self): return self.embed_tokens -class Qwen2MoeForCausalLM(nn.Module, CudaGraphMixin): +class Qwen2MoeForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -396,11 +399,7 @@ def __init__(self, # build model self.model = Qwen2MoeModel(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, @@ -421,10 +420,6 @@ def forward( ) return hidden_states - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/qwen2_vl.py b/lmdeploy/pytorch/models/qwen2_vl.py index 77e025c638..2a473a893c 100644 --- a/lmdeploy/pytorch/models/qwen2_vl.py +++ b/lmdeploy/pytorch/models/qwen2_vl.py @@ -16,7 +16,7 @@ from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin -from .utils.model import DeployModelMixin, vlm_model +from .utils.model import DeployModelMixinV1, build_embedding, vlm_model def _apply_mrope_selection(hidden_states: torch.Tensor, mrope_position_ids: torch.Tensor, mrope_section: List[int], @@ -235,11 +235,13 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.vocab_size = config.vocab_size self.mrope_section = config.rope_scaling['mrope_section'] - self.embed_tokens = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) + self.embed_tokens = build_embedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device, + ) # build all decode layers self.layers = nn.ModuleList([ @@ -592,7 +594,7 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, return self.merger(hidden_states) -class Qwen2VLForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin): +class Qwen2VLForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -628,11 +630,7 @@ def __init__(self, # build model self.model = Qwen2Model(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, @@ -668,15 +666,6 @@ def forward( ) return hidden_states - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - - def update_weights(self): - """Update weights.""" - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/qwen3.py b/lmdeploy/pytorch/models/qwen3.py index 47121af1fe..4fcbaefeba 100644 --- a/lmdeploy/pytorch/models/qwen3.py +++ b/lmdeploy/pytorch/models/qwen3.py @@ -8,11 +8,11 @@ from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, RMSNorm, SiluAndMul, build_rotary_embedding_from_config -from lmdeploy.pytorch.nn.linear import (build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj, - build_rowwise_linear) +from lmdeploy.pytorch.nn.linear import build_down_linear, build_gateup_linear, build_o_proj, build_qkv_proj from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixinV1, build_embedding class Qwen3Attention(nn.Module): @@ -217,11 +217,13 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) + self.embed_tokens = build_embedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device, + ) # build all decode layers self.layers = nn.ModuleList([ @@ -278,7 +280,7 @@ def get_input_embeddings(self): return self.embed_tokens -class Qwen3ForCausalLM(nn.Module, CudaGraphMixin): +class Qwen3ForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -304,11 +306,7 @@ def __init__(self, # build model self.model = Qwen3model(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, @@ -329,15 +327,6 @@ def forward( ) return hidden_states - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - - def update_weights(self): - """Update weights.""" - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/qwen3_moe.py b/lmdeploy/pytorch/models/qwen3_moe.py index 9d50cbb86b..b54fb24c91 100644 --- a/lmdeploy/pytorch/models/qwen3_moe.py +++ b/lmdeploy/pytorch/models/qwen3_moe.py @@ -16,6 +16,7 @@ from .patch import get_build_model_context from .utils.cudagraph import CudaGraphMixin +from .utils.model import DeployModelMixinV1, build_embedding class Qwen3MoeAttention(nn.Module): @@ -316,12 +317,13 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) + self.embed_tokens = build_embedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device, + ) if get_dist_manager().current_context().dist_config.enable_eplb: ep_size, _ = get_ep_world_rank() @@ -388,7 +390,7 @@ def get_input_embeddings(self): return self.embed_tokens -class Qwen3MoeForCausalLM(nn.Module, CudaGraphMixin): +class Qwen3MoeForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -415,11 +417,7 @@ def __init__(self, # build model self.model = Qwen3MoeModel(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) # for router replay bm_ctx = get_build_model_context() self.enable_return_routed_experts = bm_ctx.enable_return_routed_experts @@ -457,10 +455,6 @@ def forward( return hidden_states return dict(hidden_states=hidden_states, all_routed_experts=all_routed_experts) - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/qwen3_next.py b/lmdeploy/pytorch/models/qwen3_next.py index 49420bcb4e..62f3fb6472 100644 --- a/lmdeploy/pytorch/models/qwen3_next.py +++ b/lmdeploy/pytorch/models/qwen3_next.py @@ -17,6 +17,7 @@ from lmdeploy.pytorch.weight_loader.model_weight_loader import default_weight_loader, load_weight from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin +from .utils.model import DeployModelMixinV1, build_embedding class GatedDeltaMeta: @@ -811,12 +812,13 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) + self.embed_tokens = build_embedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device, + ) # build all decode layers # TODO: use full config.num_hidden_layers @@ -879,7 +881,7 @@ def get_input_embeddings(self): return self.embed_tokens -class Qwen3NextForCausalLM(nn.Module, CudaGraphMixin): +class Qwen3NextForCausalLM(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -905,11 +907,7 @@ def __init__(self, # build model self.model = Qwen3NextModel(config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.hidden_size, - config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.hidden_size, config.vocab_size, bias=False, dtype=dtype, device=device) def forward( self, @@ -932,10 +930,6 @@ def forward( ) return hidden_states - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - def get_input_embeddings(self): """Get input embeddings.""" return self.model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py index 60c3617ffe..610001f7d7 100644 --- a/lmdeploy/pytorch/models/qwen3_vl.py +++ b/lmdeploy/pytorch/models/qwen3_vl.py @@ -20,7 +20,7 @@ from .qwen2_5_vl import Qwen2_5_VLVisionAttention as Qwen3VLVisionAttention from .qwen3 import Qwen3model from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin -from .utils.model import DeployModelMixin, vlm_model +from .utils.model import DeployModelMixinV1, vlm_model class Qwen3VLTextRotaryEmbedding(nn.Module): @@ -447,7 +447,7 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_ return hidden_states, deepstack_feature_lists -class Qwen3VLForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin): +class Qwen3VLForConditionalGeneration(nn.Module, DeployModelMixinV1, CudaGraphMixin): """ModelForCausalLM.""" packed_modules_mapping = { @@ -485,11 +485,11 @@ def __init__(self, self.language_model = Qwen3VLTextModel(config.text_config, dtype=dtype, device=device) # build lm_head - self.lm_head = build_rowwise_linear(config.text_config.hidden_size, - config.text_config.vocab_size, - bias=False, - dtype=dtype, - device=device) + self.lm_head = self.build_lm_head(config.text_config.hidden_size, + config.text_config.vocab_size, + bias=False, + dtype=dtype, + device=device) def forward( self, @@ -549,15 +549,6 @@ def forward( ) return hidden_states - def get_logits(self, hidden_states: torch.Tensor): - """Compute logits of the model output.""" - return self.lm_head(hidden_states) - - def update_weights(self): - """Update weights.""" - if self.config.tie_word_embeddings: - self.lm_head.weight = self.language_model.embed_tokens.weight - def get_input_embeddings(self): """Get input embeddings.""" return self.language_model.get_input_embeddings() diff --git a/lmdeploy/pytorch/models/utils/model.py b/lmdeploy/pytorch/models/utils/model.py index 0b7938db49..ce4bcbd9c5 100644 --- a/lmdeploy/pytorch/models/utils/model.py +++ b/lmdeploy/pytorch/models/utils/model.py @@ -6,6 +6,9 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor from lmdeploy.pytorch.model_inputs import StepContext +from lmdeploy.pytorch.models.patch import get_build_model_context +from lmdeploy.pytorch.nn.embedding import ParallelEmbedding +from lmdeploy.pytorch.nn.linear import build_rowwise_linear class DeployModelMixin: @@ -51,13 +54,56 @@ def get_input_processor(self) -> BaseModelInputProcessor: return None +class DeployModelMixinV1(DeployModelMixin): + + def get_logits(self, hidden_states: torch.Tensor): + """Compute logits of the model output.""" + head_dtype = self.get_lm_head().weight.dtype + if hidden_states.dtype != head_dtype: + hidden_states = hidden_states.to(dtype=head_dtype) + hidden_states = self.get_lm_head()(hidden_states) + return hidden_states + + def get_lm_head(self): + """Get lm_head.""" + return self.lm_head + + def get_input_embeddings(self): + """Get embeds.""" + raise NotImplementedError('Not Implemented') + + def update_weights(self): + """Update weights.""" + if getattr(self.config, 'tie_word_embeddings', False): + self.get_lm_head().weight = self.get_input_embeddings().weight + + def build_lm_head(self, + hidden_size: int, + vocab_size: int, + bias: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + **kwargs): + """Build LM Head.""" + bm_ctx = get_build_model_context() + head_dtype = torch.float32 if bm_ctx.enforce_fp32_head else dtype + lm_head = build_rowwise_linear( + hidden_size, + vocab_size, + bias, + dtype=head_dtype, + device=device, + **kwargs, + ) + return lm_head + + def vlm_model(vlm_cls): if not issubclass(vlm_cls, torch.nn.Module): raise ValueError('Only subclasses of nn.Module can be decorated with @vlm_model.') @functools.wraps(vlm_cls) def wrapper(*args, **kwargs): - from lmdeploy.pytorch.models.patch import get_build_model_context bm_ctx = get_build_model_context() disable_vision_encoder = bm_ctx.disable_vision_encoder if disable_vision_encoder: @@ -68,3 +114,30 @@ def wrapper(*args, **kwargs): return vlm_cls(*args, **kwargs) return wrapper + + +def build_embedding(vocab_size: int, + hidden_size: int, + padding_idx: int, + dtype: torch.dtype = None, + device: torch.device = None, + is_tp: bool = False, + **kwargs): + """Build embedding.""" + bm_ctx = get_build_model_context() + + # run with fp32 only when share weights with lm_head + force_dtype = None + if bm_ctx.enforce_fp32_head and bm_ctx.tie_word_embeddings: + force_dtype = torch.float32 + + return ParallelEmbedding( + vocab_size, + hidden_size, + padding_idx, + dtype=dtype, + device=device, + is_tp=is_tp, + force_dtype=force_dtype, + **kwargs, + ) diff --git a/lmdeploy/pytorch/nn/embedding.py b/lmdeploy/pytorch/nn/embedding.py index cd9f5b8086..d9a5c17814 100644 --- a/lmdeploy/pytorch/nn/embedding.py +++ b/lmdeploy/pytorch/nn/embedding.py @@ -16,15 +16,18 @@ def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> class ParallelEmbedding(nn.Module): - def __init__(self, - vocab_size: int, - hidden_size: int, - padding_idx: int, - dtype: torch.dtype = None, - device: torch.device = None, - is_tp: bool = False, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, - layer_type: str = 'attn'): + def __init__( + self, + vocab_size: int, + hidden_size: int, + padding_idx: int, + dtype: torch.dtype = None, + device: torch.device = None, + is_tp: bool = False, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + layer_type: str = 'attn', + force_dtype: torch.dtype = None, + ): self.dist_ctx = get_dist_manager().current_context() super().__init__() @@ -52,9 +55,11 @@ def __init__(self, else: self.vocab_size_padded = self.vocab_size + self.out_dtype = dtype self.start_index = self.rank * self.vocab_size_padded self.end_index = (self.rank + 1) * self.vocab_size_padded - self.register_parameter('weight', self.create_weight(self.vocab_size_padded, hidden_size, dtype, device)) + weight_dtype = force_dtype or dtype + self.register_parameter('weight', self.create_weight(self.vocab_size_padded, hidden_size, weight_dtype, device)) self.weight.weight_loader = self.weight_loader backend = get_backend() @@ -98,4 +103,7 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor): self.weight[self.padding_idx - self.start_index] = 0 def forward(self, x: torch.Tensor): - return self.impl.forward(x, self.weight, all_reduce=self.all_reduce, group=self.tp_group) + embeddings = self.impl.forward(x, self.weight, all_reduce=self.all_reduce, group=self.tp_group) + if self.out_dtype is not None and embeddings.dtype != self.out_dtype: + embeddings = embeddings.to(dtype=self.out_dtype) + return embeddings