// RUN: mlir-opt -lower-vector-mask -split-input-file %s | FileCheck %s func.func @vector_transfer_read(%t0: tensor, %idx: index, %m0: vector<16xi1>) -> vector<16xf32> { %ft0 = arith.constant 0.0 : f32 %0 = vector.mask %m0 { vector.transfer_read %t0[%idx], %ft0 : tensor, vector<16xf32> } : vector<16xi1> -> vector<16xf32> return %0 : vector<16xf32> } // CHECK-LABEL: func.func @vector_transfer_read( // CHECK-SAME: %[[VAL_0:.*]]: tensor, // CHECK-SAME: %[[VAL_1:.*]]: index, // CHECK-SAME: %[[VAL_2:.*]]: vector<16xi1>) -> vector<16xf32> { // CHECK-NOT: vector.mask // CHECK: %[[VAL_4:.*]] = vector.transfer_read {{.*}}, %[[VAL_2]] : tensor, vector<16xf32> // CHECK: return %[[VAL_4]] : vector<16xf32> // CHECK: } // ----- func.func @vector_transfer_write_on_memref(%val: vector<16xf32>, %t0: memref, %idx: index, %m0: vector<16xi1>) { vector.mask %m0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, memref } : vector<16xi1> return } // CHECK-LABEL: func.func @vector_transfer_write_on_memref( // CHECK-SAME: %[[VAL_0:.*]]: vector<16xf32>, // CHECK-SAME: %[[VAL_1:.*]]: memref, // CHECK-SAME: %[[VAL_2:.*]]: index, // CHECK-SAME: %[[VAL_3:.*]]: vector<16xi1>) { //CHECK-NOT: vector.mask // CHECK: vector.transfer_write %[[VAL_0]], {{.*}}, %[[VAL_3]] : vector<16xf32>, memref // CHECK: return // CHECK: } // ----- func.func @vector_transfer_write_on_tensor(%val: vector<16xf32>, %t0: tensor, %idx: index, %m0: vector<16xi1>) -> tensor { %res = vector.mask %m0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor } : vector<16xi1> -> tensor return %res : tensor } // CHECK-LABEL: func.func @vector_transfer_write_on_tensor( // CHECK-SAME: %[[VAL_0:.*]]: vector<16xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor, // CHECK-SAME: %[[VAL_2:.*]]: index, // CHECK-SAME: %[[VAL_3:.*]]: vector<16xi1>) -> tensor { // CHECK: %[[VAL_4:.*]] = vector.transfer_write %[[VAL_0]], {{.*}}, %[[VAL_3]] : vector<16xf32>, tensor // CHECK: return %[[VAL_4]] : tensor // CHECK: } // ----- func.func @vector_gather(%arg0: tensor<64xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> { %c0 = arith.constant 0 : index %cst = arith.constant 0.000000e+00 : f32 %c3 = arith.constant 3 : index %0 = vector.create_mask %c3 : vector<4xi1> %1 = vector.mask %0 { vector.transfer_read %arg1[%c0], %cst {in_bounds = [true]} : tensor<3xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32> %cst_0 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> %cst_1 = arith.constant dense : vector<4xi1> %cst_2 = arith.constant dense<0.000000e+00> : vector<4xf32> %c0_3 = arith.constant 0 : index %2 = vector.mask %0 { vector.gather %arg0[%c0_3] [%cst_0], %cst_1, %cst_2 : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> } : vector<4xi1> -> vector<4xf32> %c0_4 = arith.constant 0 : index %3 = vector.mask %0 { vector.transfer_write %2, %arg1[%c0_4] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32> } : vector<4xi1> -> tensor<3xf32> return %3 : tensor<3xf32> } // CHECK-LABEL: func.func @vector_gather( // CHECK-SAME: %[[VAL_0:.*]]: tensor<64xf32>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<3xf32>) -> tensor<3xf32> { // CHECK: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32> // CHECK: %[[VAL_3:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> // CHECK: %[[VAL_4:.*]] = arith.constant 0 : index // CHECK: %[[VAL_5:.*]] = arith.constant 3 : index // CHECK: %[[VAL_6:.*]] = vector.create_mask %[[VAL_5]] : vector<4xi1> // CHECK: %[[VAL_7:.*]] = vector.gather %[[VAL_0]][%[[VAL_4]]] [%[[VAL_3]]], %[[VAL_6]], %[[VAL_2]] : tensor<64xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> // CHECK: %[[VAL_8:.*]] = vector.transfer_write %[[VAL_7]], %[[VAL_1]][%[[VAL_4]]], %[[VAL_6]] {in_bounds = [true]} : vector<4xf32>, tensor<3xf32> // ----- // CHECK-LABEL: func @empty_vector_mask_with_return // CHECK-SAME: %[[IN:.*]]: vector<8xf32> func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1>) -> vector<8xf32> { // CHECK-NOT: vector.mask // CHECK: return %[[IN]] : vector<8xf32> %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32> return %0 : vector<8xf32> }