Skip to content
Merged
1 change: 1 addition & 0 deletions src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"minimax_m2": "mixtral",
# Qwen2-style MoE
"qwen2_moe": "qwen2_moe",
"afmoe": "qwen2_moe",
"deepseek_v2": "qwen2_moe",
"deepseek_v3": "qwen2_moe",
"dots1": "qwen2_moe",
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/afmoe/configuration_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class AfmoeConfig(PreTrainedConfig):
num_experts_per_tok: int | None = 6
num_shared_experts: int | None = 2
route_scale: float | None = 1.0
output_router_logits: bool = False
global_attn_every_n_layers: int | None = 4
sliding_window: int | None = 1024
layer_types: list | None = None
Expand Down
142 changes: 70 additions & 72 deletions src/transformers/models/afmoe/modeling_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,21 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
from ...integrations import (
use_experts_implementation,
use_kernel_forward_from_hub,
use_kernel_func_from_hub,
use_kernelized_func,
)
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
from ...utils.generic import maybe_autocast, merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
from ...utils.output_capturing import OutputRecorder, capture_outputs
from .configuration_afmoe import AfmoeConfig


Expand Down Expand Up @@ -162,79 +167,58 @@ def forward(self, hidden_states: torch.Tensor, expert_bias: torch.Tensor):
_, _, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)

scores = torch.sigmoid(self.gate(hidden_states).to(torch.float32))
router_logits = self.gate(hidden_states).to(torch.float32)
scores = torch.sigmoid(router_logits)

_, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1)
top_scores = scores.gather(dim=1, index=selected_experts)
denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20
top_scores = top_scores / denominator
top_scores = top_scores * self.route_scale
return top_scores, selected_experts
return router_logits, top_scores, selected_experts


class AfmoeExperts(nn.ModuleList):
"""
Container holding the routed experts.

This mirrors the Experts pattern used across other MoE models to ease checkpoint conversion.
"""
@use_experts_implementation
class AfmoeExperts(nn.Module):
"""Collection of expert weights stored as 3D tensors."""

def __init__(self, config: AfmoeConfig):
def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_experts
for _ in range(self.num_experts):
self.append(AfmoeMLP(config, intermediate_size=config.moe_intermediate_size))
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.moe_intermediate_size
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
self.act_fn = ACT2FN[config.hidden_act]

def forward(
self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
"""
Args:
hidden_states: (batch, seq, hidden)
selected_experts: (batch, seq, top_k)
routing_weights: (batch, seq, top_k)
"""
batch_size, seq_len, hidden_dim = hidden_states.shape
if seq_len == 0:
return hidden_states.new_zeros(batch_size, 0, hidden_dim)
hidden_states_flat = hidden_states.view(-1, hidden_dim)
top_k = selected_experts.shape[-1]

# Map every token routing decision to a unique position so we can process expert by expert.
token_indices = torch.arange(
hidden_states_flat.shape[0], device=hidden_states.device, dtype=torch.long
).repeat_interleave(top_k)
expert_indices = selected_experts.reshape(-1)
routing_weights = routing_weights.reshape(-1)

sorting = torch.argsort(expert_indices, stable=True)
token_indices = token_indices[sorting]
expert_indices = expert_indices[sorting]
routing_weights = routing_weights[sorting]

dispatched_tokens = hidden_states_flat.index_select(0, token_indices)
expert_outputs = torch.zeros_like(dispatched_tokens)

unique_experts, counts = torch.unique_consecutive(expert_indices, return_counts=True)
start = 0
for expert_id, count in zip(unique_experts.tolist(), counts.tolist()):
if count == 0:
final_hidden_states = torch.zeros_like(hidden_states)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == self.num_experts:
continue
end = start + count
expert_input = dispatched_tokens[start:end]
expert_output = self[expert_id](expert_input)
expert_outputs[start:end] = expert_output
start = end
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

weighted_outputs = (expert_outputs.to(torch.float32) * routing_weights.unsqueeze(-1)).to(hidden_states.dtype)
aggregated = torch.zeros_like(hidden_states_flat)
scatter_indices = token_indices.unsqueeze(-1).expand_as(weighted_outputs)
aggregated.scatter_add_(0, scatter_indices, weighted_outputs)
return aggregated.view(batch_size, seq_len, hidden_dim)
return final_hidden_states


class AfmoeMoE(nn.Module):
class AfmoeSparseMoeBlock(nn.Module):
"""
Mixture of Experts (MoE) module for AFMoE.

Expand All @@ -254,14 +238,14 @@ def forward(self, hidden_states):
batch_size, seq_len, hidden_dim = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, hidden_dim)

# Get routing decisions
top_scores, selected_experts = self.router(hidden_states, self.expert_bias)
top_scores = top_scores.view(batch_size, seq_len, self.config.num_experts_per_tok)
selected_experts = selected_experts.view(batch_size, seq_len, self.config.num_experts_per_tok)
# Get routing decisions (returns flattened top-k)
router_logits, top_scores, selected_experts = self.router(hidden_states, self.expert_bias)

# Process through shared experts
shared_output = self.shared_experts(hidden_states_flat).view(batch_size, seq_len, hidden_dim)
routed_output = self.experts(hidden_states, selected_experts, top_scores)
routed_output = self.experts(hidden_states_flat, selected_experts, top_scores).view(
batch_size, seq_len, hidden_dim
)
return shared_output + routed_output


Expand Down Expand Up @@ -452,7 +436,7 @@ def __init__(self, config: AfmoeConfig, layer_idx: int):
# MoE or dense FFN
self.moe_enabled = layer_idx >= config.num_dense_layers
if self.moe_enabled:
self.mlp = AfmoeMoE(config)
self.mlp = AfmoeSparseMoeBlock(config)
else:
self.mlp = AfmoeMLP(config)

Expand Down Expand Up @@ -503,6 +487,7 @@ class AfmoePreTrainedModel(PreTrainedModel):
_no_split_modules = ["AfmoeDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_can_record_outputs = {
"router_logits": OutputRecorder(AfmoeTokenChoiceRouter, index=0),
"hidden_states": AfmoeDecoderLayer,
"attentions": AfmoeAttention,
}
Expand All @@ -519,15 +504,22 @@ class AfmoePreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flash_attn = True
_supports_flex_attn = True
_can_compile_fullgraph = (
is_grouped_mm_available()
) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
_supports_attention_backend = True
supports_gradient_checkpointing = True

def _init_weights(self, module):
"""Initialize the weights"""
super()._init_weights(module)
if isinstance(module, AfmoeTokenChoiceRouter):
std = self.config.initializer_range
if isinstance(module, AfmoeExperts):
init.normal_(module.gate_up_proj, mean=0.0, std=std)
init.normal_(module.down_proj, mean=0.0, std=std)
elif isinstance(module, AfmoeTokenChoiceRouter):
init.zeros_(module.gate.weight)
elif isinstance(module, AfmoeMoE):
elif isinstance(module, AfmoeSparseMoeBlock):
init.zeros_(module.expert_bias)


Expand Down Expand Up @@ -645,9 +637,10 @@ def forward(
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
output_router_logits: bool | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
) -> MoeCausalLMOutputWithPast:
r"""
Example:

Expand All @@ -665,31 +658,36 @@ def forward(
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
outputs: BaseModelOutputWithPast = self.model(
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)

outputs: MoeModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_router_logits=output_router_logits,
**kwargs,
)

hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])

loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)

return CausalLMOutputWithPast(
return MoeCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)


Expand Down
Loading
Loading