Skip to content

Commit e2f1bad

Browse files
committed
Merge branch 'dev/3.0' into feature/split_granularity
2 parents 2c0441a + 37fbee1 commit e2f1bad

File tree

11 files changed

+178
-55
lines changed

11 files changed

+178
-55
lines changed

modules/Nncase.Modules.NTT/Evaluator/NTT/VectorizedRoPE.cs

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,25 @@ namespace Nncase.Evaluator.IR.NTT;
1717
public class VectorizedRoPEEvaluator : IEvaluator<VectorizedRoPE>, ITypeInferencer<VectorizedRoPE>, ICostEvaluator<VectorizedRoPE>,
1818
IMetricEvaluator<VectorizedRoPE>
1919
{
20+
public static bool AxisEqual(IRArray<SBP> a, IRArray<SBP> b, int startA, int startB)
21+
{
22+
var lenA = a.Count - startA;
23+
if (lenA != b.Count - startB)
24+
{
25+
return false;
26+
}
27+
28+
for (int i = 0; i < lenA; i++)
29+
{
30+
if (!Equals(a[startA + i], b[startB + i]))
31+
{
32+
return false;
33+
}
34+
}
35+
36+
return true;
37+
}
38+
2039
/// <inheritdoc/>
2140
public IValue Visit(IEvaluateContext context, VectorizedRoPE target)
2241
{
@@ -99,18 +118,16 @@ private IRType Visit(TensorType input)
99118

100119
private IRType Visit(DistributedType input, DistributedType cos, DistributedType sin)
101120
{
102-
var invalid = new InvalidType($"{input}, {cos}, {sin} not support");
121+
// only unsupported print without to-string
103122
if (input.Placement != cos.Placement || cos.Placement != sin.Placement
104-
|| !cos.AxisPolicies.SequenceEqual(sin.AxisPolicies))
105-
{
106-
return invalid;
107-
}
108-
109-
// [head, dim, seq]
110-
if (!input.AxisPolicies[1..].SequenceEqual(cos.AxisPolicies)
123+
|| !AxisEqual(input.AxisPolicies, cos.AxisPolicies, startA: 1, startB: 0)
124+
|| !AxisEqual(cos.AxisPolicies, sin.AxisPolicies, startA: 0, startB: 0)
111125
|| input.AxisPolicies[1] is not SBPBroadCast)
112126
{
113-
return invalid;
127+
return new InvalidType("RoPE: distributed types mismatch (placement/axis/SBP)");
128+
129+
// optional(still ToString):
130+
// return new InvalidType($"RoPE mismatch: in={input.GetType().Name}, cos={cos.GetType().Name}, sin={sin.GetType().Name}");
114131
}
115132

116133
return input;

ntt/include/nncase/ntt/distributed/sharded_tensor.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ class sharded_tensor_view : public ntt::detail::shape_storage<TShape> {
4141
using shape_storage_type::shape;
4242
using shape_storage_type::size;
4343

44+
constexpr sharded_tensor_view() noexcept
45+
: shape_storage_type(TShape{}),
46+
sharding_(),
47+
local_(local_buffer_type{}, TLocalShape{}, LocalStrides{}) {}
48+
4449
constexpr sharded_tensor_view(local_buffer_type local_buffer,
4550
const TShape &shape,
4651
const TSharding &sharding,

ntt/include/nncase/ntt/distributed/sharding.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,10 @@ template <class Mesh, class... AxisPolicies> struct sharding {
139139
return fixed_dim_v<sizeof...(AxisPolicies)>;
140140
}
141141

142+
constexpr sharding() noexcept requires(sizeof...(AxisPolicies) == 0) = default;
143+
142144
constexpr sharding(const AxisPolicies &...axis_policies) noexcept
145+
requires(sizeof...(AxisPolicies) != 0)
143146
: axis_policies(axis_policies...) {}
144147

145148
template <Shape GlobalShape, ShardIndex<Mesh> TShardIndex>

src/Nncase.Core/IR/BaseExpr.cs

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
using CommunityToolkit.HighPerformance.Helpers;
1515
using Nncase.Diagnostics;
1616

17+
[assembly: InternalsVisibleTo("Nncase.Passes")]
18+
1719
namespace Nncase.IR;
1820

1921
/// <summary>
@@ -39,14 +41,17 @@ public abstract partial class BaseExpr
3941
private IRType? _checkedType;
4042
private int? _hashCodeCache;
4143

42-
internal BaseExpr(IEnumerable<BaseExpr> operands)
44+
internal BaseExpr(IEnumerable<BaseExpr> operands, bool tempora = true)
4345
{
4446
ExprScope.Current?.Add(this);
4547
_operands = operands.ToArray();
46-
foreach (var operand in _operands)
48+
if (tempora)
4749
{
48-
ValidateOperand(operand);
49-
operand.AddUser(this);
50+
foreach (var operand in _operands)
51+
{
52+
ValidateOperand(operand);
53+
operand.AddUser(this);
54+
}
5055
}
5156

5257
RefreshDepth();
@@ -238,6 +243,11 @@ public override bool Equals(object? obj)
238243

239244
internal void AddUser(BaseExpr user)
240245
{
246+
if (UserTrackingScope.IsSuppressed)
247+
{
248+
return;
249+
}
250+
241251
Trace.Assert(!ReferenceEquals(this, user));
242252
_users.TryAdd(user, default);
243253
}
@@ -396,3 +406,21 @@ private void RefreshDepth()
396406
Depth = _operands.Length == 0 ? 0 : _operands.Max(x => x.Depth) + 1;
397407
}
398408
}
409+
410+
internal static class UserTrackingScope
411+
{
412+
private static readonly System.Threading.AsyncLocal<int> _depth = new();
413+
414+
public static bool IsSuppressed => _depth.Value > 0;
415+
416+
public static IDisposable Suppress()
417+
{
418+
_depth.Value = _depth.Value + 1;
419+
return new Popper();
420+
}
421+
422+
private sealed class Popper : IDisposable
423+
{
424+
public void Dispose() => _depth.Value = _depth.Value - 1;
425+
}
426+
}

src/Nncase.Core/IR/Var.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public Var(string name, IRType typeAnnotation)
4242
: base(Array.Empty<BaseExpr>())
4343
{
4444
TypeAnnotation = typeAnnotation;
45-
CheckedType = TypeAnnotation;
45+
RawCheckedType = TypeAnnotation;
4646
GlobalVarIndex = GetNextId();
4747
Name = name;
4848
}
@@ -55,7 +55,7 @@ public Var(IRType typeAnnotation)
5555
: base(Array.Empty<BaseExpr>())
5656
{
5757
TypeAnnotation = typeAnnotation;
58-
CheckedType = TypeAnnotation;
58+
RawCheckedType = TypeAnnotation;
5959
GlobalVarIndex = GetNextId();
6060
Name = $"var_{GlobalVarIndex}";
6161
}

src/Nncase.Evaluator/NN/RoPE.cs

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,25 @@ namespace Nncase.Evaluator.NN;
1717
public class RoPEEvaluator : IEvaluator<RoPE>, ITypeInferencer<RoPE>, ICostEvaluator<RoPE>,
1818
IMetricEvaluator<RoPE>
1919
{
20+
public static bool AxisEqual(IRArray<SBP> a, IRArray<SBP> b, int startA, int startB)
21+
{
22+
var lenA = a.Count - startA;
23+
if (lenA != b.Count - startB)
24+
{
25+
return false;
26+
}
27+
28+
for (int i = 0; i < lenA; i++)
29+
{
30+
if (!Equals(a[startA + i], b[startB + i]))
31+
{
32+
return false;
33+
}
34+
}
35+
36+
return true;
37+
}
38+
2039
/// <inheritdoc/>
2140
public IValue Visit(IEvaluateContext context, RoPE target)
2241
{
@@ -97,18 +116,16 @@ private IRType Visit(TensorType input)
97116

98117
private IRType Visit(DistributedType input, DistributedType scale, DistributedType bias)
99118
{
100-
var invalid = new InvalidType($"{input}, {scale}, {bias} not support");
119+
// only unsupported print without to-string
101120
if (input.Placement != scale.Placement || scale.Placement != bias.Placement
102-
|| !scale.AxisPolicies.SequenceEqual(bias.AxisPolicies))
121+
|| !AxisEqual(input.AxisPolicies, scale.AxisPolicies, startA: 1, startB: 0)
122+
|| !AxisEqual(scale.AxisPolicies, bias.AxisPolicies, startA: 0, startB: 0)
123+
|| input.AxisPolicies[^1] is not SBPBroadCast)
103124
{
104-
return invalid;
105-
}
125+
return new InvalidType("RoPE: distributed types mismatch (placement/axis/SBP)");
106126

107-
// [head, seq, dim]
108-
if (!input.AxisPolicies[1..].SequenceEqual(scale.AxisPolicies)
109-
|| input.AxisPolicies[2] is not SBPBroadCast)
110-
{
111-
return invalid;
127+
// optional(still ToString):
128+
// return new InvalidType($"RoPE mismatch: in={input.GetType().Name}, cos={scale.GetType().Name}, sin={bias.GetType().Name}");
112129
}
113130

114131
return input;

src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ namespace Nncase.Passes.BufferSchedule;
1414

1515
public class LifeTimeUpdater : ExprFunctor<Unit, Unit, LifeTimeUpdater.Context>
1616
{
17-
protected override Unit DefaultVisit(BaseExpr expr, Context context) => default;
17+
protected internal override Unit DefaultVisit(BaseExpr expr, Context context) => default;
1818

19-
protected override Unit VisitTuple(IR.Tuple expr, Context context)
19+
protected internal override Unit VisitTuple(IR.Tuple expr, Context context)
2020
{
2121
foreach (var item in expr.Fields)
2222
{
@@ -26,7 +26,7 @@ protected override Unit VisitTuple(IR.Tuple expr, Context context)
2626
return default;
2727
}
2828

29-
protected override Unit VisitCall(Call expr, Context context)
29+
protected internal override Unit VisitCall(Call expr, Context context)
3030
{
3131
foreach (var item in expr.Arguments)
3232
{
@@ -104,7 +104,7 @@ public override Result VisitType(TupleType tupleType)
104104
return new(size, Array.Empty<long>(), Array.Empty<long>());
105105
}
106106

107-
protected override Result VisitCall(Call expr)
107+
protected internal override Result VisitCall(Call expr)
108108
{
109109
return VisitType(expr.CheckedType);
110110
}

src/Nncase.Passes/Distributed/AutoDistributed.cs

Lines changed: 76 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,57 @@ protected override Task<BaseFunction> RunCoreAsync(BaseFunction input, RunPassCo
8484
}
8585
}
8686

87+
internal static class UserRebuilder
88+
{
89+
public static void Rebuild(BaseExpr root)
90+
{
91+
var order = new List<BaseExpr>(256);
92+
var seen = new HashSet<BaseExpr>(ReferenceEqualityComparer.Instance);
93+
DfsIter(root, order, seen);
94+
95+
foreach (var n in order)
96+
{
97+
var users = n.Users.ToArray();
98+
for (int i = 0; i < users.Length; i++)
99+
{
100+
n.RemoveUser(users[i]);
101+
}
102+
}
103+
104+
foreach (var n in order)
105+
{
106+
var ops = n.Operands;
107+
for (int i = 0; i < ops.Length; i++)
108+
{
109+
ops[i].AddUser(n);
110+
}
111+
}
112+
}
113+
114+
private static void DfsIter(BaseExpr root, List<BaseExpr> order, HashSet<BaseExpr> seen)
115+
{
116+
var stack = new Stack<BaseExpr>();
117+
stack.Push(root);
118+
119+
while (stack.Count > 0)
120+
{
121+
var n = stack.Pop();
122+
if (!seen.Add(n))
123+
{
124+
continue;
125+
}
126+
127+
order.Add(n);
128+
129+
var ops = n.Operands;
130+
for (int i = ops.Length - 1; i >= 0; i--)
131+
{
132+
stack.Push(ops[i]);
133+
}
134+
}
135+
}
136+
}
137+
87138
internal sealed class SearchableNode
88139
{
89140
public SearchableNode(BaseExpr expr, IRType type, bool isBidirect = false)
@@ -310,19 +361,24 @@ bool Matched(SearchableNode node, (IRArray<SBP> Policies, Placement Placement) t
310361

311362
public Function Rewrite(Function function)
312363
{
313-
var body = function.Body;
314-
Visit(body);
315-
var rootCluster = TryInstertTerminator(body);
316-
317-
// if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.EGraphCost))
364+
BaseExpr post;
365+
using (Nncase.IR.UserTrackingScope.Suppress())
318366
{
319-
using (var stream = Diagnostics.DumpScope.Current.OpenFile("DistributedSearchGraph.dot"))
367+
Visit(function.Body);
368+
var root = TryInstertTerminator(function.Body);
369+
if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.EGraphCost))
320370
{
321-
Dump(stream, new Dictionary<SearchableNode, bool>() { }, new Dictionary<SearchableNode, CostModel.Cost>() { });
371+
using (var stream = Diagnostics.DumpScope.Current.OpenFile("DistributedSearchGraph.dot"))
372+
{
373+
Dump(stream, new Dictionary<SearchableNode, bool>() { }, new Dictionary<SearchableNode, CostModel.Cost>() { });
374+
}
322375
}
376+
377+
post = SolveAndExtract(root);
323378
}
324379

325-
var post = SolveAndExtract(rootCluster);
380+
UserRebuilder.Rebuild(post);
381+
326382
return function.With(body: post);
327383
}
328384

@@ -546,6 +602,16 @@ string DescribeSbp(IRType? type)
546602
}
547603
}
548604

605+
if (expr.Target is not Boxing && ((Call)newExpr).Arguments.AsValueEnumerable().Any(a => a.CheckedType is DistributedType dt && dt.Partial is not null))
606+
{
607+
continue;
608+
}
609+
610+
if (!newExpr.InferenceType(_inferencer_cache) || newExpr.CheckedType is InvalidType)
611+
{
612+
continue;
613+
}
614+
549615
if (!expr.Target.GetType().FullName!.Contains("CustomNTT", StringComparison.Ordinal)
550616
&& TargetOptions.HierarchyKind == HierarchyKind.SMT
551617
&& expr.Users.Any(u => u is Call call && (call.Target.GetType().FullName!.Contains("CustomNTT.MatMul", StringComparison.Ordinal) || call.Target is PagedAttention)))
@@ -557,16 +623,6 @@ string DescribeSbp(IRType? type)
557623
}
558624
}
559625

560-
if (expr.Target is not Boxing && ((Call)newExpr).Arguments.AsValueEnumerable().Any(a => a.CheckedType is DistributedType dt && dt.Partial is not null))
561-
{
562-
continue;
563-
}
564-
565-
if (!newExpr.InferenceType(_inferencer_cache) || newExpr.CheckedType is InvalidType)
566-
{
567-
continue;
568-
}
569-
570626
var checkType = newExpr.CheckedType;
571627
if (!bucketMemo.TryGetValue(checkType, out var dbucket))
572628
{
@@ -1311,12 +1367,9 @@ private BaseExpr SolveAndExtract(DistributedSearchGraph rootCluster)
13111367
}
13121368

13131369
var picks = _rootSearchGraph.Vertices.ToDictionary(e => e, e => solver.BooleanValue(varMemo[e]));
1314-
if (enableDump)
1370+
using (var stream = enableDump ? Diagnostics.DumpScope.Current.OpenFile("Costs/Pick.dot") : Stream.Null)
13151371
{
1316-
using (var stream = Diagnostics.DumpScope.Current.OpenFile("Costs/Pick.dot"))
1317-
{
1318-
Dump(stream, picks, costMemo);
1319-
}
1372+
Dump(stream, picks, costMemo);
13201373
}
13211374

13221375
if (_phase == AutoDistributedPhase.SearchConstant)

src/Nncase.Passes/Mutators/FusionGroupMutator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ public bool TryMergeFusion(IMergeRewriteRule rule, Call old_call, out Call new_c
100100
return false;
101101
}
102102

103-
protected override BaseExpr VisitFusion(Fusion expr, Unit context) => base.VisitFusion(expr, context);
103+
protected internal override BaseExpr VisitFusion(Fusion expr, Unit context) => base.VisitFusion(expr, context);
104104

105105
/// <inheritdoc/>
106106
protected override Expr RewriteLeafCall(Call expr)

0 commit comments

Comments
 (0)