@@ -472,64 +472,6 @@ func.func @concatenate_same_shardings_func_result_different_sharding(%arg0: tens
472472 return %0 : tensor <4 x80 x256 xf32 >
473473}
474474
475- //===----------------------------------------------------------------------===//
476- // Named computations tests
477- // More tests are in insert_explicit_reshards/data_flow_ops.mlir
478- //===----------------------------------------------------------------------===//
479-
480- // CHECK-LABEL: func @named_computation
481- func.func @named_computation (%arg0: tensor <210 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{" x" }]>}) -> (tensor <210 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{" z" }]>}) {
482- // CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"y"}]> : tensor<210xf32>
483- // CHECK-NEXT: sdy.named_computation<"foo">(%[[RESHARD]])
484- %0 = sdy.named_computation <" foo" >(%arg0 ) in_shardings =[<@mesh , [{" y" }]>] out_shardings =[<@mesh , [{" z" }]>] (%arg1: tensor <210 xf32 >) {
485- %2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value <[<@mesh , [{" y" }]>]>} : tensor <210 xf32 >
486- // CHECK: %[[RESHARD:.*]] = sdy.reshard %{{.*}} <@mesh, [{"z"}]> : tensor<210xf32>
487- // CHECK-NEXT: sdy.return %[[RESHARD]] : tensor<210xf32>
488- sdy.return %2 : tensor <210 xf32 >
489- } : (tensor <210 xf32 >) -> (tensor <210 xf32 >)
490- %1 = stablehlo.negate %0 {sdy.sharding = #sdy.sharding_per_value <[<@mesh , [{" z" }]>]>} : tensor <210 xf32 >
491- return %1 : tensor <210 xf32 >
492- }
493-
494- // CHECK-LABEL: func @one_argument_to_multiple_named_computations
495- func.func @one_argument_to_multiple_named_computations (%arg0: tensor <210 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{" x" }]>}) -> (tensor <210 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{" z" }]>}) {
496- // CHECK: %[[RESHARD0:.*]] = sdy.reshard %arg0 <@mesh, [{"z"}]> : tensor<210xf32>
497- // CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh, [{"y"}]> : tensor<210xf32>
498- // CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[RESHARD1]])
499- %0 = sdy.named_computation <" foo" >(%arg0 ) in_shardings =[<@mesh , [{" y" }]>] out_shardings =[<@mesh , [{" y" }]>] (%arg1: tensor <210 xf32 >) {
500- %2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value <[<@mesh , [{" y" }]>]>} : tensor <210 xf32 >
501- sdy.return %2 : tensor <210 xf32 >
502- } : (tensor <210 xf32 >) -> (tensor <210 xf32 >)
503- // CHECK: %[[NC1:.*]] = sdy.named_computation<"foo">(%[[RESHARD0]])
504- %1 = sdy.named_computation <" foo" >(%arg0 ) in_shardings =[<@mesh , [{" z" }]>] out_shardings =[<@mesh , [{" z" }]>] (%arg1: tensor <210 xf32 >) {
505- %2 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value <[<@mesh , [{" z" }]>]>} : tensor <210 xf32 >
506- sdy.return %2 : tensor <210 xf32 >
507- } : (tensor <210 xf32 >) -> (tensor <210 xf32 >)
508- // CHECK: %[[ADD:.*]] = stablehlo.add %[[NC0]], %[[NC1]]
509- // CHECK-NEXT: return %[[ADD]]
510- %3 = stablehlo.add %0 , %1 {sdy.sharding = #sdy.sharding_per_value <[<@mesh , [{" z" }]>]>} : tensor <210 xf32 >
511- return %3 : tensor <210 xf32 >
512- }
513-
514- // CHECK-LABEL: func @different_arguments_to_multiple_named_computations_with_same_input_output_shardings
515- func.func @different_arguments_to_multiple_named_computations_with_same_input_output_shardings (%arg0: tensor <210 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{" x" }]>}) -> (tensor <210 xf32 > {sdy.sharding = #sdy.sharding <@mesh , [{" y" }]>}) {
516- // CHECK: %[[RESHARD:.*]] = sdy.reshard %arg0 <@mesh, [{"y"}]> : tensor<210xf32>
517- // CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[RESHARD]])
518- %0 = sdy.named_computation <" foo" >(%arg0 ) in_shardings =[<@mesh , [{" y" }]>] out_shardings =[<@mesh , [{" y" }]>] (%arg1: tensor <210 xf32 >) {
519- %3 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value <[<@mesh , [{" y" }]>]>} : tensor <210 xf32 >
520- sdy.return %3 : tensor <210 xf32 >
521- } : (tensor <210 xf32 >) -> (tensor <210 xf32 >)
522- // CHECK: %[[NEGATE:.*]] = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"y"}]>]>}
523- // CHECK-NEXT: %[[NC0:.*]] = sdy.named_computation<"foo">(%[[NEGATE]])
524- %1 = stablehlo.negate %arg0 {sdy.sharding = #sdy.sharding_per_value <[<@mesh , [{" y" }]>]>} : tensor <210 xf32 >
525- %2 = sdy.named_computation <" foo" >(%1 ) in_shardings =[<@mesh , [{" y" }]>] out_shardings =[<@mesh , [{" y" }]>] (%arg1: tensor <210 xf32 >) {
526- %3 = stablehlo.abs %arg1 {sdy.sharding = #sdy.sharding_per_value <[<@mesh , [{" y" }]>]>} : tensor <210 xf32 >
527- sdy.return %3 : tensor <210 xf32 >
528- } : (tensor <210 xf32 >) -> (tensor <210 xf32 >)
529- %4 = stablehlo.add %0 , %2 {sdy.sharding = #sdy.sharding_per_value <[<@mesh , [{" y" }]>]>} : tensor <210 xf32 >
530- return %4 : tensor <210 xf32 >
531- }
532-
533475//===----------------------------------------------------------------------===//
534476// Replicated and sharded to unreduced tests
535477//===----------------------------------------------------------------------===//
0 commit comments