Skip to content

Commit 8fc810d

Browse files
ekayaaslancopybara-github
authored andcommitted
Push shardy outliner up past InsertExplicitReshards pass.
InsertExplicitReshards is for resolving sharding conflicts by inserting reshards. PiperOrigin-RevId: 895891342
1 parent c00b51a commit 8fc810d

File tree

14 files changed

+182
-112
lines changed

14 files changed

+182
-112
lines changed

shardy/dialect/sdy/ir/utils.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,5 +1108,18 @@ FuncOp cloneFuncRecursively(FuncOp funcOp, SymbolTable& symbolTable) {
11081108
return clonedFuncOp;
11091109
}
11101110

1111+
TensorShardingPerValueAttr getFullyClosedLike(mlir::ValueRange values,
1112+
Attribute meshOrRef) {
1113+
SmallVector<TensorShardingAttr> resultShardings;
1114+
resultShardings.reserve(values.size());
1115+
for (mlir::Value value : values) {
1116+
resultShardings.push_back(TensorShardingAttr::getFullyReplicated(
1117+
meshOrRef.getContext(), mlir::sdy::getTensorRank(value), meshOrRef,
1118+
/*isClosed=*/true));
1119+
}
1120+
return TensorShardingPerValueAttr::get(meshOrRef.getContext(),
1121+
resultShardings);
1122+
}
1123+
11111124
} // namespace sdy
11121125
} // namespace mlir

shardy/dialect/sdy/ir/utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,12 @@ Operation* getCommonSupportedReductionOp(stablehlo::ScatterOp scatter);
678678
mlir::func::FuncOp cloneFuncRecursively(func::FuncOp funcOp,
679679
SymbolTable& symbolTable);
680680

681+
// Returns a `TensorShardingPerValueAttr` on the shardings of the `values`. If
682+
// the sharding of a value is null, it creates a fully closed sharding for it on
683+
// the given `meshOrRef` and the rank of the tensor corresponding to the value.
684+
TensorShardingPerValueAttr getFullyClosedLike(mlir::ValueRange values,
685+
Attribute meshOrRef);
686+
681687
} // namespace sdy
682688
} // namespace mlir
683689

