Skip to content

Commit 2c0441a

Browse files
committed
try to fix k80 CI
1 parent 1c3320e commit 2c0441a

File tree

4 files changed

+67
-1
lines changed

4 files changed

+67
-1
lines changed

modules/Nncase.Modules.NTT/Evaluator/Distributed/Boxing.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,14 @@ IRType VisitD2D(DistributedType inv, DistributedType outv)
4747
{
4848
if (inv.Partial is not null && outv.AxisPolicies[i] is SBPSplit s)
4949
{
50+
if (inv.AxisPolicies[i] is SBPSplit splitIn)
51+
{
52+
if (splitIn.Axes.Except(s.Axes).Any())
53+
{
54+
return new InvalidType("Not Supported Split-> Split.");
55+
}
56+
}
57+
5058
if (s.Axes.Except(inv.Partial.Axes).ToArray() != s.Axes)
5159
{
5260
if (s.Axes.Except(inv.Partial.Axes).Any())

src/Nncase.Passes/Distributed/AutoDistributed.cs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,11 @@ string DescribeSbp(IRType? type)
557557
}
558558
}
559559

560+
if (expr.Target is not Boxing && ((Call)newExpr).Arguments.AsValueEnumerable().Any(a => a.CheckedType is DistributedType dt && dt.Partial is not null))
561+
{
562+
continue;
563+
}
564+
560565
if (!newExpr.InferenceType(_inferencer_cache) || newExpr.CheckedType is InvalidType)
561566
{
562567
continue;
@@ -969,7 +974,7 @@ private IRType CheckBoxingType(IRType inType, IRType outType, bool isReshape = f
969974
{
970975
IRType VisitD2D(DistributedType inv, DistributedType outv)
971976
{
972-
if (DistributedUtility.AreSamePolicies(inv.AxisPolicies, outv.AxisPolicies))
977+
if (inv.Partial == outv.Partial && DistributedUtility.AreSamePolicies(inv.AxisPolicies, outv.AxisPolicies))
973978
{
974979
return new InvalidType("Same DistributedType");
975980
}
@@ -984,6 +989,14 @@ IRType VisitD2D(DistributedType inv, DistributedType outv)
984989
{
985990
if (inv.Partial is not null && outv.AxisPolicies[i] is SBPSplit s)
986991
{
992+
if (inv.AxisPolicies[i] is SBPSplit splitIn)
993+
{
994+
if (splitIn.Axes.Except(s.Axes).Any())
995+
{
996+
return new InvalidType("Not Supported Split-> Split.");
997+
}
998+
}
999+
9871000
if (s.Axes.Except(inv.Partial.Axes).ToArray() != s.Axes)
9881001
{
9891002
if (s.Axes.Except(inv.Partial.Axes).Any())
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright (c) Canaan Inc. All rights reserved.
2+
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Linq;
7+
using System.Text;
8+
using System.Threading.Tasks;
9+
using Nncase.IR;
10+
using Nncase.IR.Distributed;
11+
using Nncase.PatternMatch;
12+
using static Nncase.IR.F.NN;
13+
14+
using static Nncase.IR.TypePatternUtility;
15+
using static Nncase.PatternMatch.F.Distributed;
16+
using static Nncase.PatternMatch.Utility;
17+
18+
namespace Nncase.Passes.Rules;
19+
20+
[RuleGenerator]
21+
public partial class UpdateBoxingTensorType : RewriteRule<Pattern>
22+
{
23+
/// <inheritdoc/>
24+
public override Pattern Pattern { get; } = IsBoxing(
25+
target_name: "boxing",
26+
_ => true,
27+
IsWildcard("input"));
28+
29+
private Expr? GetReplace(Boxing boxing, Expr input, RunPassContext context)
30+
{
31+
var type = input.CheckedType;
32+
if (type is DistributedType dt1 && boxing.NewType is DistributedType dt2)
33+
{
34+
var ttype = dt1.TensorType;
35+
var dtype = dt2 with { TensorType = ttype };
36+
var newBoxing = new Call(new IR.Distributed.Boxing(dtype), input);
37+
context.MatchOptions.SuppressPattern(newBoxing, Pattern);
38+
return newBoxing;
39+
}
40+
41+
return null;
42+
}
43+
}

src/Nncase.Schedule/Transforms/AutoTilePass.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using Nncase.IR;
1212
using Nncase.IR.Affine;
1313
using Nncase.Passes.GraphPartition;
14+
using Nncase.Passes.Rules;
1415
using Nncase.Schedule;
1516
using QuikGraph;
1617
using QuikGraph.Algorithms;
@@ -116,6 +117,7 @@ protected override Task<BaseFunction> RunCoreAsync(BaseFunction input, RunPassCo
116117

117118
var constructor = new AutoTileReconstructor(tiler, ModuleKind, CompileOptions, condenseAlgo, dimVars.ToArray());
118119
var post = constructor.Construct();
120+
post = CompilerServices.Rewrite(post, [new UpdateBoxingTensorType()], new());
119121
return Task.FromResult((BaseFunction)func.With(body: post));
120122
}
121123
}

0 commit comments

Comments
 (0)