|
1 | | -// RUN: sdy_opt %s -sdy-remove-propagation-debug-info | FileCheck %s |
| 1 | +// RUN: sdy_opt %s -split-input-file -sdy-remove-propagation-debug-info | FileCheck %s |
2 | 2 |
|
3 | 3 | sdy.mesh @mesh_a = <["a"=2]> |
4 | 4 | sdy.mesh @mesh_c = <["c"=8]> |
@@ -78,3 +78,35 @@ func.func @remove_origin_shardings_manual_computation_with_sharding_constraints( |
78 | 78 | } {sdy.block_arg_sharding_origins = [{a = "self", b = "self"}, {a = "self", b = "self"}], sdy.result_sharding_origins = [{a = "self", b = "self"}], sdy.sharding_origin_name = "mc_1"} : (tensor<8x8xf32>, tensor<8x8xf32>) -> tensor<8x8xf32> |
79 | 79 | return %0, %2 : tensor<8x8xf32>, tensor<8x8xf32> |
80 | 80 | } |
| 81 | + |
| 82 | +// ----- |
| 83 | +sdy.mesh @mesh_abc = <["a"=2, "b"=2, "c"=8]> |
| 84 | +// CHECK-LABEL: func @single_call( |
| 85 | +// CHECK-SAME: %arg0: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>}, |
| 86 | +// CHECK-SAME: %arg1: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>}) -> ( |
| 87 | +// CHECK-SAME: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>}) { |
| 88 | +func.func @single_call( |
| 89 | + %arg0: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>, sdy.sharding_origins = {a = "mc_1_input: 0", b = "mc_1_input: 0"}}, |
| 90 | + %arg1: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>}) -> ( |
| 91 | + tensor<8x8x8xf32> {sdy.propagation_edges = #sdy.propagation_edges<[{step-0 = [{"b" = result-0 -> [operand-0]}]}, {step-2 = [{"a" = operand-0 -> [result-0]}, {"c" = operand-0 -> [result-0]}]}]>, sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>}) { |
| 92 | + // CHECK-NEXT: %0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>]>} |
| 93 | + // CHECK-NEXT: %1 = call @foo(%0, %0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>]>} |
| 94 | + %0 = stablehlo.add %arg0, %arg1 {sdy.propagation_edges = #sdy.propagation_edges<[{step-1 = [{"a" = operand-0 -> [operand-1, result-0]}, {"b" = result-0 -> [operand-0, operand-1]}, {"c" = operand-1 -> [operand-0, result-0]}]}]>, sdy.sharding = #sdy.sharding_per_value<[<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>]>} : tensor<8x8x8xf32> |
| 95 | + %1 = call @foo(%0,%0) {sdy.propagation_edges = #sdy.propagation_edges<[{step-1 = [{"a" = operand-0 -> [operand-1, result-0]}, {"b" = result-0 -> [operand-0, operand-1]}, {"c" = operand-1 -> [operand-0, result-0]}]}]>, sdy.sharding = #sdy.sharding_per_value<[<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>]>} : (tensor<8x8x8xf32>, tensor<8x8x8xf32>) -> tensor<8x8x8xf32> |
| 96 | + return %1 : tensor<8x8x8xf32> |
| 97 | +} |
| 98 | + |
| 99 | +// CHECK-LABEL: func private @foo( |
| 100 | +// CHECK-SAME: %arg0: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>} |
| 101 | +// CHECK-SAME: %arg1: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>}) -> ( |
| 102 | +// CHECK-SAME: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>} |
| 103 | +func.func private @foo( |
| 104 | + %arg0: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>}, |
| 105 | + %arg1: tensor<8x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>, sdy.sharding_origins = {a = "mc_1_input: 0", b = "mc_1_input: 0"}}) -> ( |
| 106 | + tensor<8x8x8xf32> {sdy.propagation_edges = #sdy.propagation_edges<[{step-0 = [{"b" = result-0 -> [operand-0]}]}, {step-2 = [{"a" = operand-0 -> [result-0]}, {"c" = operand-0 -> [result-0]}]}]>, sdy.sharding = #sdy.sharding<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>}) { |
| 107 | + // CHECK-NEXT: %0 = stablehlo.multiply %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>]>} |
| 108 | + %0 = stablehlo.multiply %arg0, %arg1 {sdy.propagation_edges = #sdy.propagation_edges<[{step-1 = [{"a" = operand-0 -> [operand-1, result-0]}, {"b" = result-0 -> [operand-0, operand-1]}, {"c" = operand-1 -> [operand-0, result-0]}]}]>, sdy.sharding = #sdy.sharding_per_value<[<@mesh_abc, [{"a", ?}, {"b", ?}, {"c", ?}]>]>} : tensor<8x8x8xf32> |
| 109 | + return %0 : tensor<8x8x8xf32> |
| 110 | +} |
| 111 | + |
| 112 | + |
0 commit comments