diff --git a/shardy/dialect/sdy/ir/utils.cc b/shardy/dialect/sdy/ir/utils.cc index d470b30aa..efd482d26 100644 --- a/shardy/dialect/sdy/ir/utils.cc +++ b/shardy/dialect/sdy/ir/utils.cc @@ -1108,5 +1108,18 @@ FuncOp cloneFuncRecursively(FuncOp funcOp, SymbolTable& symbolTable) { return clonedFuncOp; } +TensorShardingPerValueAttr getFullyClosedLike(mlir::ValueRange values, + Attribute meshOrRef) { + SmallVector resultShardings; + resultShardings.reserve(values.size()); + for (mlir::Value value : values) { + resultShardings.push_back(TensorShardingAttr::getFullyReplicated( + meshOrRef.getContext(), mlir::sdy::getTensorRank(value), meshOrRef, + /*isClosed=*/true)); + } + return TensorShardingPerValueAttr::get(meshOrRef.getContext(), + resultShardings); +} + } // namespace sdy } // namespace mlir diff --git a/shardy/dialect/sdy/ir/utils.h b/shardy/dialect/sdy/ir/utils.h index 13c646f96..77363695a 100644 --- a/shardy/dialect/sdy/ir/utils.h +++ b/shardy/dialect/sdy/ir/utils.h @@ -678,6 +678,12 @@ Operation* getCommonSupportedReductionOp(stablehlo::ScatterOp scatter); mlir::func::FuncOp cloneFuncRecursively(func::FuncOp funcOp, SymbolTable& symbolTable); +// Returns a `TensorShardingPerValueAttr` on the shardings of the `values`. If +// the sharding of a value is null, it creates a fully closed sharding for it on +// the given `meshOrRef` and the rank of the tensor corresponding to the value. +TensorShardingPerValueAttr getFullyClosedLike(mlir::ValueRange values, + Attribute meshOrRef); + } // namespace sdy } // namespace mlir diff --git a/shardy/dialect/sdy/transforms/common/propagation_options.h b/shardy/dialect/sdy/transforms/common/propagation_options.h index 6875b1bd2..aff8b749d 100644 --- a/shardy/dialect/sdy/transforms/common/propagation_options.h +++ b/shardy/dialect/sdy/transforms/common/propagation_options.h @@ -48,7 +48,8 @@ struct PropagationOptions { // auto-partitioner will be invoked after propagation of user-specified // shardings. bool enableAutoPartitioning = false; - // Whether to avoid explicit reshards/collectives on named computations. + // Whether to avoid explicit reshards/collectives on named computations/calls. + // TODO(enver): Rename to avoidReshardsOnCalls. bool avoidReshardsOnNamedComputations = false; // Whether to update axes with non-divisible input/output shardings. bool updateNonDivisibleInputOutputShardings = true; diff --git a/shardy/dialect/sdy/transforms/export/export_pipeline.cc b/shardy/dialect/sdy/transforms/export/export_pipeline.cc index 2901c8d06..1c74e72b8 100644 --- a/shardy/dialect/sdy/transforms/export/export_pipeline.cc +++ b/shardy/dialect/sdy/transforms/export/export_pipeline.cc @@ -40,10 +40,9 @@ void runShardyPartitioner(OpPassManager& pm, int& dumpIndex, const ExportOptions& options) { InsertExplicitReshardsPassOptions passOptions; passOptions.enableFullVersion = options.enableInsertExplicitCollectives; - passOptions.avoidReshardsOnNamedComputations = - options.avoidReshardsOnNamedComputations; + passOptions.avoidReshardsOnCalls = options.avoidReshardsOnCalls; + pm.addPass(createExportNamedComputationsPass()); pm.addNestedPass(createInsertExplicitReshardsPass(passOptions)); - if (options.enableInsertExplicitCollectives) { pm.addPass(mlir::sdy::createSaveModuleOpPass( options.dumpDirectory, "after_explicit_reshards", dumpIndex++)); @@ -98,12 +97,12 @@ void addExportPipeline(OpPassManager& pm, int& dumpIndex, // reshards/collectives. if (!options.avoidExportForPartitioning) { runShardyPartitioner(pm, dumpIndex, options); + } else { + pm.addPass(createExportNamedComputationsPass()); } - if (options.dumpPropagationEdges || options.dumpShardingOrigins) { pm.addPass(createRemovePropagationDebugInfoPass()); } - pm.addPass(createExportNamedComputationsPass()); if (!options.keepShardingRules) { pm.addNestedPass(createDropShardingRulesPass()); } diff --git a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc index 41e887afa..1e03f0623 100644 --- a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc +++ b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc @@ -51,6 +51,9 @@ namespace sdy { namespace { +using func::CallOp; +using func::FuncOp; + void insertExplicitReshardsToTargetSharding(OpOperand& opOperand, TensorShardingAttr targetSharding, IRRewriter& rewriter, @@ -132,6 +135,33 @@ void insertExplicitReshardsOnDataFlowOp( } } +void insertExplicitReshardsOnCallOp(CallOp callOp, IRRewriter& rewriter, + const SymbolTable& symbolTable, + const bool onFullVersion) { + FuncOp funcOp = symbolTable.lookup(callOp.getCallee()); + TensorShardingPerValueAttr funcArgShardings = + mlir::sdy::getFuncArgShardings(funcOp, symbolTable); + if (!funcArgShardings) { + mlir::Attribute meshOrRef = getMeshOrRef( + callOp.getNumOperands(), symbolTable, + [&](int64_t i) { return getSharding(callOp.getOperand(i)); }); + // Return without inserting reshards as neither func arguments nor call + // operands have a sharding with non-maximal mesh. + if (!meshOrRef) { + return; + } + funcArgShardings = getFullyClosedLike(callOp.getOperands(), meshOrRef); + } + rewriter.setInsertionPoint(callOp); + for (auto [funcArgSharding, sourceOpOperand] : llvm::zip_equal( + funcArgShardings.getShardings(), callOp->getOpOperands())) { + insertExplicitReshardsToTargetSharding( + sourceOpOperand, + /*targetSharding=*/funcArgSharding, rewriter, symbolTable, + /*insertAfterOperand=*/true, onFullVersion); + } +} + // Reshard the result of a dot operation if all the following hold: // // 1. LHS and RHS have fully compatible shardings. @@ -382,7 +412,7 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) { } // To avoid copies of the same functions with mismatching shardings on the // arguments onto multiple callsites. - if (isa(op)) { + if (isa(op)) { return true; } @@ -473,7 +503,15 @@ struct InsertExplicitReshardsPass // sharded in the same way. insertExplicitReshardsOnDataFlowOp(shardableDataFlowOp, rewriter, symbolTable, onFullVersion, - avoidReshardsOnNamedComputations); + avoidReshardsOnCalls); + return; + } + + if (CallOp callOp = dyn_cast(op)) { + if (!avoidReshardsOnCalls) { + insertExplicitReshardsOnCallOp(callOp, rewriter, symbolTable, + onFullVersion); + } return; } diff --git a/shardy/dialect/sdy/transforms/export/passes.h b/shardy/dialect/sdy/transforms/export/passes.h index 0a8f5823a..bf60286af 100644 --- a/shardy/dialect/sdy/transforms/export/passes.h +++ b/shardy/dialect/sdy/transforms/export/passes.h @@ -76,10 +76,10 @@ struct ExportOptions : public PassPipelineOptions { llvm::cl::desc("Sink sdy.propagation_edges attr."), llvm::cl::init(false)}; - Option avoidReshardsOnNamedComputations{ - *this, "avoid-reshards-on-named-computations", - llvm::cl::desc("Avoid inserting explicit reshards/collectives for named " - "computations."), + Option avoidReshardsOnCalls{ + *this, "avoid-reshards-on-calls", + llvm::cl::desc( + "Avoid inserting explicit reshards/collectives for calls."), llvm::cl::init(false)}; Option updateNonDivisibleInputOutputShardings{ diff --git a/shardy/dialect/sdy/transforms/export/passes.td b/shardy/dialect/sdy/transforms/export/passes.td index 0374fdf1c..082c7bffe 100644 --- a/shardy/dialect/sdy/transforms/export/passes.td +++ b/shardy/dialect/sdy/transforms/export/passes.td @@ -128,10 +128,10 @@ def InsertExplicitReshardsPass : Pass<"sdy-insert-explicit-reshards", "func::Fun Option<"enableFullVersion", "enable-full-version", "bool", /*default=*/"false", "Enable full version.">, - Option<"avoidReshardsOnNamedComputations", - "avoid-reshards-on-named-computations", + Option<"avoidReshardsOnCalls", + "avoid-reshards-on-calls", "bool", /*default=*/"false", - "Avoid explicit reshards/collectives on named computations."> + "Avoid explicit reshards/collectives on calls."> ]; } diff --git a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards_avoid_reshards_on_named_computations.mlir b/shardy/dialect/sdy/transforms/export/test/call_ops_avoid_reshards_on_calls_true.mlir similarity index 81% rename from shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards_avoid_reshards_on_named_computations.mlir rename to shardy/dialect/sdy/transforms/export/test/call_ops_avoid_reshards_on_calls_true.mlir index 15acf582c..7f7a2b142 100644 --- a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards_avoid_reshards_on_named_computations.mlir +++ b/shardy/dialect/sdy/transforms/export/test/call_ops_avoid_reshards_on_calls_true.mlir @@ -1,12 +1,7 @@ -// RUN: sdy_opt %s -sdy-insert-explicit-reshards='avoid-reshards-on-named-computations=true' -sdy-insert-explicit-reshards='avoid-reshards-on-named-computations=true' | FileCheck %s +// RUN: sdy_opt %s -sdy-insert-explicit-reshards='enable-full-version=true avoid-reshards-on-calls=true' | FileCheck %s sdy.mesh @mesh = <["x"=2, "y"=2, "z"=4]> -//===----------------------------------------------------------------------===// -// Named computations tests -// More tests are in insert_explicit_reshards/data_flow_ops.mlir -//===----------------------------------------------------------------------===// - // CHECK-LABEL: func @named_computation func.func @named_computation(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}]>}) { // CHECK-NEXT: sdy.named_computation<"foo">(%arg0) @@ -35,8 +30,6 @@ func.func @one_argument_to_multiple_named_computations(%arg0: tensor<210xf32> {s %2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32> sdy.return %2 : tensor<210xf32> } : (tensor<210xf32>) -> (tensor<210xf32>) - // CHECK: %[[ADD:.*]] = stablehlo.add %[[NC0]], %[[NC1]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} - // CHECK-NEXT: return %[[ADD]] %3 = stablehlo.add %0, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32> return %3 : tensor<210xf32> } @@ -49,9 +42,6 @@ func.func @different_arguments_to_multiple_named_computations_with_same_input_ou %3 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32> sdy.return %3 : tensor<210xf32> } : (tensor<210xf32>) -> (tensor<210xf32>) - // CHECK: %[[NEGATE:.*]] = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} - // CHECK-NEXT: %[[NC1:.*]] = sdy.named_computation<"foo">(%[[NEGATE]]) - // CHECK-SAME: in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]> %1 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32> %2 = sdy.named_computation<"foo">(%1) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) { %3 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32> diff --git a/shardy/dialect/sdy/transforms/export/test/export_pipeline.mlir b/shardy/dialect/sdy/transforms/export/test/export_pipeline.mlir index 03041c64c..a077d323e 100644 --- a/shardy/dialect/sdy/transforms/export/test/export_pipeline.mlir +++ b/shardy/dialect/sdy/transforms/export/test/export_pipeline.mlir @@ -6,6 +6,7 @@ sdy.mesh @mesh3d = <["a"=4, "b"=4, "c"=4]> + // CHECK-LABEL: func @manual_computation_free_axes_non_divisible func.func @manual_computation_free_axes_non_divisible( %arg0: tensor<4xf32>, %arg1: tensor<12xf32>, %arg2: tensor<24xf32>, diff --git a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir index b0f9aafc9..c88169c06 100644 --- a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir +++ b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir @@ -472,64 +472,6 @@ func.func @concatenate_same_shardings_func_result_different_sharding(%arg0: tens return %0 : tensor<4x80x256xf32> } -//===----------------------------------------------------------------------===// -// Named computations tests -// More tests are in insert_explicit_reshards/data_flow_ops.mlir -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func @named_computation -func.func @named_computation(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}]>}) { - // CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"y"}]> : tensor<210xf32> - // CHECK-NEXT: sdy.named_computation<"foo">(%[[RESHARD]]) - %0 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"z"}]>] (%arg1: tensor<210xf32>) { - %2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32> - // CHECK: %[[RESHARD:.*]] = sdy.reshard %{{.*}} <@mesh, [{"z"}]> : tensor<210xf32> - // CHECK-NEXT: sdy.return %[[RESHARD]] : tensor<210xf32> - sdy.return %2 : tensor<210xf32> - } : (tensor<210xf32>) -> (tensor<210xf32>) - %1 = stablehlo.negate %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32> - return %1 : tensor<210xf32> -} - -// CHECK-LABEL: func @one_argument_to_multiple_named_computations -func.func @one_argument_to_multiple_named_computations(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}]>}) { - // CHECK: %[[RESHARD0:.*]] = sdy.reshard %arg0 <@mesh, [{"z"}]> : tensor<210xf32> - // CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{"y"}]> : tensor<210xf32> - // CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[RESHARD1]]) - %0 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) { - %2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32> - sdy.return %2 : tensor<210xf32> - } : (tensor<210xf32>) -> (tensor<210xf32>) - // CHECK: %[[NC1:.*]] = sdy.named_computation<"foo">(%[[RESHARD0]]) - %1 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"z"}]>] out_shardings=[<@mesh, [{"z"}]>] (%arg1: tensor<210xf32>) { - %2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32> - sdy.return %2 : tensor<210xf32> - } : (tensor<210xf32>) -> (tensor<210xf32>) - // CHECK: %[[ADD:.*]] = stablehlo.add %[[NC0]], %[[NC1]] - // CHECK-NEXT: return %[[ADD]] - %3 = stablehlo.add %0, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32> - return %3 : tensor<210xf32> -} - -// CHECK-LABEL: func @different_arguments_to_multiple_named_computations_with_same_input_output_shardings -func.func @different_arguments_to_multiple_named_computations_with_same_input_output_shardings(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) { - // CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"y"}]> : tensor<210xf32> - // CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[RESHARD]]) - %0 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) { - %3 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32> - sdy.return %3 : tensor<210xf32> - } : (tensor<210xf32>) -> (tensor<210xf32>) - // CHECK: %[[NEGATE:.*]] = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} - // CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[NEGATE]]) - %1 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32> - %2 = sdy.named_computation<"foo">(%1) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) { - %3 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32> - sdy.return %3 : tensor<210xf32> - } : (tensor<210xf32>) -> (tensor<210xf32>) - %4 = stablehlo.add %0, %2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32> - return %4 : tensor<210xf32> -} - //===----------------------------------------------------------------------===// // Replicated and sharded to unreduced tests //===----------------------------------------------------------------------===// diff --git a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/call_ops.mlir b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/call_ops.mlir new file mode 100644 index 000000000..31a08e6ee --- /dev/null +++ b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/call_ops.mlir @@ -0,0 +1,48 @@ +// RUN: sdy_opt %s -split-input-file -sdy-insert-explicit-reshards='enable-full-version=true' | FileCheck %s + +sdy.mesh @mesh = <["x"=4, "y"=2]> + +// CHECK-LABEL: func @named_computation +func.func @named_computation(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x":(1)2}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) { + // CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"x"}]> : tensor<210xf32> + // CHECK-NEXT: sdy.named_computation<"foo">(%[[RESHARD]]) + %0 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"x"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) { + %2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : tensor<210xf32> + // CHECK: %[[RESHARD:.*]] = sdy.reshard %{{.*}} <@mesh, [{"y"}]> : tensor<210xf32> + // CHECK-NEXT: sdy.return %[[RESHARD]] : tensor<210xf32> + sdy.return %2 : tensor<210xf32> + } : (tensor<210xf32>) -> (tensor<210xf32>) + %1 = stablehlo.negate %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32> + return %1 : tensor<210xf32> +} + +// ----- +sdy.mesh @mesh = <["x"=4, "y"=2]> + +// CHECK-LABEL: func @named_computation_empty_block +func.func @named_computation_empty_block(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) { + // CHECK: sdy.named_computation<"foo">(%arg0) + %0 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"x"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) { + // CHECK: %[[RESHARD:.*]] = sdy.reshard %{{.*}} <@mesh, [{"y"}]> : tensor<210xf32> + // CHECK-NEXT: sdy.return %[[RESHARD]] : tensor<210xf32> + sdy.return %arg1 : tensor<210xf32> + } : (tensor<210xf32>) -> (tensor<210xf32>) + %1 = stablehlo.negate %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32> + return %1 : tensor<210xf32> +} + +// ----- +sdy.mesh @mesh = <["a"=2, "b"=2, "c"=2]> + +// CHECK-LABEL: func @named_computation_with_shardings +func.func @named_computation_with_shardings(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> tensor<12x2xi32> { + %0:2 = call @foo(%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>, <@mesh, [{}, {}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) + %1 = stablehlo.concatenate %0#0, %0#1, dim = 0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"a"}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> tensor<12x2xi32> + return %1 : tensor<12x2xi32> +} + +func.func private @foo(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, %arg1: tensor<4x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}) + -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, tensor<4x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}) { + %0 = stablehlo.abs %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"a"}]>]>} : tensor<8x2xi32> + return %0, %arg1 : tensor<8x2xi32>, tensor<4x2xi32> +} diff --git a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/call_ops_enable_full_version_false.mlir b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/call_ops_enable_full_version_false.mlir new file mode 100644 index 000000000..9667c37aa --- /dev/null +++ b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/call_ops_enable_full_version_false.mlir @@ -0,0 +1,59 @@ +// RUN: sdy_opt %s -split-input-file -sdy-insert-explicit-reshards='enable-full-version=false' | FileCheck %s + +sdy.mesh @mesh = <["x"=2, "y"=2, "z"=4]> + +// CHECK-LABEL: func @named_computation +func.func @named_computation(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}]>}) { + // CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"y"}]> : tensor<210xf32> + // CHECK-NEXT: sdy.named_computation<"foo">(%[[RESHARD]]) + %0 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"z"}]>] (%arg1: tensor<210xf32>) { + %2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32> + // CHECK: %[[RESHARD:.*]] = sdy.reshard %{{.*}} <@mesh, [{"z"}]> : tensor<210xf32> + // CHECK-NEXT: sdy.return %[[RESHARD]] : tensor<210xf32> + sdy.return %2 : tensor<210xf32> + } : (tensor<210xf32>) -> (tensor<210xf32>) + %1 = stablehlo.negate %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32> + return %1 : tensor<210xf32> +} + +// ----- +sdy.mesh @mesh = <["x"=2, "y"=2, "z"=4]> + +// CHECK-LABEL: func @one_argument_to_multiple_named_computations +func.func @one_argument_to_multiple_named_computations(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}]>}) { + // CHECK: %[[RESHARD0:.*]] = sdy.reshard %arg0 <@mesh, [{"z"}]> : tensor<210xf32> + // CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{"y"}]> : tensor<210xf32> + // CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[RESHARD1]]) + %0 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) { + %2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32> + sdy.return %2 : tensor<210xf32> + } : (tensor<210xf32>) -> (tensor<210xf32>) + // CHECK: %[[NC1:.*]] = sdy.named_computation<"foo">(%[[RESHARD0]]) + %1 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"z"}]>] out_shardings=[<@mesh, [{"z"}]>] (%arg1: tensor<210xf32>) { + %2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32> + sdy.return %2 : tensor<210xf32> + } : (tensor<210xf32>) -> (tensor<210xf32>) + // CHECK: %[[ADD:.*]] = stablehlo.add %[[NC0]], %[[NC1]] + // CHECK-NEXT: return %[[ADD]] + %3 = stablehlo.add %0, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32> + return %3 : tensor<210xf32> +} + +// CHECK-LABEL: func @different_arguments_to_multiple_named_computations_with_same_input_output_shardings +func.func @different_arguments_to_multiple_named_computations_with_same_input_output_shardings(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) { + // CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"y"}]> : tensor<210xf32> + // CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[RESHARD]]) + %0 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) { + %3 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32> + sdy.return %3 : tensor<210xf32> + } : (tensor<210xf32>) -> (tensor<210xf32>) + // CHECK: %[[NEGATE:.*]] = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} + // CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[NEGATE]]) + %1 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32> + %2 = sdy.named_computation<"foo">(%1) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) { + %3 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32> + sdy.return %3 : tensor<210xf32> + } : (tensor<210xf32>) -> (tensor<210xf32>) + %4 = stablehlo.add %0, %2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32> + return %4 : tensor<210xf32> +} diff --git a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/data_flow_ops.mlir b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/data_flow_ops.mlir index d265b5f7a..f2891d04a 100644 --- a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/data_flow_ops.mlir +++ b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/data_flow_ops.mlir @@ -2,32 +2,6 @@ sdy.mesh @mesh = <["x"=4, "y"=2]> -// CHECK-LABEL: func @named_computation -func.func @named_computation(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x":(1)2}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) { - // CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"x"}]> : tensor<210xf32> - // CHECK-NEXT: sdy.named_computation<"foo">(%[[RESHARD]]) - %0 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"x"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) { - %2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : tensor<210xf32> - // CHECK: %[[RESHARD:.*]] = sdy.reshard %{{.*}} <@mesh, [{"y"}]> : tensor<210xf32> - // CHECK-NEXT: sdy.return %[[RESHARD]] : tensor<210xf32> - sdy.return %2 : tensor<210xf32> - } : (tensor<210xf32>) -> (tensor<210xf32>) - %1 = stablehlo.negate %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32> - return %1 : tensor<210xf32> -} - -// CHECK-LABEL: func @named_computation_empty_block -func.func @named_computation_empty_block(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) { - // CHECK: sdy.named_computation<"foo">(%arg0) - %0 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"x"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) { - // CHECK: %[[RESHARD:.*]] = sdy.reshard %{{.*}} <@mesh, [{"y"}]> : tensor<210xf32> - // CHECK-NEXT: sdy.return %[[RESHARD]] : tensor<210xf32> - sdy.return %arg1 : tensor<210xf32> - } : (tensor<210xf32>) -> (tensor<210xf32>) - %1 = stablehlo.negate %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32> - return %1 : tensor<210xf32> -} - // CHECK-LABEL: func @case func.func @case(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x":(1)2}]>}, %arg1: tensor) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) { %0 = "stablehlo.case"(%arg1) ({ diff --git a/shardy/dialect/sdy/transforms/propagation/propagation_pipeline.cc b/shardy/dialect/sdy/transforms/propagation/propagation_pipeline.cc index 265074334..215078cdf 100644 --- a/shardy/dialect/sdy/transforms/propagation/propagation_pipeline.cc +++ b/shardy/dialect/sdy/transforms/propagation/propagation_pipeline.cc @@ -40,8 +40,7 @@ void populateExportOptions(ExportOptions& options, propOptions.removeAllGatherReduceScatterForCMV1; options.dumpShardingOrigins = propOptions.debugShardingOrigins; options.dumpPropagationEdges = propOptions.debugPropagationEdgeSharding; - options.avoidReshardsOnNamedComputations = - propOptions.avoidReshardsOnNamedComputations; + options.avoidReshardsOnCalls = propOptions.avoidReshardsOnNamedComputations; options.updateNonDivisibleInputOutputShardings = propOptions.updateNonDivisibleInputOutputShardings; options.enableNativeNonFlatSupport = propOptions.enableNativeNonFlatSupport;