@@ -84,6 +84,57 @@ protected override Task<BaseFunction> RunCoreAsync(BaseFunction input, RunPassCo
8484 }
8585}
8686
87+ internal static class UserRebuilder
88+ {
89+ public static void Rebuild ( BaseExpr root )
90+ {
91+ var order = new List < BaseExpr > ( 256 ) ;
92+ var seen = new HashSet < BaseExpr > ( ReferenceEqualityComparer . Instance ) ;
93+ DfsIter ( root , order , seen ) ;
94+
95+ foreach ( var n in order )
96+ {
97+ var users = n . Users . ToArray ( ) ;
98+ for ( int i = 0 ; i < users . Length ; i ++ )
99+ {
100+ n . RemoveUser ( users [ i ] ) ;
101+ }
102+ }
103+
104+ foreach ( var n in order )
105+ {
106+ var ops = n . Operands ;
107+ for ( int i = 0 ; i < ops . Length ; i ++ )
108+ {
109+ ops [ i ] . AddUser ( n ) ;
110+ }
111+ }
112+ }
113+
114+ private static void DfsIter ( BaseExpr root , List < BaseExpr > order , HashSet < BaseExpr > seen )
115+ {
116+ var stack = new Stack < BaseExpr > ( ) ;
117+ stack . Push ( root ) ;
118+
119+ while ( stack . Count > 0 )
120+ {
121+ var n = stack . Pop ( ) ;
122+ if ( ! seen . Add ( n ) )
123+ {
124+ continue ;
125+ }
126+
127+ order . Add ( n ) ;
128+
129+ var ops = n . Operands ;
130+ for ( int i = ops . Length - 1 ; i >= 0 ; i -- )
131+ {
132+ stack . Push ( ops [ i ] ) ;
133+ }
134+ }
135+ }
136+ }
137+
87138internal sealed class SearchableNode
88139{
89140 public SearchableNode ( BaseExpr expr , IRType type , bool isBidirect = false )
@@ -310,19 +361,24 @@ bool Matched(SearchableNode node, (IRArray<SBP> Policies, Placement Placement) t
310361
311362 public Function Rewrite ( Function function )
312363 {
313- var body = function . Body ;
314- Visit ( body ) ;
315- var rootCluster = TryInstertTerminator ( body ) ;
316-
317- // if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.EGraphCost))
364+ BaseExpr post ;
365+ using ( Nncase . IR . UserTrackingScope . Suppress ( ) )
318366 {
319- using ( var stream = Diagnostics . DumpScope . Current . OpenFile ( "DistributedSearchGraph.dot" ) )
367+ Visit ( function . Body ) ;
368+ var root = TryInstertTerminator ( function . Body ) ;
369+ if ( Diagnostics . DumpScope . Current . IsEnabled ( Diagnostics . DumpFlags . EGraphCost ) )
320370 {
321- Dump ( stream , new Dictionary < SearchableNode , bool > ( ) { } , new Dictionary < SearchableNode , CostModel . Cost > ( ) { } ) ;
371+ using ( var stream = Diagnostics . DumpScope . Current . OpenFile ( "DistributedSearchGraph.dot" ) )
372+ {
373+ Dump ( stream , new Dictionary < SearchableNode , bool > ( ) { } , new Dictionary < SearchableNode , CostModel . Cost > ( ) { } ) ;
374+ }
322375 }
376+
377+ post = SolveAndExtract ( root ) ;
323378 }
324379
325- var post = SolveAndExtract ( rootCluster ) ;
380+ UserRebuilder . Rebuild ( post ) ;
381+
326382 return function . With ( body : post ) ;
327383 }
328384
@@ -546,6 +602,16 @@ string DescribeSbp(IRType? type)
546602 }
547603 }
548604
605+ if ( expr . Target is not Boxing && ( ( Call ) newExpr ) . Arguments . AsValueEnumerable ( ) . Any ( a => a . CheckedType is DistributedType dt && dt . Partial is not null ) )
606+ {
607+ continue ;
608+ }
609+
610+ if ( ! newExpr . InferenceType ( _inferencer_cache ) || newExpr . CheckedType is InvalidType )
611+ {
612+ continue ;
613+ }
614+
549615 if ( ! expr . Target . GetType ( ) . FullName ! . Contains ( "CustomNTT" , StringComparison . Ordinal )
550616 && TargetOptions . HierarchyKind == HierarchyKind . SMT
551617 && expr . Users . Any ( u => u is Call call && ( call . Target . GetType ( ) . FullName ! . Contains ( "CustomNTT.MatMul" , StringComparison . Ordinal ) || call . Target is PagedAttention ) ) )
@@ -557,16 +623,6 @@ string DescribeSbp(IRType? type)
557623 }
558624 }
559625
560- if ( expr . Target is not Boxing && ( ( Call ) newExpr ) . Arguments . AsValueEnumerable ( ) . Any ( a => a . CheckedType is DistributedType dt && dt . Partial is not null ) )
561- {
562- continue ;
563- }
564-
565- if ( ! newExpr . InferenceType ( _inferencer_cache ) || newExpr . CheckedType is InvalidType )
566- {
567- continue ;
568- }
569-
570626 var checkType = newExpr . CheckedType ;
571627 if ( ! bucketMemo . TryGetValue ( checkType , out var dbucket ) )
572628 {
@@ -1311,12 +1367,9 @@ private BaseExpr SolveAndExtract(DistributedSearchGraph rootCluster)
13111367 }
13121368
13131369 var picks = _rootSearchGraph . Vertices . ToDictionary ( e => e , e => solver . BooleanValue ( varMemo [ e ] ) ) ;
1314- if ( enableDump )
1370+ using ( var stream = enableDump ? Diagnostics . DumpScope . Current . OpenFile ( "Costs/Pick.dot" ) : Stream . Null )
13151371 {
1316- using ( var stream = Diagnostics . DumpScope . Current . OpenFile ( "Costs/Pick.dot" ) )
1317- {
1318- Dump ( stream , picks , costMemo ) ;
1319- }
1372+ Dump ( stream , picks , costMemo ) ;
13201373 }
13211374
13221375 if ( _phase == AutoDistributedPhase . SearchConstant )
0 commit comments