// RUN: mlir-opt %s -allow-unregistered-dialect -arm-sve-legalize-vector-storage -split-input-file -verify-diagnostics | FileCheck %s /// This tests the basic functionality of the -arm-sve-legalize-vector-storage pass. // ----- // CHECK-LABEL: @store_and_reload_sve_predicate_nxv1i1( // CHECK-SAME: %[[MASK:.*]]: vector<[1]xi1>) func.func @store_and_reload_sve_predicate_nxv1i1(%mask: vector<[1]xi1>) -> vector<[1]xi1> { // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref> %alloca = memref.alloca() : memref> // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[1]xi1> // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref> memref.store %mask, %alloca[] : memref> // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref> // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[1]xi1> %reload = memref.load %alloca[] : memref> // CHECK-NEXT: return %[[MASK]] : vector<[1]xi1> return %reload : vector<[1]xi1> } // ----- // CHECK-LABEL: @store_and_reload_sve_predicate_nxv2i1( // CHECK-SAME: %[[MASK:.*]]: vector<[2]xi1>) func.func @store_and_reload_sve_predicate_nxv2i1(%mask: vector<[2]xi1>) -> vector<[2]xi1> { // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref> %alloca = memref.alloca() : memref> // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[2]xi1> // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref> memref.store %mask, %alloca[] : memref> // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref> // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[2]xi1> %reload = memref.load %alloca[] : memref> // CHECK-NEXT: return %[[MASK]] : vector<[2]xi1> return %reload : vector<[2]xi1> } // ----- // CHECK-LABEL: @store_and_reload_sve_predicate_nxv4i1( // CHECK-SAME: %[[MASK:.*]]: vector<[4]xi1>) func.func @store_and_reload_sve_predicate_nxv4i1(%mask: vector<[4]xi1>) -> vector<[4]xi1> { // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref> %alloca = memref.alloca() : memref> // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[4]xi1> // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref> memref.store %mask, %alloca[] : memref> // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref> // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[4]xi1> %reload = memref.load %alloca[] : memref> // CHECK-NEXT: return %[[MASK]] : vector<[4]xi1> return %reload : vector<[4]xi1> } // ----- // CHECK-LABEL: @store_and_reload_sve_predicate_nxv8i1( // CHECK-SAME: %[[MASK:.*]]: vector<[8]xi1>) func.func @store_and_reload_sve_predicate_nxv8i1(%mask: vector<[8]xi1>) -> vector<[8]xi1> { // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref> %alloca = memref.alloca() : memref> // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[8]xi1> // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref> memref.store %mask, %alloca[] : memref> // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref> // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[8]xi1> %reload = memref.load %alloca[] : memref> // CHECK-NEXT: return %[[MASK]] : vector<[8]xi1> return %reload : vector<[8]xi1> } // ----- // CHECK-LABEL: @store_and_reload_sve_predicate_nxv16i1( // CHECK-SAME: %[[MASK:.*]]: vector<[16]xi1>) func.func @store_and_reload_sve_predicate_nxv16i1(%mask: vector<[16]xi1>) -> vector<[16]xi1> { // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref> %alloca = memref.alloca() : memref> // CHECK-NEXT: memref.store %[[MASK]], %[[ALLOCA]][] : memref> memref.store %mask, %alloca[] : memref> // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref> %reload = memref.load %alloca[] : memref> // CHECK-NEXT: return %[[RELOAD]] : vector<[16]xi1> return %reload : vector<[16]xi1> } // ----- /// This is not a valid SVE mask type, so is ignored by the // `-arm-sve-legalize-vector-storage` pass. // CHECK-LABEL: @store_and_reload_unsupported_type( // CHECK-SAME: %[[MASK:.*]]: vector<[7]xi1>) func.func @store_and_reload_unsupported_type(%mask: vector<[7]xi1>) -> vector<[7]xi1> { // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref> %alloca = memref.alloca() : memref> // CHECK-NEXT: memref.store %[[MASK]], %[[ALLOCA]][] : memref> memref.store %mask, %alloca[] : memref> // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref> %reload = memref.load %alloca[] : memref> // CHECK-NEXT: return %[[RELOAD]] : vector<[7]xi1> return %reload : vector<[7]xi1> } // ----- // CHECK-LABEL: @store_2d_mask_and_reload_slice( // CHECK-SAME: %[[MASK:.*]]: vector<3x[8]xi1>) func.func @store_2d_mask_and_reload_slice(%mask: vector<3x[8]xi1>) -> vector<[8]xi1> { // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index %c0 = arith.constant 0 : index // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref> %alloca = memref.alloca() : memref> // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<3x[8]xi1> // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref> memref.store %mask, %alloca[] : memref> // CHECK-NEXT: %[[UNPACK:.*]] = vector.type_cast %[[ALLOCA]] : memref> to memref<3xvector<[16]xi1>> %unpack = vector.type_cast %alloca : memref> to memref<3xvector<[8]xi1>> // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[UNPACK]][%[[C0]]] : memref<3xvector<[16]xi1>> // CHECK-NEXT: %[[SLICE:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[8]xi1> %slice = memref.load %unpack[%c0] : memref<3xvector<[8]xi1>> // CHECK-NEXT: return %[[SLICE]] : vector<[8]xi1> return %slice : vector<[8]xi1> } // ----- // CHECK-LABEL: @set_sve_alloca_alignment func.func @set_sve_alloca_alignment() { /// This checks the alignment of alloca's of scalable vectors will be /// something the backend can handle. Currently, the backend sets the /// alignment of scalable vectors to their base size (i.e. their size at /// vscale = 1). This works for hardware-sized types, which always get a /// 16-byte alignment. The problem is larger types e.g. vector<[8]xf32> end up /// with alignments larger than 16-bytes (e.g. 32-bytes here), which are /// unsupported. The `-arm-sve-legalize-vector-storage` pass avoids this /// issue by explicitly setting the alignment to 16-bytes for all scalable /// vectors. // CHECK-COUNT-6: alignment = 16 %a1 = memref.alloca() : memref> %a2 = memref.alloca() : memref> %a3 = memref.alloca() : memref> %a4 = memref.alloca() : memref> %a5 = memref.alloca() : memref> %a6 = memref.alloca() : memref> // CHECK-COUNT-6: alignment = 16 %b1 = memref.alloca() : memref> %b2 = memref.alloca() : memref> %b3 = memref.alloca() : memref> %b4 = memref.alloca() : memref> %b5 = memref.alloca() : memref> %b6 = memref.alloca() : memref> // CHECK-COUNT-6: alignment = 16 %c1 = memref.alloca() : memref> %c2 = memref.alloca() : memref> %c3 = memref.alloca() : memref> %c4 = memref.alloca() : memref> %c5 = memref.alloca() : memref> %c6 = memref.alloca() : memref> // CHECK-COUNT-6: alignment = 16 %d1 = memref.alloca() : memref> %d2 = memref.alloca() : memref> %d3 = memref.alloca() : memref> %d4 = memref.alloca() : memref> %d5 = memref.alloca() : memref> %d6 = memref.alloca() : memref> // CHECK-COUNT-6: alignment = 16 %e1 = memref.alloca() : memref> %e2 = memref.alloca() : memref> %e3 = memref.alloca() : memref> %e4 = memref.alloca() : memref> %e5 = memref.alloca() : memref> %e6 = memref.alloca() : memref> // CHECK-COUNT-6: alignment = 16 %f1 = memref.alloca() : memref> %f2 = memref.alloca() : memref> %f3 = memref.alloca() : memref> %f4 = memref.alloca() : memref> %f5 = memref.alloca() : memref> %f6 = memref.alloca() : memref> "prevent.dce"( %a1, %a2, %a3, %a4, %a5, %a6, %b1, %b2, %b3, %b4, %b5, %b6, %c1, %c2, %c3, %c4, %c5, %c6, %d1, %d2, %d3, %d4, %d5, %d6, %e1, %e2, %e3, %e4, %e5, %e6, %f1, %f2, %f3, %f4, %f5, %f6) : (memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>, memref>) -> () return }