-
Notifications
You must be signed in to change notification settings - Fork 63
Open
Description
Line 161 in 261859e
| self.normC2 = Fp32LayerNorm(dim, bias=False) |
Line 86 in 72feb0c
| self.w1o = nn.Linear(dim, dim, bias=False) |
Not used in last layer, should be moved into an if not last statement. Unused parameters make some distributed algos slow and sad: https://pytorch.org/docs/stable/notes/ddp.html#internal-design
Edit: Also, (unless I misread your code) you seem to only put the timestep embedding in the AdaLN scale/shift thingy, but the SD3 paper also puts a vector made from the image description in there. Did you find the former worked better?
Edit 2: Also also, did your muP optimization lead that far from a 1e^-4 learning rate? Can you share the results of your hparam search?
Metadata
Metadata
Assignees
Labels
No labels