[moe] brings batch/sequence-wise load balance loss#2061
[moe] brings batch/sequence-wise load balance loss#2061rakkit wants to merge 5 commits intopytorch:mainfrom
Conversation
…d seq-wise aux loss for load balance
torchtitan/train.py
Outdated
| job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager | ||
| ) | ||
|
|
||
| self.loss_fn = functools.partial( |
There was a problem hiding this comment.
we can add a condition here to wrap loss or not for MoE. for now all models in torchtitan only return a single output so its ok for now
There was a problem hiding this comment.
If subsume this moe loss wrapper into build_loss_fn we can avoid adding the logic here.
torchtitan/config/job_config.py
Outdated
|
|
||
|
|
||
| @dataclass | ||
| class ExtraLosses: |
There was a problem hiding this comment.
This section is specifically for MoE load balancing loss for now, do you foresee any other loss related params will be used in this section? If not, let's make the name for descriptive and specific
There was a problem hiding this comment.
Followup here. Should we merge these configs to the Model dataclass?
torchtitan/train.py
Outdated
| job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager | ||
| ) | ||
|
|
||
| self.loss_fn = functools.partial( |
There was a problem hiding this comment.
If subsume this moe loss wrapper into build_loss_fn we can avoid adding the logic here.
|
Thanks a lot for the feedback, @wwwjn @shuhuayu (sorry for the late update)! Summary of new changes:
And be aware that the PP & aux-loss still does not work |
torchtitan/models/moe/moe.py
Outdated
| self.load_balance_loss_weight, | ||
| ) | ||
| else: | ||
| load_balance_loss = torch.tensor(0.0, device=out.device, dtype=out.dtype) |
There was a problem hiding this comment.
As far as I can see out is not defined in this scope yet.
| @staticmethod | ||
| def sequence_wise_aux_loss( | ||
| scores: torch.Tensor, | ||
| indices: torch.Tensor, |
There was a problem hiding this comment.
this will use the biased topk(scores + expert_bias) instead of the unbiased topk(scores) from DSv3 eq 18
There was a problem hiding this comment.
nope, thats top_scores
There was a problem hiding this comment.
ah yeah, scores is the raw sigmoid output. But isn't indices (= selected_experts_indices) derived as topk(scores + expert_bias)?
There was a problem hiding this comment.
emm, good question. need to think about this.
There was a problem hiding this comment.
i think you might be right, eq 18 the topk dont have "bias"
There was a problem hiding this comment.
thanks. I fixed this and rerun the two aux loss types and no aux loss in PR description.
tianyu-l
left a comment
There was a problem hiding this comment.
Thanks for the proposal and sorry for the late review.
It looks like this feature requires intrusive change and is worth some discussions on how to best support it.
| loss = loss_fn(pred, labels) | ||
| # Add auxiliary loss to the computation graph for gradients in the backward pass, | ||
| # but cancel out its numeric value so the forward pass only logs language model task loss. | ||
| loss = loss + (load_balance_loss - load_balance_loss.detach()) |
There was a problem hiding this comment.
This code looks too hacky. Curious why we don't want to log the full loss?
There was a problem hiding this comment.
if one needs to log load_balance_loss, a helpful way is to log it via moe's optimizer pre-step hook. (where we log everything about moe, e.g. bias, experts usage, entropy, lb loss etc).
And we dont need to hack the return of "loss" (which is a single value/tensor for dense/moe/diffusion model training).
and for people who want to run ablation study, the "loss" is the "clean" CE loss for training/validation
|
|
||
|
|
||
| @dataclass | ||
| class ExtraLosses: |
There was a problem hiding this comment.
This is part of the "model definition", so shouldn't be configurable from run-to-run (given the current limitation of torchtitan config system). They should be embedded in ModelArgs or more specifically MoEArgs.
| return loss_per_seq.mean() | ||
|
|
||
| @staticmethod | ||
| def batch_wise_aux_loss( |
There was a problem hiding this comment.
Do we need this batch one? For DSv3 it seems unnecessary. For Qwen3 it seems insufficient, which adopts a global batch load balancing but here it looks local / microbatch (see e.g. https://qwenlm.github.io/blog/global-load-balance/)
For simplicity let's start with seq_wise?
There was a problem hiding this comment.
sure, ideally we can make it (type of aux-loss) configurable
I am curious if people deal with a special case of document segmentation (for models trained with document mask attention, it's literally not "one" sequence on attention's side)
| ) | ||
|
|
||
| @staticmethod | ||
| def sequence_wise_aux_loss( |
There was a problem hiding this comment.
IIUC this returns a loss independent of DP, similar to the cross-entropy loss, in that each DP rank computes its own aux loss and do backward which eventually gets gradients reduced across DP ranks by FSDP.
In general, is there a way to verify the correctness of the implementation?
| h, accumulated_load_balance_loss = layer( | ||
| h, self.freqs_cis, accumulated_load_balance_loss, attention_masks | ||
| ) |
There was a problem hiding this comment.
Hmm, passing this per-layer loss along all the way to the final output sounds unnecessary. It sounds correct but is causing quite intrusive changes to the entire model code.
Putting PP aside, is it true that we can also achieve this via a buffer in each MoE module, similar to the expert bias? Specifically, putting the per-layer loss in a buffer, and in the loss function fetch the value and add them to the final loss.
Is this similar to the idea in #1979 when you say
Caching the per-layer aux_loss loss (which breaks compile, but not PP)
In what way it breaks compile? cc @bdhirsh @xmfan
With PP, I'm not sure what's the best way to do it. @H-Huang any suggestions?
There was a problem hiding this comment.
I'd be interested in what about caching auxilliary loss breaks compile as well.
This is probably not representative, but I have a basic working example here of using torch.compile to compile each layer of a transformer model here where I also compute and cache an extra loss onto each layer: https://gist.github.com/bdhirsh/2b59611d3070354af3f6364d9becaa08
There was a problem hiding this comment.
thx both.
. In what way it breaks compile?
If we register a buffer named "aux_loss", and use it saved aux loss value, and access it and sum it to the final loss. compiler will broken.
This is probably not representative, but I have a basic working example here of using torch.compile to compile each layer of a transformer model here where I also compute and cache an extra loss onto each layer: https://gist.github.com/bdhirsh/2b59611d3070354af3f6364d9becaa08
Here self.aux_loss is not pre-defined in the buffer. This seems to be a good solution. Let me run a few more tests to check.
There was a problem hiding this comment.
ok, I mean, the purpose of passing (accumulated) aux-Loss in forward is to deal with PP.
w/o pp, we have lots of clean solutions, including those you proposed.
For pp, if we put each block's aux loss along with the block (via buffer or whatever methods). At the backward the last stage will not be able to capture the aux-loss-of-block-i. For that we need either add the backward hook, that we manually hacking the back ward gradinet (which does not work well with compile as i have tested). Or we need manually add a communcation to gather all stages's aux-loss into last stage. (and we also need to think about micrio-batch things that we need some queue for aux-loss-of-block-i).
There was a problem hiding this comment.
Agreed that this is non-trivial. Let me think about it and get back to you.
|
Is there any update? |
This is a draft PR for:
For now, it only applies to the DeepSeek model, but I can add it for all other moe models at the end.
(also, we dont log the aux loss, but i can add it in optimizer hook to do this if you want)
The main concern is that the aux loss does not work well with PP. From what I have tested, it works well only with 1F1B. And it is broken for ZBV or interleaved 1f1b.
To test it:

[sequence_wise, by default]
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --training.extra_losses.load_balance_loss_weight=0.001[batch_wise, need to pick this in ModelArgs]

CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --training.extra_losses.load_balance_loss_weight=0.001(turn it off)

CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh