Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions modules/Nncase.Modules.NTT/Evaluator/NTT/VectorizedRoPE.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,25 @@ namespace Nncase.Evaluator.IR.NTT;
public class VectorizedRoPEEvaluator : IEvaluator<VectorizedRoPE>, ITypeInferencer<VectorizedRoPE>, ICostEvaluator<VectorizedRoPE>,
IMetricEvaluator<VectorizedRoPE>
{
public static bool AxisEqual(IRArray<SBP> a, IRArray<SBP> b, int startA, int startB)
{
var lenA = a.Count - startA;
if (lenA != b.Count - startB)
{
return false;
}

for (int i = 0; i < lenA; i++)
{
if (!Equals(a[startA + i], b[startB + i]))
{
return false;
}
}

return true;
}

/// <inheritdoc/>
public IValue Visit(IEvaluateContext context, VectorizedRoPE target)
{
Expand Down Expand Up @@ -99,18 +118,16 @@ private IRType Visit(TensorType input)

private IRType Visit(DistributedType input, DistributedType cos, DistributedType sin)
{
var invalid = new InvalidType($"{input}, {cos}, {sin} not support");
// only unsupported print without to-string
if (input.Placement != cos.Placement || cos.Placement != sin.Placement
|| !cos.AxisPolicies.SequenceEqual(sin.AxisPolicies))
{
return invalid;
}

// [head, dim, seq]
if (!input.AxisPolicies[1..].SequenceEqual(cos.AxisPolicies)
|| !AxisEqual(input.AxisPolicies, cos.AxisPolicies, startA: 1, startB: 0)
|| !AxisEqual(cos.AxisPolicies, sin.AxisPolicies, startA: 0, startB: 0)
|| input.AxisPolicies[1] is not SBPBroadCast)
{
return invalid;
return new InvalidType("RoPE: distributed types mismatch (placement/axis/SBP)");

// optional(still ToString):
// return new InvalidType($"RoPE mismatch: in={input.GetType().Name}, cos={cos.GetType().Name}, sin={sin.GetType().Name}");
}

return input;
Expand Down
36 changes: 32 additions & 4 deletions src/Nncase.Core/IR/BaseExpr.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
using CommunityToolkit.HighPerformance.Helpers;
using Nncase.Diagnostics;

[assembly: InternalsVisibleTo("Nncase.Passes")]

namespace Nncase.IR;

/// <summary>
Expand All @@ -39,14 +41,17 @@ public abstract partial class BaseExpr
private IRType? _checkedType;
private int? _hashCodeCache;

internal BaseExpr(IEnumerable<BaseExpr> operands)
internal BaseExpr(IEnumerable<BaseExpr> operands, bool tempora = true)
{
ExprScope.Current?.Add(this);
_operands = operands.ToArray();
foreach (var operand in _operands)
if (tempora)
{
ValidateOperand(operand);
operand.AddUser(this);
foreach (var operand in _operands)
{
ValidateOperand(operand);
operand.AddUser(this);
}
}

RefreshDepth();
Expand Down Expand Up @@ -238,6 +243,11 @@ public override bool Equals(object? obj)

internal void AddUser(BaseExpr user)
{
if (UserTrackingScope.IsSuppressed)
{
return;
}

Trace.Assert(!ReferenceEquals(this, user));
_users.TryAdd(user, default);
}
Expand Down Expand Up @@ -396,3 +406,21 @@ private void RefreshDepth()
Depth = _operands.Length == 0 ? 0 : _operands.Max(x => x.Depth) + 1;
}
}

internal static class UserTrackingScope
{
private static readonly System.Threading.AsyncLocal<int> _depth = new();

public static bool IsSuppressed => _depth.Value > 0;

public static IDisposable Suppress()
{
_depth.Value = _depth.Value + 1;
return new Popper();
}

private sealed class Popper : IDisposable
{
public void Dispose() => _depth.Value = _depth.Value - 1;
}
}
4 changes: 2 additions & 2 deletions src/Nncase.Core/IR/Var.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public Var(string name, IRType typeAnnotation)
: base(Array.Empty<BaseExpr>())
{
TypeAnnotation = typeAnnotation;
CheckedType = TypeAnnotation;
RawCheckedType = TypeAnnotation;
GlobalVarIndex = GetNextId();
Name = name;
}
Expand All @@ -55,7 +55,7 @@ public Var(IRType typeAnnotation)
: base(Array.Empty<BaseExpr>())
{
TypeAnnotation = typeAnnotation;
CheckedType = TypeAnnotation;
RawCheckedType = TypeAnnotation;
GlobalVarIndex = GetNextId();
Name = $"var_{GlobalVarIndex}";
}
Expand Down
35 changes: 26 additions & 9 deletions src/Nncase.Evaluator/NN/RoPE.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,25 @@ namespace Nncase.Evaluator.NN;
public class RoPEEvaluator : IEvaluator<RoPE>, ITypeInferencer<RoPE>, ICostEvaluator<RoPE>,
IMetricEvaluator<RoPE>
{
public static bool AxisEqual(IRArray<SBP> a, IRArray<SBP> b, int startA, int startB)
{
var lenA = a.Count - startA;
if (lenA != b.Count - startB)
{
return false;
}

for (int i = 0; i < lenA; i++)
{
if (!Equals(a[startA + i], b[startB + i]))
{
return false;
}
}

return true;
}

/// <inheritdoc/>
public IValue Visit(IEvaluateContext context, RoPE target)
{
Expand Down Expand Up @@ -97,18 +116,16 @@ private IRType Visit(TensorType input)

private IRType Visit(DistributedType input, DistributedType scale, DistributedType bias)
{
var invalid = new InvalidType($"{input}, {scale}, {bias} not support");
// only unsupported print without to-string
if (input.Placement != scale.Placement || scale.Placement != bias.Placement
|| !scale.AxisPolicies.SequenceEqual(bias.AxisPolicies))
|| !AxisEqual(input.AxisPolicies, scale.AxisPolicies, startA: 1, startB: 0)
|| !AxisEqual(scale.AxisPolicies, bias.AxisPolicies, startA: 0, startB: 0)
|| input.AxisPolicies[^1] is not SBPBroadCast)
{
return invalid;
}
return new InvalidType("RoPE: distributed types mismatch (placement/axis/SBP)");

// [head, seq, dim]
if (!input.AxisPolicies[1..].SequenceEqual(scale.AxisPolicies)
|| input.AxisPolicies[2] is not SBPBroadCast)
{
return invalid;
// optional(still ToString):
// return new InvalidType($"RoPE mismatch: in={input.GetType().Name}, cos={scale.GetType().Name}, sin={bias.GetType().Name}");
}

return input;
Expand Down
8 changes: 4 additions & 4 deletions src/Nncase.Passes/BufferSchedule/LifeTimeCollector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ namespace Nncase.Passes.BufferSchedule;

public class LifeTimeUpdater : ExprFunctor<Unit, Unit, LifeTimeUpdater.Context>
{
protected override Unit DefaultVisit(BaseExpr expr, Context context) => default;
protected internal override Unit DefaultVisit(BaseExpr expr, Context context) => default;

protected override Unit VisitTuple(IR.Tuple expr, Context context)
protected internal override Unit VisitTuple(IR.Tuple expr, Context context)
{
foreach (var item in expr.Fields)
{
Expand All @@ -26,7 +26,7 @@ protected override Unit VisitTuple(IR.Tuple expr, Context context)
return default;
}

protected override Unit VisitCall(Call expr, Context context)
protected internal override Unit VisitCall(Call expr, Context context)
{
foreach (var item in expr.Arguments)
{
Expand Down Expand Up @@ -104,7 +104,7 @@ public override Result VisitType(TupleType tupleType)
return new(size, Array.Empty<long>(), Array.Empty<long>());
}

protected override Result VisitCall(Call expr)
protected internal override Result VisitCall(Call expr)
{
return VisitType(expr.CheckedType);
}
Expand Down
89 changes: 71 additions & 18 deletions src/Nncase.Passes/Distributed/AutoDistributed.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,57 @@ protected override Task<BaseFunction> RunCoreAsync(BaseFunction input, RunPassCo
}
}

internal static class UserRebuilder
{
public static void Rebuild(BaseExpr root)
{
var order = new List<BaseExpr>(256);
var seen = new HashSet<BaseExpr>(ReferenceEqualityComparer.Instance);
DfsIter(root, order, seen);

foreach (var n in order)
{
var users = n.Users.ToArray();
for (int i = 0; i < users.Length; i++)
{
n.RemoveUser(users[i]);
}
}

foreach (var n in order)
{
var ops = n.Operands;
for (int i = 0; i < ops.Length; i++)
{
ops[i].AddUser(n);
}
}
}

private static void DfsIter(BaseExpr root, List<BaseExpr> order, HashSet<BaseExpr> seen)
{
var stack = new Stack<BaseExpr>();
stack.Push(root);

while (stack.Count > 0)
{
var n = stack.Pop();
if (!seen.Add(n))
{
continue;
}

order.Add(n);

var ops = n.Operands;
for (int i = ops.Length - 1; i >= 0; i--)
{
stack.Push(ops[i]);
}
}
}
}

internal sealed class SearchableNode
{
public SearchableNode(BaseExpr expr, IRType type, bool isBidirect = false)
Expand Down Expand Up @@ -310,19 +361,24 @@ bool Matched(SearchableNode node, (IRArray<SBP> Policies, Placement Placement) t

public Function Rewrite(Function function)
{
var body = function.Body;
Visit(body);
var rootCluster = TryInstertTerminator(body);

if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.EGraphCost))
BaseExpr post;
using (Nncase.IR.UserTrackingScope.Suppress())
{
using (var stream = Diagnostics.DumpScope.Current.OpenFile("DistributedSearchGraph.dot"))
Visit(function.Body);
var root = TryInstertTerminator(function.Body);
if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.EGraphCost))
{
Dump(stream, new Dictionary<SearchableNode, bool>() { }, new Dictionary<SearchableNode, CostModel.Cost>() { });
using (var stream = Diagnostics.DumpScope.Current.OpenFile("DistributedSearchGraph.dot"))
{
Dump(stream, new Dictionary<SearchableNode, bool>() { }, new Dictionary<SearchableNode, CostModel.Cost>() { });
}
}

post = SolveAndExtract(root);
}

var post = SolveAndExtract(rootCluster);
UserRebuilder.Rebuild(post);

return function.With(body: post);
}

Expand Down Expand Up @@ -546,6 +602,11 @@ string DescribeSbp(IRType? type)
}
}

if (!newExpr.InferenceType(_inferencer_cache) || newExpr.CheckedType is InvalidType)
{
continue;
}

if (!expr.Target.GetType().FullName!.Contains("CustomNTT", StringComparison.Ordinal)
&& TargetOptions.HierarchyKind == HierarchyKind.SMT
&& expr.Users.Any(u => u is Call call && (call.Target.GetType().FullName!.Contains("CustomNTT.MatMul", StringComparison.Ordinal) || call.Target is PagedAttention)))
Expand All @@ -557,11 +618,6 @@ string DescribeSbp(IRType? type)
}
}

if (!newExpr.InferenceType(_inferencer_cache) || newExpr.CheckedType is InvalidType)
{
continue;
}

var checkType = newExpr.CheckedType;
if (!bucketMemo.TryGetValue(checkType, out var dbucket))
{
Expand Down Expand Up @@ -1272,12 +1328,9 @@ private BaseExpr SolveAndExtract(DistributedSearchGraph rootCluster)
}

var picks = _rootSearchGraph.Vertices.ToDictionary(e => e, e => solver.BooleanValue(varMemo[e]));
if (enableDump)
using (var stream = enableDump ? Diagnostics.DumpScope.Current.OpenFile("Costs/Pick.dot") : Stream.Null)
{
using (var stream = Diagnostics.DumpScope.Current.OpenFile("Costs/Pick.dot"))
{
Dump(stream, picks, costMemo);
}
Dump(stream, picks, costMemo);
}

if (_phase == AutoDistributedPhase.SearchConstant)
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Passes/Mutators/FusionGroupMutator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public bool TryMergeFusion(IMergeRewriteRule rule, Call old_call, out Call new_c
return false;
}

protected override BaseExpr VisitFusion(Fusion expr, Unit context) => base.VisitFusion(expr, context);
protected internal override BaseExpr VisitFusion(Fusion expr, Unit context) => base.VisitFusion(expr, context);

/// <inheritdoc/>
protected override Expr RewriteLeafCall(Call expr)
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Passes/Rules/Neutral/AddRangeOfAndMarker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ private IR.Tuple WrapLSTMOutput(Call call, int outputSize, bool configExist, boo
{
var outputNames = new List<string>();
var getItem = IR.F.Tensors.GetItem(call, i);
outputNames.Add(call.Metadata.OutputNames?[i] ?? "LSTMOutput_" + i.ToString());
outputNames.Add(call.Metadata.OutputNames?[i] ?? ("LSTMOutput_" + i.ToString()));
outputs[i].Metadata.OutputNames = outputNames;
}

Expand Down
4 changes: 2 additions & 2 deletions src/Nncase.Passes/Rules/Neutral/InlineFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ public FunctionBodyCloner(Dictionary<IVar, BaseExpr> mapper)
_mapper = mapper;
}

protected override BaseExpr VisitLeafVar(Var expr, Unit context)
protected internal override BaseExpr VisitDimVar(DimVar expr, Unit context)
{
return _mapper[expr];
}

protected override BaseExpr VisitDimVar(DimVar expr, Unit context)
protected override BaseExpr VisitLeafVar(Var expr, Unit context)
{
return _mapper[expr];
}
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Passes/Rules/Neutral/RemoveUnusedVars.cs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ public VarReplacer(Dictionary<IVar, IVar> newVars)
_newVars = newVars;
}

protected override Expr VisitVar(Var var, Unit state)
protected internal override Expr VisitVar(Var var, Unit state)
{
return (Expr)_newVars[var];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public static OrtKISharp.Tensor ScaledDotProductAttention(OrtKISharp.Tensor quer
var curLen = query.Shape[^2];
var histLen = key.Shape[^2];

OrtKISharp.Tensor scaleFactor = scale ?? 1 / MathF.Sqrt(query.Length);
OrtKISharp.Tensor scaleFactor = scale ?? (1 / MathF.Sqrt(query.Length));
scaleFactor = scaleFactor.Cast(query.DataType);

var attnBias = OrtKI.Expand(OrtKISharp.Tensor.FromScalar(0f), OrtKISharp.Tensor.MakeTensor([curLen, histLen]));
Expand Down
Loading