From 8cb99ac03ef37b8fdaf88f59f5d2f8bd170795da Mon Sep 17 00:00:00 2001 From: starrryz <760668919@qq.com> Date: Wed, 19 Nov 2025 01:53:32 +0000 Subject: [PATCH 1/3] Revert "Revert "autodistribute optimize (#1461)" (#1465)" This reverts commit 23b6ba4f9bfb9b9648ef44afc8e21e42b71061c1. --- .../Evaluator/NTT/VectorizedRoPE.cs | 35 ++++++-- src/Nncase.Core/IR/BaseExpr.cs | 36 +++++++- src/Nncase.Core/IR/Var.cs | 4 +- src/Nncase.Evaluator/NN/RoPE.cs | 35 ++++++-- .../BufferSchedule/LifeTimeCollector.cs | 8 +- .../Distributed/AutoDistributed.cs | 89 +++++++++++++++---- .../Mutators/FusionGroupMutator.cs | 2 +- .../Rules/Neutral/InlineFunction.cs | 4 +- .../Rules/Neutral/RemoveUnusedVars.cs | 2 +- 9 files changed, 165 insertions(+), 50 deletions(-) diff --git a/modules/Nncase.Modules.NTT/Evaluator/NTT/VectorizedRoPE.cs b/modules/Nncase.Modules.NTT/Evaluator/NTT/VectorizedRoPE.cs index 1b83c61649..444f95e12f 100644 --- a/modules/Nncase.Modules.NTT/Evaluator/NTT/VectorizedRoPE.cs +++ b/modules/Nncase.Modules.NTT/Evaluator/NTT/VectorizedRoPE.cs @@ -17,6 +17,25 @@ namespace Nncase.Evaluator.IR.NTT; public class VectorizedRoPEEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator, IMetricEvaluator { + public static bool AxisEqual(IRArray a, IRArray b, int startA, int startB) + { + var lenA = a.Count - startA; + if (lenA != b.Count - startB) + { + return false; + } + + for (int i = 0; i < lenA; i++) + { + if (!Equals(a[startA + i], b[startB + i])) + { + return false; + } + } + + return true; + } + /// public IValue Visit(IEvaluateContext context, VectorizedRoPE target) { @@ -99,18 +118,16 @@ private IRType Visit(TensorType input) private IRType Visit(DistributedType input, DistributedType cos, DistributedType sin) { - var invalid = new InvalidType($"{input}, {cos}, {sin} not support"); + // only unsupported print without to-string if (input.Placement != cos.Placement || cos.Placement != sin.Placement - || !cos.AxisPolicies.SequenceEqual(sin.AxisPolicies)) + || !AxisEqual(input.AxisPolicies, cos.AxisPolicies, startA: 1, startB: 0) + || !AxisEqual(cos.AxisPolicies, sin.AxisPolicies, startA: 0, startB: 0) + || input.AxisPolicies[^1] is not SBPBroadCast) { - return invalid; - } + return new InvalidType("RoPE: distributed types mismatch (placement/axis/SBP)"); - // [head, dim, seq] - if (!input.AxisPolicies[1..].SequenceEqual(cos.AxisPolicies) - || input.AxisPolicies[1] is not SBPBroadCast) - { - return invalid; + // optional(still ToString): + // return new InvalidType($"RoPE mismatch: in={input.GetType().Name}, cos={cos.GetType().Name}, sin={sin.GetType().Name}"); } return input; diff --git a/src/Nncase.Core/IR/BaseExpr.cs b/src/Nncase.Core/IR/BaseExpr.cs index b7cac5de29..822982794e 100644 --- a/src/Nncase.Core/IR/BaseExpr.cs +++ b/src/Nncase.Core/IR/BaseExpr.cs @@ -14,6 +14,8 @@ using CommunityToolkit.HighPerformance.Helpers; using Nncase.Diagnostics; +[assembly: InternalsVisibleTo("Nncase.Passes")] + namespace Nncase.IR; /// @@ -39,14 +41,17 @@ public abstract partial class BaseExpr private IRType? _checkedType; private int? _hashCodeCache; - internal BaseExpr(IEnumerable operands) + internal BaseExpr(IEnumerable operands, bool tempora = true) { ExprScope.Current?.Add(this); _operands = operands.ToArray(); - foreach (var operand in _operands) + if (tempora) { - ValidateOperand(operand); - operand.AddUser(this); + foreach (var operand in _operands) + { + ValidateOperand(operand); + operand.AddUser(this); + } } RefreshDepth(); @@ -238,6 +243,11 @@ public override bool Equals(object? obj) internal void AddUser(BaseExpr user) { + if (UserTrackingScope.IsSuppressed) + { + return; + } + Trace.Assert(!ReferenceEquals(this, user)); _users.TryAdd(user, default); } @@ -396,3 +406,21 @@ private void RefreshDepth() Depth = _operands.Length == 0 ? 0 : _operands.Max(x => x.Depth) + 1; } } + +internal static class UserTrackingScope +{ + private static readonly System.Threading.AsyncLocal _depth = new(); + + public static bool IsSuppressed => _depth.Value > 0; + + public static IDisposable Suppress() + { + _depth.Value = _depth.Value + 1; + return new Popper(); + } + + private sealed class Popper : IDisposable + { + public void Dispose() => _depth.Value = _depth.Value - 1; + } +} diff --git a/src/Nncase.Core/IR/Var.cs b/src/Nncase.Core/IR/Var.cs index 1d9f6c9f78..f7c497c8c0 100644 --- a/src/Nncase.Core/IR/Var.cs +++ b/src/Nncase.Core/IR/Var.cs @@ -42,7 +42,7 @@ public Var(string name, IRType typeAnnotation) : base(Array.Empty()) { TypeAnnotation = typeAnnotation; - CheckedType = TypeAnnotation; + RawCheckedType = TypeAnnotation; GlobalVarIndex = GetNextId(); Name = name; } @@ -55,7 +55,7 @@ public Var(IRType typeAnnotation) : base(Array.Empty()) { TypeAnnotation = typeAnnotation; - CheckedType = TypeAnnotation; + RawCheckedType = TypeAnnotation; GlobalVarIndex = GetNextId(); Name = $"var_{GlobalVarIndex}"; } diff --git a/src/Nncase.Evaluator/NN/RoPE.cs b/src/Nncase.Evaluator/NN/RoPE.cs index 9b9ba6f166..a7b63166ca 100644 --- a/src/Nncase.Evaluator/NN/RoPE.cs +++ b/src/Nncase.Evaluator/NN/RoPE.cs @@ -17,6 +17,25 @@ namespace Nncase.Evaluator.NN; public class RoPEEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator, IMetricEvaluator { + public static bool AxisEqual(IRArray a, IRArray b, int startA, int startB) + { + var lenA = a.Count - startA; + if (lenA != b.Count - startB) + { + return false; + } + + for (int i = 0; i < lenA; i++) + { + if (!Equals(a[startA + i], b[startB + i])) + { + return false; + } + } + + return true; + } + /// public IValue Visit(IEvaluateContext context, RoPE target) { @@ -97,18 +116,16 @@ private IRType Visit(TensorType input) private IRType Visit(DistributedType input, DistributedType scale, DistributedType bias) { - var invalid = new InvalidType($"{input}, {scale}, {bias} not support"); + // only unsupported print without to-string if (input.Placement != scale.Placement || scale.Placement != bias.Placement - || !scale.AxisPolicies.SequenceEqual(bias.AxisPolicies)) + || !AxisEqual(input.AxisPolicies, scale.AxisPolicies, startA: 1, startB: 0) + || !AxisEqual(scale.AxisPolicies, bias.AxisPolicies, startA: 0, startB: 0) + || input.AxisPolicies[^1] is not SBPBroadCast) { - return invalid; - } + return new InvalidType("RoPE: distributed types mismatch (placement/axis/SBP)"); - // [head, seq, dim] - if (!input.AxisPolicies[1..].SequenceEqual(scale.AxisPolicies) - || input.AxisPolicies[2] is not SBPBroadCast) - { - return invalid; + // optional(still ToString): + // return new InvalidType($"RoPE mismatch: in={input.GetType().Name}, cos={scale.GetType().Name}, sin={bias.GetType().Name}"); } return input; diff --git a/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs index ebe9dc8047..498b43a642 100644 --- a/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs +++ b/src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs @@ -14,9 +14,9 @@ namespace Nncase.Passes.BufferSchedule; public class LifeTimeUpdater : ExprFunctor { - protected override Unit DefaultVisit(BaseExpr expr, Context context) => default; + protected internal override Unit DefaultVisit(BaseExpr expr, Context context) => default; - protected override Unit VisitTuple(IR.Tuple expr, Context context) + protected internal override Unit VisitTuple(IR.Tuple expr, Context context) { foreach (var item in expr.Fields) { @@ -26,7 +26,7 @@ protected override Unit VisitTuple(IR.Tuple expr, Context context) return default; } - protected override Unit VisitCall(Call expr, Context context) + protected internal override Unit VisitCall(Call expr, Context context) { foreach (var item in expr.Arguments) { @@ -104,7 +104,7 @@ public override Result VisitType(TupleType tupleType) return new(size, Array.Empty(), Array.Empty()); } - protected override Result VisitCall(Call expr) + protected internal override Result VisitCall(Call expr) { return VisitType(expr.CheckedType); } diff --git a/src/Nncase.Passes/Distributed/AutoDistributed.cs b/src/Nncase.Passes/Distributed/AutoDistributed.cs index 2afb5e7c0c..04549c4fb1 100644 --- a/src/Nncase.Passes/Distributed/AutoDistributed.cs +++ b/src/Nncase.Passes/Distributed/AutoDistributed.cs @@ -84,6 +84,57 @@ protected override Task RunCoreAsync(BaseFunction input, RunPassCo } } +internal static class UserRebuilder +{ + public static void Rebuild(BaseExpr root) + { + var order = new List(256); + var seen = new HashSet(ReferenceEqualityComparer.Instance); + DfsIter(root, order, seen); + + foreach (var n in order) + { + var users = n.Users.ToArray(); + for (int i = 0; i < users.Length; i++) + { + n.RemoveUser(users[i]); + } + } + + foreach (var n in order) + { + var ops = n.Operands; + for (int i = 0; i < ops.Length; i++) + { + ops[i].AddUser(n); + } + } + } + + private static void DfsIter(BaseExpr root, List order, HashSet seen) + { + var stack = new Stack(); + stack.Push(root); + + while (stack.Count > 0) + { + var n = stack.Pop(); + if (!seen.Add(n)) + { + continue; + } + + order.Add(n); + + var ops = n.Operands; + for (int i = ops.Length - 1; i >= 0; i--) + { + stack.Push(ops[i]); + } + } + } +} + internal sealed class SearchableNode { public SearchableNode(BaseExpr expr, IRType type, bool isBidirect = false) @@ -310,19 +361,24 @@ bool Matched(SearchableNode node, (IRArray Policies, Placement Placement) t public Function Rewrite(Function function) { - var body = function.Body; - Visit(body); - var rootCluster = TryInstertTerminator(body); - - if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.EGraphCost)) + BaseExpr post; + using (Nncase.IR.UserTrackingScope.Suppress()) { - using (var stream = Diagnostics.DumpScope.Current.OpenFile("DistributedSearchGraph.dot")) + Visit(function.Body); + var root = TryInstertTerminator(function.Body); + if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.EGraphCost)) { - Dump(stream, new Dictionary() { }, new Dictionary() { }); + using (var stream = Diagnostics.DumpScope.Current.OpenFile("DistributedSearchGraph.dot")) + { + Dump(stream, new Dictionary() { }, new Dictionary() { }); + } } + + post = SolveAndExtract(root); } - var post = SolveAndExtract(rootCluster); + UserRebuilder.Rebuild(post); + return function.With(body: post); } @@ -546,6 +602,11 @@ string DescribeSbp(IRType? type) } } + if (!newExpr.InferenceType(_inferencer_cache) || newExpr.CheckedType is InvalidType) + { + continue; + } + if (!expr.Target.GetType().FullName!.Contains("CustomNTT", StringComparison.Ordinal) && TargetOptions.HierarchyKind == HierarchyKind.SMT && expr.Users.Any(u => u is Call call && (call.Target.GetType().FullName!.Contains("CustomNTT.MatMul", StringComparison.Ordinal) || call.Target is PagedAttention))) @@ -557,11 +618,6 @@ string DescribeSbp(IRType? type) } } - if (!newExpr.InferenceType(_inferencer_cache) || newExpr.CheckedType is InvalidType) - { - continue; - } - var checkType = newExpr.CheckedType; if (!bucketMemo.TryGetValue(checkType, out var dbucket)) { @@ -1272,12 +1328,9 @@ private BaseExpr SolveAndExtract(DistributedSearchGraph rootCluster) } var picks = _rootSearchGraph.Vertices.ToDictionary(e => e, e => solver.BooleanValue(varMemo[e])); - if (enableDump) + using (var stream = enableDump ? Diagnostics.DumpScope.Current.OpenFile("Costs/Pick.dot") : Stream.Null) { - using (var stream = Diagnostics.DumpScope.Current.OpenFile("Costs/Pick.dot")) - { - Dump(stream, picks, costMemo); - } + Dump(stream, picks, costMemo); } if (_phase == AutoDistributedPhase.SearchConstant) diff --git a/src/Nncase.Passes/Mutators/FusionGroupMutator.cs b/src/Nncase.Passes/Mutators/FusionGroupMutator.cs index 5d3f0e9c4e..de24c73d4e 100644 --- a/src/Nncase.Passes/Mutators/FusionGroupMutator.cs +++ b/src/Nncase.Passes/Mutators/FusionGroupMutator.cs @@ -100,7 +100,7 @@ public bool TryMergeFusion(IMergeRewriteRule rule, Call old_call, out Call new_c return false; } - protected override BaseExpr VisitFusion(Fusion expr, Unit context) => base.VisitFusion(expr, context); + protected internal override BaseExpr VisitFusion(Fusion expr, Unit context) => base.VisitFusion(expr, context); /// protected override Expr RewriteLeafCall(Call expr) diff --git a/src/Nncase.Passes/Rules/Neutral/InlineFunction.cs b/src/Nncase.Passes/Rules/Neutral/InlineFunction.cs index 869104fc2a..0ff7610b89 100644 --- a/src/Nncase.Passes/Rules/Neutral/InlineFunction.cs +++ b/src/Nncase.Passes/Rules/Neutral/InlineFunction.cs @@ -54,12 +54,12 @@ public FunctionBodyCloner(Dictionary mapper) _mapper = mapper; } - protected override BaseExpr VisitLeafVar(Var expr, Unit context) + protected internal override BaseExpr VisitDimVar(DimVar expr, Unit context) { return _mapper[expr]; } - protected override BaseExpr VisitDimVar(DimVar expr, Unit context) + protected override BaseExpr VisitLeafVar(Var expr, Unit context) { return _mapper[expr]; } diff --git a/src/Nncase.Passes/Rules/Neutral/RemoveUnusedVars.cs b/src/Nncase.Passes/Rules/Neutral/RemoveUnusedVars.cs index 1273b06f05..3dd396e9da 100644 --- a/src/Nncase.Passes/Rules/Neutral/RemoveUnusedVars.cs +++ b/src/Nncase.Passes/Rules/Neutral/RemoveUnusedVars.cs @@ -139,7 +139,7 @@ public VarReplacer(Dictionary newVars) _newVars = newVars; } - protected override Expr VisitVar(Var var, Unit state) + protected internal override Expr VisitVar(Var var, Unit state) { return (Expr)_newVars[var]; } From feb51fce9be4a0d270c649cacc77f04d5b7c71ae Mon Sep 17 00:00:00 2001 From: starrryz Date: Wed, 19 Nov 2025 02:01:12 +0000 Subject: [PATCH 2/3] Apply code-format changes --- src/Nncase.Passes/Rules/Neutral/AddRangeOfAndMarker.cs | 2 +- .../TestFixture/PagedAttentionKVCacheTestFixture.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Nncase.Passes/Rules/Neutral/AddRangeOfAndMarker.cs b/src/Nncase.Passes/Rules/Neutral/AddRangeOfAndMarker.cs index a1ceaedb5e..6656ef6000 100644 --- a/src/Nncase.Passes/Rules/Neutral/AddRangeOfAndMarker.cs +++ b/src/Nncase.Passes/Rules/Neutral/AddRangeOfAndMarker.cs @@ -195,7 +195,7 @@ private IR.Tuple WrapLSTMOutput(Call call, int outputSize, bool configExist, boo { var outputNames = new List(); var getItem = IR.F.Tensors.GetItem(call, i); - outputNames.Add(call.Metadata.OutputNames?[i] ?? "LSTMOutput_" + i.ToString()); + outputNames.Add(call.Metadata.OutputNames?[i] ?? ("LSTMOutput_" + i.ToString())); outputs[i].Metadata.OutputNames = outputNames; } diff --git a/src/Nncase.Tests.TestFixture/TestFixture/PagedAttentionKVCacheTestFixture.cs b/src/Nncase.Tests.TestFixture/TestFixture/PagedAttentionKVCacheTestFixture.cs index c738b347f3..0ed117a43e 100644 --- a/src/Nncase.Tests.TestFixture/TestFixture/PagedAttentionKVCacheTestFixture.cs +++ b/src/Nncase.Tests.TestFixture/TestFixture/PagedAttentionKVCacheTestFixture.cs @@ -75,7 +75,7 @@ public static OrtKISharp.Tensor ScaledDotProductAttention(OrtKISharp.Tensor quer var curLen = query.Shape[^2]; var histLen = key.Shape[^2]; - OrtKISharp.Tensor scaleFactor = scale ?? 1 / MathF.Sqrt(query.Length); + OrtKISharp.Tensor scaleFactor = scale ?? (1 / MathF.Sqrt(query.Length)); scaleFactor = scaleFactor.Cast(query.DataType); var attnBias = OrtKI.Expand(OrtKISharp.Tensor.FromScalar(0f), OrtKISharp.Tensor.MakeTensor([curLen, histLen])); From a9772753ff7dfd8a6de115bf5705670fcd1b8a7c Mon Sep 17 00:00:00 2001 From: huochenghai Date: Tue, 2 Dec 2025 15:08:15 +0800 Subject: [PATCH 3/3] fix type infer of vectorized rope --- modules/Nncase.Modules.NTT/Evaluator/NTT/VectorizedRoPE.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/Nncase.Modules.NTT/Evaluator/NTT/VectorizedRoPE.cs b/modules/Nncase.Modules.NTT/Evaluator/NTT/VectorizedRoPE.cs index 444f95e12f..6cfb37aacd 100644 --- a/modules/Nncase.Modules.NTT/Evaluator/NTT/VectorizedRoPE.cs +++ b/modules/Nncase.Modules.NTT/Evaluator/NTT/VectorizedRoPE.cs @@ -122,7 +122,7 @@ private IRType Visit(DistributedType input, DistributedType cos, DistributedType if (input.Placement != cos.Placement || cos.Placement != sin.Placement || !AxisEqual(input.AxisPolicies, cos.AxisPolicies, startA: 1, startB: 0) || !AxisEqual(cos.AxisPolicies, sin.AxisPolicies, startA: 0, startB: 0) - || input.AxisPolicies[^1] is not SBPBroadCast) + || input.AxisPolicies[1] is not SBPBroadCast) { return new InvalidType("RoPE: distributed types mismatch (placement/axis/SBP)");