// RUN: mlir-opt %s | mlir-opt | FileCheck %s // CHECK: mesh.cluster @mesh0 mesh.cluster @mesh0(shape = 2x2x4) // CHECK: mesh.cluster @mesh1(shape = 4x?) mesh.cluster @mesh1(shape = 4x?) // CHECK: mesh.cluster @mesh2(shape = ?x4) mesh.cluster @mesh2(shape = ?x4) // CHECK: mesh.cluster @mesh3(shape = ?x?) mesh.cluster @mesh3(shape = ?x?) mesh.cluster @mesh4(shape = 3) // CHECK: mesh.cluster @mesh5(shape = ?) mesh.cluster @mesh5(shape = ?) // CHECK-LABEL: func @mesh_shard_encoding_fully_replicated func.func @mesh_shard_encoding_fully_replicated( // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}]]>> %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>>) -> tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>> { return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[]]>> } // CHECK-LABEL: func @mesh_shard_encoding_1st_dim func.func @mesh_shard_encoding_1st_dim( // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh0, {{\[\[}}0]]>> %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>>) -> tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>> { return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh0, [[0]]>> } // CHECK-LABEL: func @mesh_shard_encoding_2nd_dim func.func @mesh_shard_encoding_2nd_dim( // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32, #mesh.shard<@mesh1, {{\[\[}}], [0]]>> %arg0 : tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>>) -> tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>> { return %arg0 : tensor<4x8xf32, #mesh.shard<@mesh1, [[], [0]]>> } // CHECK-LABEL: func @mesh_shard_encoding_1st_and_3rd_dim func.func @mesh_shard_encoding_1st_and_3rd_dim( // CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32, #mesh.shard<@mesh3, {{\[\[}}0], [], [1]]>> %arg0 : tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>>) -> tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>> { return %arg0 : tensor<4x8x16xf32, #mesh.shard<@mesh3, [[0], [], [1]]>> } // CHECK-LABEL: func @mesh_shard_op_1st_dim // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> func.func @mesh_shard_op_1st_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh0, {{\[\[}}0]]> : tensor<4x8xf32> %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32> return %0 : tensor<4x8xf32> } // CHECK-LABEL: func @mesh_shard_op_2nd_dim // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> func.func @mesh_shard_op_2nd_dim(%arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh1, {{\[\[}}], [0]]> : tensor<4x8xf32> %0 = mesh.shard %arg0 to <@mesh1, [[], [0]]> : tensor<4x8xf32> return %0 : tensor<4x8xf32> } // CHECK-LABEL: func @mesh_shard_op_1st_and_3rd_dim func.func @mesh_shard_op_1st_and_3rd_dim( // CHECK-SAME: %[[ARG:.*]]: tensor<4x8x16xf32> %arg0 : tensor<4x8x16xf32>) -> tensor<4x8x16xf32> { // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0], [], [1]]> : tensor<4x8x16xf32> %0 = mesh.shard %arg0 to <@mesh3, [[0], [], [1]]> : tensor<4x8x16xf32> return %0 : tensor<4x8x16xf32> } // CHECK-LABEL: func @mesh_shard_op_partial_max func.func @mesh_shard_op_partial_max( // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = max[1]> : tensor<4x8xf32> %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = max[1]> : tensor<4x8xf32> return %0 : tensor<4x8xf32> } // CHECK-LABEL: func @mesh_shard_op_partial_min func.func @mesh_shard_op_partial_min( // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = min[1]> : tensor<4x8xf32> %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = min[1]> : tensor<4x8xf32> return %0 : tensor<4x8xf32> } // CHECK-LABEL: func @mesh_shard_op_partial_generic func.func @mesh_shard_op_partial_generic( // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = generic[1]> : tensor<4x8xf32> %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = generic[1]> : tensor<4x8xf32> return %0 : tensor<4x8xf32> } // CHECK-LABEL: func @mesh_shard_op_partial_sum func.func @mesh_shard_op_partial_sum( // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = sum[1]> : tensor<4x8xf32> %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = sum[1]> : tensor<4x8xf32> return %0 : tensor<4x8xf32> } // CHECK-LABEL: func @mesh_shard_op_partial_sum_multi_axes func.func @mesh_shard_op_partial_sum_multi_axes( // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> %arg0 : tensor<4x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: mesh.shard %[[ARG]] to <@mesh3, {{\[\[}}0]], partial = sum[1, 2]> : tensor<4x8xf32> %0 = mesh.shard %arg0 to <@mesh3, [[0]], partial = sum[1, 2]> : tensor<4x8xf32> return %0 : tensor<4x8xf32> } // CHECK-LABEL: func @mesh_shard_op_two_users // CHECK-SAME: %[[ARG:.*]]: tensor<4x8xf32> func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) -> (tensor<4x8xf32>, tensor<4x8xf32>) { // CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG]] to <@mesh0, {{\[\[}}0]]> : tensor<4x8xf32> %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32> // CHECK-DAG: mesh.shard %[[V0]] to <@mesh0, {{\[\[}}1]]> annotate_for_users : tensor<4x8xf32> %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32> // CHECK-DAG: mesh.shard %[[V0]] to <@mesh0, {{\[\[}}2]]> annotate_for_users : tensor<4x8xf32> %2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32> return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32> } // CHECK-LABEL: func @cluster_shape func.func @cluster_shape() -> (index, index) { // CHECK: %[[RES:.*]]:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index %0:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index return %0#0, %0#1 : index, index } // CHECK-LABEL: func @cluster_shape_default_axes func.func @cluster_shape_default_axes() -> (index, index, index) { // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index %0:3 = mesh.cluster_shape @mesh0 : index, index, index // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index return %0#0, %0#1, %0#2 : index, index, index } // CHECK-LABEL: func @cluster_shape_empty_axes func.func @cluster_shape_empty_axes() -> (index, index, index) { // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index %0:3 = mesh.cluster_shape @mesh0 axes = [] : index, index, index // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index return %0#0, %0#1, %0#2 : index, index, index } // CHECK-LABEL: func @process_multi_index func.func @process_multi_index() -> (index, index) { // CHECK: %[[RES:.*]]:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index %0:2 = mesh.process_multi_index on @mesh0 axes = [0, 1] : index, index // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index return %0#0, %0#1 : index, index } // CHECK-LABEL: func @process_multi_index_default_axes func.func @process_multi_index_default_axes() -> (index, index, index) { // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index %0:3 = mesh.process_multi_index on @mesh0 : index, index, index // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index return %0#0, %0#1, %0#2 : index, index, index } // CHECK-LABEL: func @process_multi_index_empty_axes func.func @process_multi_index_empty_axes() -> (index, index, index) { // CHECK: %[[RES:.*]]:3 = mesh.process_multi_index on @mesh0 : index, index, index %0:3 = mesh.process_multi_index on @mesh0 axes = [] : index, index, index // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index return %0#0, %0#1, %0#2 : index, index, index } // CHECK-LABEL: func @process_linear_index func.func @process_linear_index() -> index { // CHECK: %[[RES:.*]] = mesh.process_linear_index on @mesh0 : index %0 = mesh.process_linear_index on @mesh0 : index // CHECK: return %[[RES]] : index return %0 : index } // CHECK-LABEL: func @all_reduce func.func @all_reduce( // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> %arg0 : tensor<3x4xf32>) -> tensor<3x4xf64> { // CHECK-NEXT: mesh.all_reduce %[[ARG]] on @mesh0 mesh_axes = [1, 0] reduction = // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x4xf64> %0 = mesh.all_reduce %arg0 on @mesh0 mesh_axes = [1, 0] reduction = : tensor<3x4xf32> -> tensor<3x4xf64> return %0 : tensor<3x4xf64> } // CHECK-LABEL: func @all_gather func.func @all_gather( // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> %arg0 : tensor<3x4xf32>) -> tensor<3x16xf32> { // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1 // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x16xf32> %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1 : tensor<3x4xf32> -> tensor<3x16xf32> return %0 : tensor<3x16xf32> } // CHECK-LABEL: func @all_gather_dynamic_dims_in_tensor func.func @all_gather_dynamic_dims_in_tensor( // CHECK-SAME: %[[ARG:.*]]: tensor %arg0 : tensor) -> tensor { // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh0 mesh_axes = [2] gather_axis = 1 // CHECK-SAME: : tensor -> tensor %0 = mesh.all_gather %arg0 on @mesh0 mesh_axes = [2] gather_axis = 1 : tensor -> tensor return %0 : tensor } // CHECK-LABEL: func @all_gather_dynamic_dims_in_mesh func.func @all_gather_dynamic_dims_in_mesh( // CHECK-SAME: %[[ARG:.*]]: tensor<5x6xf32> %arg0 : tensor<5x6xf32>) -> tensor<5x?xf32> { // CHECK-NEXT: mesh.all_gather %[[ARG]] on @mesh3 mesh_axes = [1] gather_axis = 1 // CHECK-SAME: : tensor<5x6xf32> -> tensor<5x?xf32> %0 = mesh.all_gather %arg0 on @mesh3 mesh_axes = [1] gather_axis = 1 : tensor<5x6xf32> -> tensor<5x?xf32> return %0 : tensor<5x?xf32> } // CHECK-LABEL: func @all_to_all func.func @all_to_all( // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8> %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { // CHECK-NEXT: mesh.all_to_all %[[ARG]] // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0 // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x6xi8> %0 = mesh.all_to_all %arg0 on @mesh4 split_axis = 1 concat_axis = 0 : tensor<3x6xi8> -> tensor<3x6xi8> return %0 : tensor<3x6xi8> } // CHECK-LABEL: func @all_to_all_dynamic_dims_in_result func.func @all_to_all_dynamic_dims_in_result( // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8> %arg0 : tensor<3x6xi8>) -> tensor<3x?xi8> { // CHECK-NEXT: mesh.all_to_all %[[ARG]] // CHECK-SAME: on @mesh4 split_axis = 1 concat_axis = 0 // CHECK-SAME: : tensor<3x6xi8> -> tensor<3x?xi8> %0 = mesh.all_to_all %arg0 on @mesh4 split_axis = 1 concat_axis = 0 : tensor<3x6xi8> -> tensor<3x?xi8> return %0 : tensor<3x?xi8> } // CHECK-LABEL: func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size func.func @all_to_all_same_split_concat_dim_with_dynamic_device_group_size( // CHECK-SAME: %[[ARG:.*]]: tensor<3xi8> %arg0 : tensor<3xi8>) -> tensor<3xi8> { // CHECK-NEXT: mesh.all_to_all %[[ARG]] // CHECK-SAME: @mesh4 split_axis = 0 concat_axis = 0 // CHECK-SAME: : tensor<3xi8> -> tensor<3xi8> %0 = mesh.all_to_all %arg0 on @mesh4 split_axis = 0 concat_axis = 0 : tensor<3xi8> -> tensor<3xi8> return %0 : tensor<3xi8> } // CHECK-LABEL: func @all_to_all_non_divisible_split_axis_size func.func @all_to_all_non_divisible_split_axis_size( // CHECK-SAME: %[[ARG:.*]]: tensor<2x3xi8> %arg0 : tensor<2x3xi8>) -> tensor { // CHECK-NEXT: mesh.all_to_all %[[ARG]] // CHECK-SAME: @mesh0 mesh_axes = [0, 1] split_axis = 0 concat_axis = 1 // CHECK-SAME: : tensor<2x3xi8> -> tensor %0 = mesh.all_to_all %arg0 on @mesh0 mesh_axes = [0, 1] split_axis = 0 concat_axis = 1 : tensor<2x3xi8> -> tensor return %0 : tensor } // CHECK-LABEL: func @broadcast_static_root func.func @broadcast_static_root( // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8> %arg0 : tensor<3x6xi8>) -> tensor<3x6xi8> { // CHECK-NEXT: mesh.broadcast %[[ARG]] // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] // CHECK-SAME: root = [0, 1] // CHECK-SAME: : (tensor<3x6xi8>) -> tensor<3x6xi8> %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2] root = [0, 1] : (tensor<3x6xi8>) -> tensor<3x6xi8> return %0 : tensor<3x6xi8> } // CHECK-LABEL: func @broadcast_dynamic_root func.func @broadcast_dynamic_root( // CHECK-SAME: %[[ARG0:.*]]: tensor<3x6xi8> %arg0 : tensor<3x6xi8>, // CHECK-SAME: %[[ARG1:.*]]: index %arg1 : index ) -> tensor<3x6xi8> { // CHECK-NEXT: mesh.broadcast %[[ARG0]] // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] // CHECK-SAME: root = [1, %[[ARG1]]] // CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<3x6xi8> %0 = mesh.broadcast %arg0 on @mesh0 mesh_axes = [0, 2] root = [1, %arg1] : (tensor<3x6xi8>, index) -> tensor<3x6xi8> return %0 : tensor<3x6xi8> } // CHECK-LABEL: func @gather_static_root func.func @gather_static_root( // CHECK-SAME: %[[ARG:.*]]: tensor<3x6xi8> %arg0 : tensor<3x6xi8>) -> tensor<24x6xi8> { // CHECK-NEXT: mesh.gather %[[ARG]] // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] // CHECK-SAME: gather_axis = 0 // CHECK-SAME: root = [0, 1] // CHECK-SAME: : (tensor<3x6xi8>) -> tensor<24x6xi8> %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2] gather_axis = 0 root = [0, 1] : (tensor<3x6xi8>) -> tensor<24x6xi8> return %0 : tensor<24x6xi8> } // CHECK-LABEL: func @gather_dynamic_root func.func @gather_dynamic_root( // CHECK-SAME: %[[ARG0:.*]]: tensor<3x6xi8> %arg0 : tensor<3x6xi8>, // CHECK-SAME: %[[ARG1:.*]]: index %arg1 : index ) -> tensor<24x6xi8> { // CHECK-NEXT: mesh.gather %[[ARG0]] // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] // CHECK-SAME: gather_axis = 0 // CHECK-SAME: root = [1, %[[ARG1]]] // CHECK-SAME: : (tensor<3x6xi8>, index) -> tensor<24x6xi8> %0 = mesh.gather %arg0 on @mesh0 mesh_axes = [0, 2] gather_axis = 0 root = [1, %arg1] : (tensor<3x6xi8>, index) -> tensor<24x6xi8> return %0 : tensor<24x6xi8> } // CHECK-LABEL: func @receive_static_source func.func @receive_static_source( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>) -> tensor<2xi8> { // CHECK-NEXT: mesh.recv %[[ARG]] // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] // CHECK-SAME: source = [0, 1] // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8> %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2] source = [0, 1] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> } // CHECK-LABEL: func @receive_dynamic_source func.func @receive_dynamic_source( // CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>, // CHECK-SAME: %[[ARG1:.*]]: index %arg1 : index ) -> tensor<2xi8> { // CHECK-NEXT: mesh.recv %[[ARG0]] // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] // CHECK-SAME: source = [1, %[[ARG1]]] // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8> %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2] source = [1, %arg1] : (tensor<2xi8>, index) -> tensor<2xi8> return %0 : tensor<2xi8> } // CHECK-LABEL: func @receive_no_source func.func @receive_no_source( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>) -> tensor<2xi8> { // CHECK-NEXT: mesh.recv %[[ARG]] // CHECK-NOT: source %0 = mesh.recv %arg0 on @mesh0 mesh_axes = [0, 2] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> } // CHECK-LABEL: func @reduce_static_root func.func @reduce_static_root( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>) -> tensor<2xi8> { // CHECK-NEXT: mesh.reduce %[[ARG]] // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] // CHECK-SAME: root = [0, 1] // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8> %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2] root = [0, 1] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> } // CHECK-LABEL: func @reduce_dynamic_root func.func @reduce_dynamic_root( // CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>, // CHECK-SAME: %[[ARG1:.*]]: index %arg1 : index ) -> tensor<2xi8> { // CHECK-NEXT: mesh.reduce %[[ARG0]] // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] // CHECK-SAME: root = [1, %[[ARG1]]] // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8> %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2] root = [1, %arg1] : (tensor<2xi8>, index) -> tensor<2xi8> return %0 : tensor<2xi8> } // CHECK-LABEL: func @reduce_different_return_element_type func.func @reduce_different_return_element_type( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>) -> tensor<2xi16> { // CHECK-NEXT: mesh.reduce %[[ARG]] // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] // CHECK-SAME: root = [0, 1] // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi16> %0 = mesh.reduce %arg0 on @mesh0 mesh_axes = [0, 2] root = [0, 1] : (tensor<2xi8>) -> tensor<2xi16> return %0 : tensor<2xi16> } // CHECK-LABEL: func @reduce_scatter_static_dimensions func.func @reduce_scatter_static_dimensions( // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> %arg0 : tensor<3x4xf32>) -> tensor<3x1xf64> { // CHECK-NEXT: mesh.reduce_scatter %[[ARG]] // CHECK-SAME: on @mesh0 mesh_axes = [2] reduction = scatter_axis = 1 // CHECK-SAME: : tensor<3x4xf32> -> tensor<3x1xf64> %0 = mesh.reduce_scatter %arg0 on @mesh0 mesh_axes = [2] reduction = scatter_axis = 1 : tensor<3x4xf32> -> tensor<3x1xf64> return %0 : tensor<3x1xf64> } // CHECK-LABEL: func @reduce_scatter_dynamic_dimensions func.func @reduce_scatter_dynamic_dimensions( // CHECK-SAME: %[[ARG:.*]]: tensor %arg0 : tensor) -> tensor { // CHECK-NEXT: mesh.reduce_scatter %[[ARG]] // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] scatter_axis = 0 // CHECK-SAME: : tensor -> tensor %0 = mesh.reduce_scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0 : tensor -> tensor return %0 : tensor } // CHECK-LABEL: func @scatter_static_dimensions func.func @scatter_static_dimensions( // CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32> %arg0 : tensor<3x4xf32>) -> tensor<3x1xf32> { // CHECK-NEXT: mesh.scatter %[[ARG]] // CHECK-SAME: on @mesh0 mesh_axes = [2] // CHECK-SAME: scatter_axis = 1 root = [1] // CHECK-SAME: : (tensor<3x4xf32>) -> tensor<3x1xf32> %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [2] scatter_axis = 1 root = [1] : (tensor<3x4xf32>) -> tensor<3x1xf32> return %0 : tensor<3x1xf32> } // CHECK-LABEL: func @scatter_dynamic_dimensions func.func @scatter_dynamic_dimensions( // CHECK-SAME: %[[ARG:.*]]: tensor %arg0 : tensor) -> tensor { // CHECK-NEXT: mesh.scatter %[[ARG]] // CHECK-SAME: on @mesh3 mesh_axes = [0, 1] // CHECK-SAME: scatter_axis = 0 root = [1, 2] // CHECK-SAME: : (tensor) -> tensor %0 = mesh.scatter %arg0 on @mesh3 mesh_axes = [0, 1] scatter_axis = 0 root = [1, 2] : (tensor) -> tensor return %0 : tensor } // CHECK-LABEL: func @scatter_dynamic_root func.func @scatter_dynamic_root( // CHECK-SAME: %[[ARG0:.*]]: tensor<8xi8> %arg0 : tensor<8xi8>, // CHECK-SAME: %[[ARG1:.*]]: index %arg1 : index ) -> tensor<1xi8> { // CHECK-NEXT: mesh.scatter %[[ARG0]] // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] // CHECK-SAME: scatter_axis = 0 // CHECK-SAME: root = [1, %[[ARG1]]] // CHECK-SAME: : (tensor<8xi8>, index) -> tensor<1xi8> %0 = mesh.scatter %arg0 on @mesh0 mesh_axes = [0, 2] scatter_axis = 0 root = [1, %arg1] : (tensor<8xi8>, index) -> tensor<1xi8> return %0 : tensor<1xi8> } // CHECK-LABEL: func @send_static_destination func.func @send_static_destination( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>) -> tensor<2xi8> { // CHECK-NEXT: mesh.send %[[ARG]] // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] // CHECK-SAME: destination = [0, 1] // CHECK-SAME: : (tensor<2xi8>) -> tensor<2xi8> %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2] destination = [0, 1] : (tensor<2xi8>) -> tensor<2xi8> return %0 : tensor<2xi8> } // CHECK-LABEL: func @send_dynamic_destination func.func @send_dynamic_destination( // CHECK-SAME: %[[ARG0:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>, // CHECK-SAME: %[[ARG1:.*]]: index %arg1 : index ) -> tensor<2xi8> { // CHECK-NEXT: mesh.send %[[ARG0]] // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] // CHECK-SAME: destination = [1, %[[ARG1]]] // CHECK-SAME: : (tensor<2xi8>, index) -> tensor<2xi8> %0 = mesh.send %arg0 on @mesh0 mesh_axes = [0, 2] destination = [1, %arg1] : (tensor<2xi8>, index) -> tensor<2xi8> return %0 : tensor<2xi8> } // CHECK-LABEL: func @shift func.func @shift( // CHECK-SAME: %[[ARG:.*]]: tensor<2xi8> %arg0 : tensor<2xi8>) -> tensor<2xi8> { // CHECK-NEXT: mesh.shift %[[ARG]] // CHECK-SAME: on @mesh0 mesh_axes = [0, 2] // CHECK-SAME: shift_axis = 2 offset = -2 rotate // CHECK-SAME: : tensor<2xi8> -> tensor<2xi8> %0 = mesh.shift %arg0 on @mesh0 mesh_axes = [0, 2] shift_axis = 2 offset = -2 rotate : tensor<2xi8> -> tensor<2xi8> return %0 : tensor<2xi8> }