shardy/dialect/sdy/transforms/common/propagation_options.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ struct PropagationOptions {
4848
// auto-partitioner will be invoked after propagation of user-specified
4949
// shardings.
5050
bool enableAutoPartitioning = false;
51-
// Whether to avoid explicit reshards/collectives on named computations.
51+
// Whether to avoid explicit reshards/collectives on named computations/calls.
52+
// TODO(enver): Rename to avoidReshardsOnCalls.
5253
bool avoidReshardsOnNamedComputations = false;
5354
// Whether to update axes with non-divisible input/output shardings.
5455
bool updateNonDivisibleInputOutputShardings = true;

shardy/dialect/sdy/transforms/export/export_pipeline.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,9 @@ void runShardyPartitioner(OpPassManager& pm, int& dumpIndex,
4040
const ExportOptions& options) {
4141
InsertExplicitReshardsPassOptions passOptions;
4242
passOptions.enableFullVersion = options.enableInsertExplicitCollectives;
43-
passOptions.avoidReshardsOnNamedComputations =
44-
options.avoidReshardsOnNamedComputations;
43+
passOptions.avoidReshardsOnCalls = options.avoidReshardsOnCalls;
44+
pm.addPass(createExportNamedComputationsPass());
4545
pm.addNestedPass<func::FuncOp>(createInsertExplicitReshardsPass(passOptions));
46-
4746
if (options.enableInsertExplicitCollectives) {
4847
pm.addPass(mlir::sdy::createSaveModuleOpPass(
4948
options.dumpDirectory, "after_explicit_reshards", dumpIndex++));
@@ -98,12 +97,12 @@ void addExportPipeline(OpPassManager& pm, int& dumpIndex,
9897
// reshards/collectives.
9998
if (!options.avoidExportForPartitioning) {
10099
runShardyPartitioner(pm, dumpIndex, options);
100+
} else {
101+
pm.addPass(createExportNamedComputationsPass());
101102
}
102-
103103
if (options.dumpPropagationEdges || options.dumpShardingOrigins) {
104104
pm.addPass(createRemovePropagationDebugInfoPass());
105105
}
106-
pm.addPass(createExportNamedComputationsPass());
107106
if (!options.keepShardingRules) {
108107
pm.addNestedPass<func::FuncOp>(createDropShardingRulesPass());
109108
}

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ namespace sdy {
5151

5252
namespace {
5353

54+
using func::CallOp;
55+
using func::FuncOp;
56+
5457
void insertExplicitReshardsToTargetSharding(OpOperand& opOperand,
5558
TensorShardingAttr targetSharding,
5659
IRRewriter& rewriter,
@@ -132,6 +135,33 @@ void insertExplicitReshardsOnDataFlowOp(
132135
}
133136
}
134137

138+
void insertExplicitReshardsOnCallOp(CallOp callOp, IRRewriter& rewriter,
139+
const SymbolTable& symbolTable,
140+
const bool onFullVersion) {
141+
FuncOp funcOp = symbolTable.lookup<FuncOp>(callOp.getCallee());
142+
TensorShardingPerValueAttr funcArgShardings =
143+
mlir::sdy::getFuncArgShardings(funcOp, symbolTable);
144+
if (!funcArgShardings) {
145+
mlir::Attribute meshOrRef = getMeshOrRef(
146+
callOp.getNumOperands(), symbolTable,
147+
[&](int64_t i) { return getSharding(callOp.getOperand(i)); });
148+
// Return without inserting reshards as neither func arguments nor call
149+
// operands have a sharding with non-maximal mesh.
150+
if (!meshOrRef) {
151+
return;
152+
}
153+
funcArgShardings = getFullyClosedLike(callOp.getOperands(), meshOrRef);
154+
}
155+
rewriter.setInsertionPoint(callOp);
156+
for (auto [funcArgSharding, sourceOpOperand] : llvm::zip_equal(
157+
funcArgShardings.getShardings(), callOp->getOpOperands())) {
158+
insertExplicitReshardsToTargetSharding(
159+
sourceOpOperand,
160+
/*targetSharding=*/funcArgSharding, rewriter, symbolTable,
161+
/*insertAfterOperand=*/true, onFullVersion);
162+
}
163+
}
164+
135165
// Reshard the result of a dot operation if all the following hold:
136166
//
137167
// 1. LHS and RHS have fully compatible shardings.
@@ -382,7 +412,7 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
382412
}
383413
// To avoid copies of the same functions with mismatching shardings on the
384414
// arguments onto multiple callsites.
385-
if (isa<NamedComputationOp>(op)) {
415+
if (isa<NamedComputationOp, func::CallOp>(op)) {
386416
return true;
387417
}
388418

@@ -473,7 +503,15 @@ struct InsertExplicitReshardsPass
473503
// sharded in the same way.
474504
insertExplicitReshardsOnDataFlowOp(shardableDataFlowOp, rewriter,
475505
symbolTable, onFullVersion,
476-
avoidReshardsOnNamedComputations);
506+
avoidReshardsOnCalls);
507+
return;
508+
}
509+
510+
if (CallOp callOp = dyn_cast<CallOp>(op)) {
511+
if (!avoidReshardsOnCalls) {
512+
insertExplicitReshardsOnCallOp(callOp, rewriter, symbolTable,
513+
onFullVersion);
514+
}
477515
return;
478516
}
479517

shardy/dialect/sdy/transforms/export/passes.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ struct ExportOptions : public PassPipelineOptions<ExportOptions> {
7676
llvm::cl::desc("Sink sdy.propagation_edges attr."),
7777
llvm::cl::init(false)};
7878

79-
Option<bool> avoidReshardsOnNamedComputations{
80-
*this, "avoid-reshards-on-named-computations",
81-
llvm::cl::desc("Avoid inserting explicit reshards/collectives for named "
82-
"computations."),
79+
Option<bool> avoidReshardsOnCalls{
80+
*this, "avoid-reshards-on-calls",
81+
llvm::cl::desc(
82+
"Avoid inserting explicit reshards/collectives for calls."),
8383
llvm::cl::init(false)};
8484

8585
Option<bool> updateNonDivisibleInputOutputShardings{

shardy/dialect/sdy/transforms/export/passes.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,10 @@ def InsertExplicitReshardsPass : Pass<"sdy-insert-explicit-reshards", "func::Fun
128128
Option<"enableFullVersion", "enable-full-version",
129129
"bool", /*default=*/"false",
130130
"Enable full version.">,
131-
Option<"avoidReshardsOnNamedComputations",
132-
"avoid-reshards-on-named-computations",
131+
Option<"avoidReshardsOnCalls",
132+
"avoid-reshards-on-calls",
133133
"bool", /*default=*/"false",
134-
"Avoid explicit reshards/collectives on named computations.">
134+
"Avoid explicit reshards/collectives on calls.">
135135
];
136136
}
137137

shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards_avoid_reshards_on_named_computations.mlir renamed to shardy/dialect/sdy/transforms/export/test/call_ops_avoid_reshards_on_calls_true.mlir

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
1-
// RUN: sdy_opt %s -sdy-insert-explicit-reshards='avoid-reshards-on-named-computations=true' -sdy-insert-explicit-reshards='avoid-reshards-on-named-computations=true' | FileCheck %s
1+
// RUN: sdy_opt %s -sdy-insert-explicit-reshards='enable-full-version=true avoid-reshards-on-calls=true' | FileCheck %s
22

33
sdy.mesh @mesh = <["x"=2, "y"=2, "z"=4]>
44

5-
//===----------------------------------------------------------------------===//
6-
// Named computations tests
7-
// More tests are in insert_explicit_reshards/data_flow_ops.mlir
8-
//===----------------------------------------------------------------------===//
9-
105
// CHECK-LABEL: func @named_computation
116
func.func @named_computation(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}]>}) {
127
// CHECK-NEXT: sdy.named_computation<"foo">(%arg0)
@@ -35,8 +30,6 @@ func.func @one_argument_to_multiple_named_computations(%arg0: tensor<210xf32> {s
3530
%2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32>
3631
sdy.return %2 : tensor<210xf32>
3732
} : (tensor<210xf32>) -> (tensor<210xf32>)
38-
// CHECK: %[[ADD:.*]] = stablehlo.add %[[NC0]], %[[NC1]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>}
39-
// CHECK-NEXT: return %[[ADD]]
4033
%3 = stablehlo.add %0, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32>
4134
return %3 : tensor<210xf32>
4235
}
@@ -49,9 +42,6 @@ func.func @different_arguments_to_multiple_named_computations_with_same_input_ou
4942
%3 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
5043
sdy.return %3 : tensor<210xf32>
5144
} : (tensor<210xf32>) -> (tensor<210xf32>)
52-
// CHECK: %[[NEGATE:.*]] = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>}
53-
// CHECK-NEXT: %[[NC1:.*]] = sdy.named_computation<"foo">(%[[NEGATE]])
54-
// CHECK-SAME: in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>
5545
%1 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
5646
%2 = sdy.named_computation<"foo">(%1) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) {
5747
%3 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>

shardy/dialect/sdy/transforms/export/test/export_pipeline.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
sdy.mesh @mesh3d = <["a"=4, "b"=4, "c"=4]>
88

9+
910
// CHECK-LABEL: func @manual_computation_free_axes_non_divisible
1011
func.func @manual_computation_free_axes_non_divisible(
1112
%arg0: tensor<4xf32>, %arg1: tensor<12xf32>, %arg2: tensor<24xf32>,

shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -472,64 +472,6 @@ func.func @concatenate_same_shardings_func_result_different_sharding(%arg0: tens
472472
return %0 : tensor<4x80x256xf32>
473473
}
474474

475-
//===----------------------------------------------------------------------===//
476-
// Named computations tests
477-
// More tests are in insert_explicit_reshards/data_flow_ops.mlir
478-
//===----------------------------------------------------------------------===//
479-
480-
// CHECK-LABEL: func @named_computation
481-
func.func @named_computation(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}]>}) {
482-
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"y"}]> : tensor<210xf32>
483-
// CHECK-NEXT: sdy.named_computation<"foo">(%[[RESHARD]])
484-
%0 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"z"}]>] (%arg1: tensor<210xf32>) {
485-
%2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
486-
// CHECK: %[[RESHARD:.*]] = sdy.reshard %{{.*}} <@mesh, [{"z"}]> : tensor<210xf32>
487-
// CHECK-NEXT: sdy.return %[[RESHARD]] : tensor<210xf32>
488-
sdy.return %2 : tensor<210xf32>
489-
} : (tensor<210xf32>) -> (tensor<210xf32>)
490-
%1 = stablehlo.negate %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32>
491-
return %1 : tensor<210xf32>
492-
}
493-
494-
// CHECK-LABEL: func @one_argument_to_multiple_named_computations
495-
func.func @one_argument_to_multiple_named_computations(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}]>}) {
496-
// CHECK: %[[RESHARD0:.*]] = sdy.reshard %arg0 <@mesh, [{"z"}]> : tensor<210xf32>
497-
// CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{"y"}]> : tensor<210xf32>
498-
// CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[RESHARD1]])
499-
%0 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) {
500-
%2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
501-
sdy.return %2 : tensor<210xf32>
502-
} : (tensor<210xf32>) -> (tensor<210xf32>)
503-
// CHECK: %[[NC1:.*]] = sdy.named_computation<"foo">(%[[RESHARD0]])
504-
%1 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"z"}]>] out_shardings=[<@mesh, [{"z"}]>] (%arg1: tensor<210xf32>) {
505-
%2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32>
506-
sdy.return %2 : tensor<210xf32>
507-
} : (tensor<210xf32>) -> (tensor<210xf32>)
508-
// CHECK: %[[ADD:.*]] = stablehlo.add %[[NC0]], %[[NC1]]
509-
// CHECK-NEXT: return %[[ADD]]
510-
%3 = stablehlo.add %0, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32>
511-
return %3 : tensor<210xf32>
512-
}
513-
514-
// CHECK-LABEL: func @different_arguments_to_multiple_named_computations_with_same_input_output_shardings
515-
func.func @different_arguments_to_multiple_named_computations_with_same_input_output_shardings(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) {
516-
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"y"}]> : tensor<210xf32>
517-
// CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[RESHARD]])
518-
%0 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) {
519-
%3 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
520-
sdy.return %3 : tensor<210xf32>
521-
} : (tensor<210xf32>) -> (tensor<210xf32>)
522-
// CHECK: %[[NEGATE:.*]] = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>}
523-
// CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[NEGATE]])
524-
%1 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
525-
%2 = sdy.named_computation<"foo">(%1) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) {
526-
%3 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
527-
sdy.return %3 : tensor<210xf32>
528-
} : (tensor<210xf32>) -> (tensor<210xf32>)
529-
%4 = stablehlo.add %0, %2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
530-
return %4 : tensor<210xf32>
531-
}
532-
533475
//===----------------------------------------------------------------------===//
534476
// Replicated and sharded to unreduced tests
535477
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)