Skip to content
25 changes: 25 additions & 0 deletions src/llmcompressor/modifiers/awq/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,30 @@ class AWQMapping:
),
]

# AFMOE uses dual normalization: pre_mlp_layernorm feeds the MLP
# (not post_attention_layernorm) and attention has its own gate_proj
# for gating mechanism
_afmoe_mappings = [
AWQMapping(
"re:.*input_layernorm$",
[
"re:.*self_attn.q_proj$",
"re:.*self_attn.k_proj$",
"re:.*self_attn.v_proj$",
"re:.*self_attn.gate_proj$",
],
),
AWQMapping("re:.*v_proj$", ["re:.*o_proj$"]),
AWQMapping(
"re:.*pre_mlp_layernorm$",
["re:.*mlp.*gate_proj$", "re:.*mlp.*up_proj$"],
),
AWQMapping(
"re:.*up_proj$",
["re:.*down_proj$"],
),
]

# Example mapping for MoE models with parallel transformer blocks, where
# attention and MoE share the same input. This is the only case where
# activation_hook_target is needed. Without it, the hook lands on
Expand All @@ -217,6 +241,7 @@ class AWQMapping:
]

AWQ_MAPPING_REGISTRY: dict[str, list[AWQMapping]] = {
"AfmoeForCausalLM": _afmoe_mappings,
"BloomForCausalLM": _bloom_mappings,
"CohereForCausalLM": _cohere_mappings,
"Cohere2ForCausalLM": _cohere_mappings,
Expand Down
17 changes: 17 additions & 0 deletions src/llmcompressor/modifiers/transform/smoothquant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,22 @@
),
]

AFMOE_SMOOTHQUANT_MAPPINGS: list[LayerMap] = [
LayerMap(
balance_layers=[
"re:.*self_attn\\.q_proj",
"re:.*self_attn\\.k_proj",
"re:.*self_attn\\.v_proj",
"re:.*self_attn\\.gate_proj",
],
smooth_layers="re:.*input_layernorm",
),
LayerMap(
balance_layers=["re:.*mlp.*gate_proj", "re:.*mlp.*up_proj"],
smooth_layers="re:.*pre_mlp_layernorm",
),
]


# Registry of layer mappings for different architectures
# Add more mappings here
Expand All @@ -85,6 +101,7 @@
"Qwen2ForCausalLM": DEFAULT_SMOOTHQUANT_MAPPINGS,
"Qwen3ForCausalLM": DEFAULT_SMOOTHQUANT_MAPPINGS,
"WhisperForConditionalGeneration": WHISPER_V2_SMOOTHQUANT_MAPPINGS,
"AfmoeForCausalLM": AFMOE_SMOOTHQUANT_MAPPINGS,
}


Expand Down