Skip to content

Commit 3d55d00

Browse files
committed
fix SBP check of Custom MoE
1 parent 83ec88f commit 3d55d00

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

modules/Nncase.Modules.NTT/Evaluator/CustomOp/NTT/SparseExperts.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,25 +165,25 @@ private bool CheckCustomSBP(
165165
{
166166
if (q is DistributedType a && gate is DistributedType b && down is DistributedType c && up is DistributedType d)
167167
{
168-
if (Enumerable.Range(0, a.TensorType.Shape.Rank).Any(i => DistributedUtility.IsSamePolicy(a.AxisPolicies[i], se.QSBPs[i], checkGranularity: false)))
168+
if (Enumerable.Range(0, a.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(a.AxisPolicies[i], se.QSBPs[i], checkGranularity: false)))
169169
{
170170
Console.WriteLine($"[SparseExperts] Q SBP not match: {string.Join(",", a.AxisPolicies.Select(p => p.ToString()))} != {string.Join(",", se.QSBPs.Select(p => p.ToString()))}");
171171
return false;
172172
}
173173

174-
if (Enumerable.Range(0, b.TensorType.Shape.Rank).Any(i => DistributedUtility.IsSamePolicy(b.AxisPolicies[i], se.GateSBPs[i], checkGranularity: false)))
174+
if (Enumerable.Range(0, b.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(b.AxisPolicies[i], se.GateSBPs[i], checkGranularity: false)))
175175
{
176176
Console.WriteLine($"[SparseExperts] Gate SBP not match: {string.Join(",", b.AxisPolicies.Select(p => p.ToString()))} != {string.Join(",", se.GateSBPs.Select(p => p.ToString()))}");
177177
return false;
178178
}
179179

180-
if (Enumerable.Range(0, c.TensorType.Shape.Rank).Any(i => DistributedUtility.IsSamePolicy(c.AxisPolicies[i], se.DownSBPs[i], checkGranularity: false)))
180+
if (Enumerable.Range(0, c.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(c.AxisPolicies[i], se.DownSBPs[i], checkGranularity: false)))
181181
{
182182
Console.WriteLine($"[SparseExperts] Down SBP not match: {string.Join(",", c.AxisPolicies.Select(p => p.ToString()))} != {string.Join(",", se.DownSBPs.Select(p => p.ToString()))}");
183183
return false;
184184
}
185185

186-
if (Enumerable.Range(0, d.TensorType.Shape.Rank).Any(i => DistributedUtility.IsSamePolicy(d.AxisPolicies[i], se.UpSBPs[i], checkGranularity: false)))
186+
if (Enumerable.Range(0, d.TensorType.Shape.Rank).Any(i => !DistributedUtility.IsSamePolicy(d.AxisPolicies[i], se.UpSBPs[i], checkGranularity: false)))
187187
{
188188
Console.WriteLine($"[SparseExperts] Up SBP not match: {string.Join(",", d.AxisPolicies.Select(p => p.ToString()))} != {string.Join(",", se.UpSBPs.Select(p => p.ToString()))}");
189189
return false;

0 commit comments

Comments
 (0)