Skip to content

Update AFMoE architecture to use v5-style MoE impl#44063

Open
AutumnAurelium wants to merge 11 commits intohuggingface:mainfrom
AutumnAurelium:main
Open

Update AFMoE architecture to use v5-style MoE impl#44063
AutumnAurelium wants to merge 11 commits intohuggingface:mainfrom
AutumnAurelium:main

Conversation

@AutumnAurelium
Copy link

What does this PR do?

This brings the Arcee AFMoE architecture in line with other MoE models' implementation patterns since v5. It also adds integration testing using Trinity Nano.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @Cyrilvallez

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good thanks for updating!

@AutumnAurelium
Copy link
Author

@ArthurZucker @Cyrilvallez

Any update on getting this merged? Fixed problems mentioned above.

@winglian
Copy link
Collaborator

winglian commented Mar 3, 2026

run-slow: afmoe

Copy link
Collaborator

@winglian winglian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@winglian winglian requested a review from ArthurZucker March 3, 2026 17:13
@winglian
Copy link
Collaborator

winglian commented Mar 4, 2026

confirmed that model trains in axolotl, as well as loads experts as expected;

>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model_id = "arcee-ai/Trinity-Nano-Base"
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)
>>> model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float16, device_map="auto")
Loading weights: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1003/1003 [00:01<00:00, 948.02it/s, Materializing param=model.norm.weight]
>>> messages = [{"role":"user","content":"tell me about the culinary holy trinity"}]
>>> inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
>>> outputs = model.generate(**inputs, max_new_tokens=400, temperature=0.6, top_p=0.95, top_k=20, do_sample=True)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
user
tell me about the culinary holy trinity
assistant
The culinary holy trinity refers to a combination of three key ingredients that are essential in many dishes. The trinity varies depending on the cuisine, but a common trinity includes onions, bell peppers, and celery. This trinity is used in dishes like Cajun and Creole cuisine, where it forms the base for many sauces and stews.
>>> model.model.layers[2]
AfmoeDecoderLayer(
	(self_attn): AfmoeAttention(
		(q_proj): Linear(in_features=1024, out_features=1024, bias=False)
		(k_proj): Linear(in_features=1024, out_features=256, bias=False)
		(v_proj): Linear(in_features=1024, out_features=256, bias=False)
		(o_proj): Linear(in_features=1024, out_features=1024, bias=False)
		(q_norm): AfmoeRMSNorm((128,), eps=1e-05)
		(k_norm): AfmoeRMSNorm((128,), eps=1e-05)
		(gate_proj): Linear(in_features=1024, out_features=1024, bias=False)
		(rotary_fn): Func()
	)
	(input_layernorm): AfmoeRMSNorm((1024,), eps=1e-05)
	(post_attention_layernorm): AfmoeRMSNorm((1024,), eps=1e-05)
	(pre_mlp_layernorm): AfmoeRMSNorm((1024,), eps=1e-05)
	(post_mlp_layernorm): AfmoeRMSNorm((1024,), eps=1e-05)
	(mlp): AfmoeMoE(
		(router): AfmoeTokenChoiceRouter(
			(gate): Linear(in_features=1024, out_features=128, bias=False)
		)
		(shared_experts): AfmoeMLP(
			(gate_proj): Linear(in_features=1024, out_features=256, bias=False)
			(up_proj): Linear(in_features=1024, out_features=256, bias=False)
			(down_proj): Linear(in_features=256, out_features=1024, bias=False)
			(act_fn): SiLUActivation()
		)
		(experts): AfmoeExperts(
			(act_fn): SiLUActivation()
		)
	)
)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM just let's leverage modular in that case the MOE is standard can be inherited!

return final_hidden_states


class AfmoeMoE(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty sure you can now inherit this from another class! can you try ? 🤗

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker Should this also be named AfmoeSparseMoeBlock for consistency?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep perfect

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: afmoe

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants