Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions shardy/dialect/sdy/ir/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1108,5 +1108,18 @@ FuncOp cloneFuncRecursively(FuncOp funcOp, SymbolTable& symbolTable) {
return clonedFuncOp;
}

TensorShardingPerValueAttr getFullyClosedLike(mlir::ValueRange values,
Attribute meshOrRef) {
SmallVector<TensorShardingAttr> resultShardings;
resultShardings.reserve(values.size());
for (mlir::Value value : values) {
resultShardings.push_back(TensorShardingAttr::getFullyReplicated(
meshOrRef.getContext(), mlir::sdy::getTensorRank(value), meshOrRef,
/*isClosed=*/true));
}
return TensorShardingPerValueAttr::get(meshOrRef.getContext(),
resultShardings);
}

} // namespace sdy
} // namespace mlir
6 changes: 6 additions & 0 deletions shardy/dialect/sdy/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,12 @@ Operation* getCommonSupportedReductionOp(stablehlo::ScatterOp scatter);
mlir::func::FuncOp cloneFuncRecursively(func::FuncOp funcOp,
SymbolTable& symbolTable);

// Returns a `TensorShardingPerValueAttr` on the shardings of the `values`. If
// the sharding of a value is null, it creates a fully closed sharding for it on
// the given `meshOrRef` and the rank of the tensor corresponding to the value.
TensorShardingPerValueAttr getFullyClosedLike(mlir::ValueRange values,
Attribute meshOrRef);

} // namespace sdy
} // namespace mlir

Expand Down
3 changes: 2 additions & 1 deletion shardy/dialect/sdy/transforms/common/propagation_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ struct PropagationOptions {
// auto-partitioner will be invoked after propagation of user-specified
// shardings.
bool enableAutoPartitioning = false;
// Whether to avoid explicit reshards/collectives on named computations.
// Whether to avoid explicit reshards/collectives on named computations/calls.
// TODO(enver): Rename to avoidReshardsOnCalls.
bool avoidReshardsOnNamedComputations = false;
// Whether to update axes with non-divisible input/output shardings.
bool updateNonDivisibleInputOutputShardings = true;
Expand Down
9 changes: 4 additions & 5 deletions shardy/dialect/sdy/transforms/export/export_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,9 @@ void runShardyPartitioner(OpPassManager& pm, int& dumpIndex,
const ExportOptions& options) {
InsertExplicitReshardsPassOptions passOptions;
passOptions.enableFullVersion = options.enableInsertExplicitCollectives;
passOptions.avoidReshardsOnNamedComputations =
options.avoidReshardsOnNamedComputations;
passOptions.avoidReshardsOnCalls = options.avoidReshardsOnCalls;
pm.addPass(createExportNamedComputationsPass());
pm.addNestedPass<func::FuncOp>(createInsertExplicitReshardsPass(passOptions));

if (options.enableInsertExplicitCollectives) {
pm.addPass(mlir::sdy::createSaveModuleOpPass(
options.dumpDirectory, "after_explicit_reshards", dumpIndex++));
Expand Down Expand Up @@ -98,12 +97,12 @@ void addExportPipeline(OpPassManager& pm, int& dumpIndex,
// reshards/collectives.
if (!options.avoidExportForPartitioning) {
runShardyPartitioner(pm, dumpIndex, options);
} else {
pm.addPass(createExportNamedComputationsPass());
}

if (options.dumpPropagationEdges || options.dumpShardingOrigins) {
pm.addPass(createRemovePropagationDebugInfoPass());
}
pm.addPass(createExportNamedComputationsPass());
if (!options.keepShardingRules) {
pm.addNestedPass<func::FuncOp>(createDropShardingRulesPass());
}
Expand Down
42 changes: 40 additions & 2 deletions shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ namespace sdy {

namespace {

using func::CallOp;
using func::FuncOp;

void insertExplicitReshardsToTargetSharding(OpOperand& opOperand,
TensorShardingAttr targetSharding,
IRRewriter& rewriter,
Expand Down Expand Up @@ -132,6 +135,33 @@ void insertExplicitReshardsOnDataFlowOp(
}
}

void insertExplicitReshardsOnCallOp(CallOp callOp, IRRewriter& rewriter,
const SymbolTable& symbolTable,
const bool onFullVersion) {
FuncOp funcOp = symbolTable.lookup<FuncOp>(callOp.getCallee());
TensorShardingPerValueAttr funcArgShardings =
mlir::sdy::getFuncArgShardings(funcOp, symbolTable);
if (!funcArgShardings) {
mlir::Attribute meshOrRef = getMeshOrRef(
callOp.getNumOperands(), symbolTable,
[&](int64_t i) { return getSharding(callOp.getOperand(i)); });
// Return without inserting reshards as neither func arguments nor call
// operands have a sharding with non-maximal mesh.
if (!meshOrRef) {
return;
}
funcArgShardings = getFullyClosedLike(callOp.getOperands(), meshOrRef);
}
rewriter.setInsertionPoint(callOp);
for (auto [funcArgSharding, sourceOpOperand] : llvm::zip_equal(
funcArgShardings.getShardings(), callOp->getOpOperands())) {
insertExplicitReshardsToTargetSharding(
sourceOpOperand,
/*targetSharding=*/funcArgSharding, rewriter, symbolTable,
/*insertAfterOperand=*/true, onFullVersion);
}
}

// Reshard the result of a dot operation if all the following hold:
//
// 1. LHS and RHS have fully compatible shardings.
Expand Down Expand Up @@ -382,7 +412,7 @@ bool isOnFullVersion(Operation* op, const bool enableFullVersion) {
}
// To avoid copies of the same functions with mismatching shardings on the
// arguments onto multiple callsites.
if (isa<NamedComputationOp>(op)) {
if (isa<NamedComputationOp, func::CallOp>(op)) {
return true;
}

Expand Down Expand Up @@ -473,7 +503,15 @@ struct InsertExplicitReshardsPass
// sharded in the same way.
insertExplicitReshardsOnDataFlowOp(shardableDataFlowOp, rewriter,
symbolTable, onFullVersion,
avoidReshardsOnNamedComputations);
avoidReshardsOnCalls);
return;
}

if (CallOp callOp = dyn_cast<CallOp>(op)) {
if (!avoidReshardsOnCalls) {
insertExplicitReshardsOnCallOp(callOp, rewriter, symbolTable,
onFullVersion);
}
return;
}

Expand Down
8 changes: 4 additions & 4 deletions shardy/dialect/sdy/transforms/export/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ struct ExportOptions : public PassPipelineOptions<ExportOptions> {
llvm::cl::desc("Sink sdy.propagation_edges attr."),
llvm::cl::init(false)};

Option<bool> avoidReshardsOnNamedComputations{
*this, "avoid-reshards-on-named-computations",
llvm::cl::desc("Avoid inserting explicit reshards/collectives for named "
"computations."),
Option<bool> avoidReshardsOnCalls{
*this, "avoid-reshards-on-calls",
llvm::cl::desc(
"Avoid inserting explicit reshards/collectives for calls."),
llvm::cl::init(false)};

Option<bool> updateNonDivisibleInputOutputShardings{
Expand Down
6 changes: 3 additions & 3 deletions shardy/dialect/sdy/transforms/export/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,10 @@ def InsertExplicitReshardsPass : Pass<"sdy-insert-explicit-reshards", "func::Fun
Option<"enableFullVersion", "enable-full-version",
"bool", /*default=*/"false",
"Enable full version.">,
Option<"avoidReshardsOnNamedComputations",
"avoid-reshards-on-named-computations",
Option<"avoidReshardsOnCalls",
"avoid-reshards-on-calls",
"bool", /*default=*/"false",
"Avoid explicit reshards/collectives on named computations.">
"Avoid explicit reshards/collectives on calls.">
];
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
// RUN: sdy_opt %s -sdy-insert-explicit-reshards='avoid-reshards-on-named-computations=true' -sdy-insert-explicit-reshards='avoid-reshards-on-named-computations=true' | FileCheck %s
// RUN: sdy_opt %s -sdy-insert-explicit-reshards='enable-full-version=true avoid-reshards-on-calls=true' | FileCheck %s

sdy.mesh @mesh = <["x"=2, "y"=2, "z"=4]>

//===----------------------------------------------------------------------===//
// Named computations tests
// More tests are in insert_explicit_reshards/data_flow_ops.mlir
//===----------------------------------------------------------------------===//

// CHECK-LABEL: func @named_computation
func.func @named_computation(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}]>}) {
// CHECK-NEXT: sdy.named_computation<"foo">(%arg0)
Expand Down Expand Up @@ -35,8 +30,6 @@ func.func @one_argument_to_multiple_named_computations(%arg0: tensor<210xf32> {s
%2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32>
sdy.return %2 : tensor<210xf32>
} : (tensor<210xf32>) -> (tensor<210xf32>)
// CHECK: %[[ADD:.*]] = stablehlo.add %[[NC0]], %[[NC1]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>}
// CHECK-NEXT: return %[[ADD]]
%3 = stablehlo.add %0, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32>
return %3 : tensor<210xf32>
}
Expand All @@ -49,9 +42,6 @@ func.func @different_arguments_to_multiple_named_computations_with_same_input_ou
%3 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
sdy.return %3 : tensor<210xf32>
} : (tensor<210xf32>) -> (tensor<210xf32>)
// CHECK: %[[NEGATE:.*]] = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>}
// CHECK-NEXT: %[[NC1:.*]] = sdy.named_computation<"foo">(%[[NEGATE]])
// CHECK-SAME: in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>
%1 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
%2 = sdy.named_computation<"foo">(%1) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) {
%3 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

sdy.mesh @mesh3d = <["a"=4, "b"=4, "c"=4]>


// CHECK-LABEL: func @manual_computation_free_axes_non_divisible
func.func @manual_computation_free_axes_non_divisible(
%arg0: tensor<4xf32>, %arg1: tensor<12xf32>, %arg2: tensor<24xf32>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -472,64 +472,6 @@ func.func @concatenate_same_shardings_func_result_different_sharding(%arg0: tens
return %0 : tensor<4x80x256xf32>
}

//===----------------------------------------------------------------------===//
// Named computations tests
// More tests are in insert_explicit_reshards/data_flow_ops.mlir
//===----------------------------------------------------------------------===//

// CHECK-LABEL: func @named_computation
func.func @named_computation(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}]>}) {
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"y"}]> : tensor<210xf32>
// CHECK-NEXT: sdy.named_computation<"foo">(%[[RESHARD]])
%0 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"z"}]>] (%arg1: tensor<210xf32>) {
%2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
// CHECK: %[[RESHARD:.*]] = sdy.reshard %{{.*}} <@mesh, [{"z"}]> : tensor<210xf32>
// CHECK-NEXT: sdy.return %[[RESHARD]] : tensor<210xf32>
sdy.return %2 : tensor<210xf32>
} : (tensor<210xf32>) -> (tensor<210xf32>)
%1 = stablehlo.negate %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32>
return %1 : tensor<210xf32>
}

