Skip to content

Commit 063f780

Browse files
ekayaaslancopybara-github
authored andcommitted
Push shardy outliner up RemovePropagationDebugInfo pass.
RemovePropagationDebugInfo removes debug attributes from op results and func arg/results. It already supports funcOps. CallOps are supported generically as part of non-func operations. PiperOrigin-RevId: 895864642
1 parent c00b51a commit 063f780

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

shardy/dialect/sdy/transforms/export/export_pipeline.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,10 @@ void addExportPipeline(OpPassManager& pm, int& dumpIndex,
100100
runShardyPartitioner(pm, dumpIndex, options);
101101
}
102102

103+
pm.addPass(createExportNamedComputationsPass());
103104
if (options.dumpPropagationEdges || options.dumpShardingOrigins) {
104105
pm.addPass(createRemovePropagationDebugInfoPass());
105106
}
106-
pm.addPass(createExportNamedComputationsPass());
107107
if (!options.keepShardingRules) {
108108
pm.addNestedPass<func::FuncOp>(createDropShardingRulesPass());
109109
}

shardy/dialect/sdy/transforms/export/test/remove_propagation_debug_info.mlir

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: sdy_opt %s -sdy-remove-propagation-debug-info | FileCheck %s
1+
// RUN: sdy_opt %s -split-input-file -sdy-remove-propagation-debug-info | FileCheck %s
22

33
sdy.mesh @mesh_a = <["a"=2]>
44
sdy.mesh @mesh_c = <["c"=8]>
@@ -78,3 +78,35 @@ func.func @remove_origin_shardings_manual_computation_with_sharding_constraints(
7878
} {sdy.block_arg_sharding_origins = [{a = "self", b = "self"}, {a = "self", b = "self"}], sdy.result_sharding_origins = [{a = "self", b = "self"}], sdy.sharding_origin_name = "mc_1"} : (tensor<8x8xf32>, tensor<8x8xf32>) -> tensor<8x8xf32>
7979
return %0, %2 : tensor<8x8xf32>, tensor<8x8xf32>
8080
}
81+
82+
// -----
83+
sdy.mesh @mesh_abc = <["a"=2, "b"=2, "c"=8]>
84+
// CHECK-LABEL: func @single_call(
85+
// CHECK-SAME: %arg0: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>},
86+
// CHECK-SAME: %arg1: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>}) -> (
87+
// CHECK-SAME: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>}) {
88+
func.func @single_call(
89+
%arg0: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>, sdy.sharding_origins = {a = "mc_1_input: 0", b = "mc_1_input: 0"}},
90+
%arg1: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>}) -> (
91+
tensor<8x8x8xf32> {sdy.propagation_edges = #sdy.propagation_edges<[{step-0 = [{"b" = result-0 -> [operand-0]}]}, {step-2 = [{"a" = operand-0 -> [result-0]}, {"c" = operand-0 -> [result-0]}]}]>, sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>}) {
92+
// CHECK-NEXT: %0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>]>}
93+
// CHECK-NEXT: %1 = call @foo(%0, %0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>]>}
94+
%0 = stablehlo.add %arg0, %arg1 {sdy.propagation_edges = #sdy.propagation_edges<[{step-1 = [{"a" = operand-0 -> [operand-1, result-0]}, {"b" = result-0 -> [operand-0, operand-1]}, {"c" = operand-1 -> [operand-0, result-0]}]}]>, sdy.sharding = #sdy.sharding_per_value<[<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>]>} : tensor<8x8x8xf32>
95+
%1 = call @foo(%0,%0) {sdy.propagation_edges = #sdy.propagation_edges<[{step-1 = [{"a" = operand-0 -> [operand-1, result-0]}, {"b" = result-0 -> [operand-0, operand-1]}, {"c" = operand-1 -> [operand-0, result-0]}]}]>, sdy.sharding = #sdy.sharding_per_value<[<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>]>} : (tensor<8x8x8xf32>, tensor<8x8x8xf32>) -> tensor<8x8x8xf32>
96+
return %1 : tensor<8x8x8xf32>
97+
}
98+
99+
// CHECK-LABEL: func private @foo(
100+
// CHECK-SAME: %arg0: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>}
101+
// CHECK-SAME: %arg1: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>}) -> (
102+
// CHECK-SAME: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>}
103+
func.func private @foo(
104+
%arg0: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>},
105+
%arg1: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>, sdy.sharding_origins = {a = "mc_1_input: 0", b = "mc_1_input: 0"}}) -> (
106+
tensor<8x8x8xf32> {sdy.propagation_edges = #sdy.propagation_edges<[{step-0 = [{"b" = result-0 -> [operand-0]}]}, {step-2 = [{"a" = operand-0 -> [result-0]}, {"c" = operand-0 -> [result-0]}]}]>, sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>}) {
107+
// CHECK-NEXT: %0 = stablehlo.multiply %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>]>}
108+
%0 = stablehlo.multiply %arg0, %arg1 {sdy.propagation_edges = #sdy.propagation_edges<[{step-1 = [{"a" = operand-0 -> [operand-1, result-0]}, {"b" = result-0 -> [operand-0, operand-1]}, {"c" = operand-1 -> [operand-0, result-0]}]}]>, sdy.sharding = #sdy.sharding_per_value<[<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>]>} : tensor<8x8x8xf32>
109+
return %0 : tensor<8x8x8xf32>
110+
}
111+
112+

0 commit comments

Comments
 (0)