// CHECK-LABEL: func @one_argument_to_multiple_named_computations
func.func @one_argument_to_multiple_named_computations(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"z"}]>}) {
// CHECK: %[[RESHARD0:.*]] = sdy.reshard %arg0 <@mesh, [{"z"}]> : tensor<210xf32>
// CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{"y"}]> : tensor<210xf32>
// CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[RESHARD1]])
%0 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) {
%2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
sdy.return %2 : tensor<210xf32>
} : (tensor<210xf32>) -> (tensor<210xf32>)
// CHECK: %[[NC1:.*]] = sdy.named_computation<"foo">(%[[RESHARD0]])
%1 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"z"}]>] out_shardings=[<@mesh, [{"z"}]>] (%arg1: tensor<210xf32>) {
%2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32>
sdy.return %2 : tensor<210xf32>
} : (tensor<210xf32>) -> (tensor<210xf32>)
// CHECK: %[[ADD:.*]] = stablehlo.add %[[NC0]], %[[NC1]]
// CHECK-NEXT: return %[[ADD]]
%3 = stablehlo.add %0, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"z"}]>]>} : tensor<210xf32>
return %3 : tensor<210xf32>
}

// CHECK-LABEL: func @different_arguments_to_multiple_named_computations_with_same_input_output_shardings
func.func @different_arguments_to_multiple_named_computations_with_same_input_output_shardings(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) {
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"y"}]> : tensor<210xf32>
// CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[RESHARD]])
%0 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) {
%3 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
sdy.return %3 : tensor<210xf32>
} : (tensor<210xf32>) -> (tensor<210xf32>)
// CHECK: %[[NEGATE:.*]] = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>}
// CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[NEGATE]])
%1 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
%2 = sdy.named_computation<"foo">(%1) in_shardings=[<@mesh, [{"y"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) {
%3 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
sdy.return %3 : tensor<210xf32>
} : (tensor<210xf32>) -> (tensor<210xf32>)
%4 = stablehlo.add %0, %2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
return %4 : tensor<210xf32>
}

//===----------------------------------------------------------------------===//
// Replicated and sharded to unreduced tests
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// RUN: sdy_opt %s -split-input-file -sdy-insert-explicit-reshards='enable-full-version=true' | FileCheck %s

sdy.mesh @mesh = <["x"=4, "y"=2]>

// CHECK-LABEL: func @named_computation
func.func @named_computation(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x":(1)2}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) {
// CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"x"}]> : tensor<210xf32>
// CHECK-NEXT: sdy.named_computation<"foo">(%[[RESHARD]])
%0 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"x"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) {
%2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}]>]>} : tensor<210xf32>
// CHECK: %[[RESHARD:.*]] = sdy.reshard %{{.*}} <@mesh, [{"y"}]> : tensor<210xf32>
// CHECK-NEXT: sdy.return %[[RESHARD]] : tensor<210xf32>
sdy.return %2 : tensor<210xf32>
} : (tensor<210xf32>) -> (tensor<210xf32>)
%1 = stablehlo.negate %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
return %1 : tensor<210xf32>
}

// -----
sdy.mesh @mesh = <["x"=4, "y"=2]>

// CHECK-LABEL: func @named_computation_empty_block
func.func @named_computation_empty_block(%arg0: tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}]>}) -> (tensor<210xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}]>}) {
// CHECK: sdy.named_computation<"foo">(%arg0)
%0 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{"x"}]>] out_shardings=[<@mesh, [{"y"}]>] (%arg1: tensor<210xf32>) {
// CHECK: %[[RESHARD:.*]] = sdy.reshard %{{.*}} <@mesh, [{"y"}]> : tensor<210xf32>
// CHECK-NEXT: sdy.return %[[RESHARD]] : tensor<210xf32>
sdy.return %arg1 : tensor<210xf32>
} : (tensor<210xf32>) -> (tensor<210xf32>)
%1 = stablehlo.negate %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>} : tensor<210xf32>
return %1 : tensor<210xf32>
}

// -----
sdy.mesh @mesh = <["a"=2, "b"=2, "c"=2]>

// CHECK-LABEL: func @named_computation_with_shardings
func.func @named_computation_with_shardings(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> tensor<12x2xi32> {
%0:2 = call @foo(%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>, <@mesh, [{}, {}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>)
%1 = stablehlo.concatenate %0#0, %0#1, dim = 0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"a"}]>]>} : (tensor<8x2xi32>, tensor<4x2xi32>) -> tensor<12x2xi32>
return %1 : tensor<12x2xi32>
}

func.func private @foo(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, %arg1: tensor<4x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>})
-> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, tensor<4x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}) {
%0 = stablehlo.abs %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"a"}]>]>} : tensor<8x2xi32>
return %0, %arg1 : tensor<8x2xi32>, tensor<4x2xi32>
}
Loading
Loading