//===- VectorToArmSME.cpp - Conversion from Vector to the ArmSME dialect --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/Support/Casting.h" using namespace mlir; namespace { /// Conversion pattern for vector.transfer_read. /// /// --- /// /// Example 1: op with identity permutation map to horizontal /// arm_sme.tile_load: /// /// vector.transfer_read ... permutation_map: (d0, d1) -> (d0, d1) /// /// is converted to: /// /// arm_sme.tile_load ... /// /// --- /// /// Example 2: op with transpose permutation map to vertical arm_sme.tile_load /// (in-flight transpose): /// /// vector.transfer_read ... permutation_map: (d0, d1) -> (d1, d0) /// /// is converted to: /// /// arm_sme.tile_load ... layout struct TransferReadToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, PatternRewriter &rewriter) const final { // The permutation map must have two results. if (transferReadOp.getTransferRank() != 2) return rewriter.notifyMatchFailure(transferReadOp, "not a 2 result permutation map"); auto vectorType = transferReadOp.getVectorType(); if (!arm_sme::isValidSMETileVectorType(vectorType)) return rewriter.notifyMatchFailure(transferReadOp, "not a valid vector type for SME"); if (!llvm::isa(transferReadOp.getSource().getType())) return rewriter.notifyMatchFailure(transferReadOp, "not a memref source"); // Out-of-bounds dims are not supported. if (transferReadOp.hasOutOfBoundsDim()) return rewriter.notifyMatchFailure(transferReadOp, "not inbounds transfer read"); arm_sme::TileSliceLayout layout; AffineExpr d0, d1; bindDims(transferReadOp.getContext(), d0, d1); AffineMap map = transferReadOp.getPermutationMap(); if (map.isIdentity()) layout = arm_sme::TileSliceLayout::Horizontal; else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0}, transferReadOp.getContext())) layout = arm_sme::TileSliceLayout::Vertical; else return rewriter.notifyMatchFailure(transferReadOp, "unsupported permutation map"); // Padding isn't optional for transfer_read, but is only used in the case // of out-of-bounds accesses (not supported here) and/or masking. Mask is // optional, if it's not present don't pass padding. auto mask = transferReadOp.getMask(); auto padding = mask ? transferReadOp.getPadding() : nullptr; rewriter.replaceOpWithNewOp( transferReadOp, vectorType, transferReadOp.getSource(), transferReadOp.getIndices(), padding, mask, layout); return success(); } }; /// Conversion pattern for vector.transfer_write. /// /// --- /// /// Example 1: op with identity permutation map to horizontal /// arm_sme.tile_store: /// /// vector.transfer_write %vector, %source[%c0, %c0] /// {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref /// /// is converted to: /// /// arm_sme.tile_store %vector, %source[%c0, %c0] : memref, /// vector<[16]x[16]xi8> /// --- /// /// Example 2: op with transpose permutation map to vertical arm_sme.tile_store /// (in-flight transpose): /// /// vector.transfer_write %vector, %source[%c0, %c0] /// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, /// in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref /// /// is converted to: /// /// arm_sme.tile_store %vector, %source[%c0, %c0] layout /// : memref, vector<[16]x[16]xi8> struct TransferWriteToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const final { auto vType = writeOp.getVectorType(); if (!arm_sme::isValidSMETileVectorType(vType)) return failure(); if (!llvm::isa(writeOp.getSource().getType())) return failure(); // Out-of-bounds dims are not supported. if (writeOp.hasOutOfBoundsDim()) return rewriter.notifyMatchFailure(writeOp, "not inbounds transfer write"); AffineExpr d0, d1; bindDims(writeOp.getContext(), d0, d1); AffineMap map = writeOp.getPermutationMap(); bool isTranspose = (map == AffineMap::get(map.getNumDims(), 0, {d1, d0}, writeOp.getContext())); if (!map.isIdentity() && !isTranspose) return rewriter.notifyMatchFailure(writeOp, "unsupported permutation map"); arm_sme::TileSliceLayout layout = isTranspose ? arm_sme::TileSliceLayout::Vertical : arm_sme::TileSliceLayout::Horizontal; rewriter.replaceOpWithNewOp( writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(), writeOp.getMask(), layout); return success(); } }; /// Conversion pattern for vector.load. struct VectorLoadToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::LoadOp load, PatternRewriter &rewriter) const override { if (!arm_sme::isValidSMETileVectorType(load.getVectorType())) return failure(); rewriter.replaceOpWithNewOp( load, load.getVectorType(), load.getBase(), load.getIndices()); return success(); } }; /// Conversion pattern for vector.store. struct VectorStoreToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::StoreOp store, PatternRewriter &rewriter) const override { if (!arm_sme::isValidSMETileVectorType(store.getVectorType())) return failure(); rewriter.replaceOpWithNewOp( store, store.getValueToStore(), store.getBase(), store.getIndices()); return success(); } }; /// Conversion pattern for vector.broadcast. /// /// Example: /// /// %broadcast_to_tile = vector.broadcast %src : i32 to vector<[4]x[4]xi32> /// /// is converted to: /// /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) /// { /// %tile_update = arm_sme.move_vector_to_tile_slice /// %broadcast_to_1d, %iter_tile, %tile_slice_index : /// vector<[4]xi32> into vector<[4]x[4]xi32> /// scf.yield %tile_update : vector<[4]x[4]xi32> /// } /// /// Supports scalar, 0-d vector, and 1-d vector broadcasts. struct BroadcastOpToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, PatternRewriter &rewriter) const final { auto tileType = broadcastOp.getResultVectorType(); if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) return failure(); auto loc = broadcastOp.getLoc(); auto srcType = broadcastOp.getSourceType(); auto srcVectorType = dyn_cast(srcType); Value broadcastOp1D; if (srcType.isIntOrFloat() || (srcVectorType && (srcVectorType.getRank() == 0))) { // Broadcast scalar or 0-d vector to 1-d vector. VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); broadcastOp1D = rewriter.create( loc, tileSliceType, broadcastOp.getSource()); } else if (srcVectorType && (srcVectorType.getRank() == 1)) // Value to broadcast is already a 1-d vector, nothing to do. broadcastOp1D = broadcastOp.getSource(); else return failure(); auto initTile = rewriter.create(loc, tileType); auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, Value currentTile) { // Create 'arm_sme.move_vector_to_tile_slice' to broadcast the value // to each tile slice. auto nextTile = b.create( loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); return nextTile.getResult(); }; // Create a loop over ZA tile slices. auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody); rewriter.replaceOp(broadcastOp, forOp.getResult(0)); return success(); } }; /// Conversion pattern for vector.splat. /// /// Example: /// /// %splat_to_tile = vector.splat %src : i32 to vector<[4]x[4]xi32> /// /// is converted to: /// /// %broadcast_to_1d = vector.broadcast %src : i32 to vector<[4]xi32> /// %broadcast_to_tile = scf.for %tile_slice_index = %c0 to %num_tile_slices /// step %c1 iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) /// { /// %tile_update = arm_sme.move_vector_to_tile_slice /// %broadcast_to_1d, %iter_tile, %tile_slice_index : /// vector<[4]xi32> into vector<[4]x[4]xi32> /// scf.yield %tile_update : vector<[4]x[4]xi32> /// } /// /// This is identical to vector.broadcast of a scalar. struct SplatOpToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::SplatOp splatOp, PatternRewriter &rewriter) const final { auto tileType = splatOp.getResult().getType(); if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) return failure(); auto loc = splatOp.getLoc(); auto srcType = splatOp.getOperand().getType(); assert(srcType.isIntOrFloat() && "Invalid source type for vector.splat"); // Avoid unused-variable warning when building without assertions. (void)srcType; // First, broadcast the scalar to a 1-d vector. VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); Value broadcastOp1D = rewriter.create( loc, tileSliceType, splatOp.getInput()); auto initTile = rewriter.create(loc, tileType); auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, Value currentTile) { auto nextTile = b.create( loc, tileType, broadcastOp1D, currentTile, tileSliceIndex); return nextTile.getResult(); }; // Next, create a loop over ZA tile slices and "move" the generated 1-d // vector to each slice. auto forOp = createLoopOverTileSlices(rewriter, loc, initTile, makeLoopBody); rewriter.replaceOp(splatOp, forOp.getResult(0)); return success(); } }; /// Conversion pattern for vector.transpose. /// /// Stores the input tile to memory and reloads vertically. /// /// Example: /// /// %transposed_src = vector.transpose %src, [1, 0] /// : vector<[4]x[4]xi32> to vector<[4]x[4]xi32> /// /// is converted to: /// /// %alloca = memref.alloca(%svl_s, %svl_s) : memref /// %arm_sme.tile_store %src, , %alloca[%c0, %c0] /// : memref, vector<[4]x[4]xi32> /// %transposed_src = arm_sme.tile_load %alloca[%c0, %c0] /// layout : memref, vector<[4]x[4]xi32> /// /// NOTE: Tranposing via memory is obviously expensive, the current intention /// is to avoid the transpose if possible, this is therefore intended as a /// fallback and to provide base support for Vector ops. If it turns out /// transposes can't be avoided then this should be replaced with a more optimal /// implementation, perhaps with tile <-> vector (MOVA) ops. struct TransposeOpToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, PatternRewriter &rewriter) const final { auto tileType = transposeOp.getResultVectorType(); if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) return failure(); // Bail unless this is a true 2-D matrix transpose. ArrayRef permutation = transposeOp.getPermutation(); if (permutation[0] != 1 || permutation[1] != 0) return failure(); auto loc = transposeOp.getLoc(); // Allocate buffer to store input tile to. Value vscale = rewriter.create(loc, rewriter.getIndexType()); Value minTileSlices = rewriter.create( loc, rewriter.getIndexAttr(tileType.getDimSize(0))); Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); Value numTileSlices = rewriter.create(loc, vscale, minTileSlices); auto bufferType = MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic}, tileType.getElementType()); auto buffer = rewriter.create( loc, bufferType, ValueRange{numTileSlices, numTileSlices}); Value input = transposeOp.getVector(); // Store input tile. auto tileStoreOp = rewriter.create( loc, input, buffer, ValueRange{c0, c0}); // Reload input tile vertically. rewriter.replaceOpWithNewOp( transposeOp, tileType, tileStoreOp.getBase(), tileStoreOp.getIndices(), arm_sme::TileSliceLayout::Vertical); return success(); } }; /// Conversion pattern for vector.outerproduct. /// /// If the vector.outerproduct is masked (and the mask is from a /// vector.create_mask), then the mask is decomposed into two 1-D masks for the /// operands. /// /// Example: /// /// %mask = vector.create_mask %dimA, %dimB : vector<[4]x[4]xi1> /// %result = vector.mask %mask { /// vector.outerproduct %vecA, %vecB /// : vector<[4]xf32>, vector<[4]xf32> /// } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> /// /// is converted to: /// /// %maskA = vector.create_mask %dimA : vector<[4]xi1> /// %maskB = vector.create_mask %dimB : vector<[4]xi1> /// %result = arm_sme.outerproduct %vecA, %vecB masks(%maskA, %maskB) /// : vector<[4]xf32>, vector<[4]xf32> /// /// Unmasked outerproducts can be directly replaced with the arm_sme op. /// /// Example: /// /// %result = vector.outerproduct %vecA, %vecB /// : vector<[4]xf32>, vector<[4]xf32> /// /// is converted to: /// /// %result = arm_sme.outerproduct %vecA, %vecB /// : vector<[4]xf32>, vector<[4]xf32> /// struct VectorOuterProductToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::OuterProductOp outerProductOp, PatternRewriter &rewriter) const override { // We don't yet support lowering AXPY operations to SME. These could be // lowered by masking out all but the first element of the LHS. if (!isa(outerProductOp.getOperandTypeRHS())) return rewriter.notifyMatchFailure(outerProductOp, "AXPY operations not supported"); if (!arm_sme::isValidSMETileVectorType( outerProductOp.getResultVectorType())) return rewriter.notifyMatchFailure( outerProductOp, "outer product does not fit into SME tile"); auto kind = outerProductOp.getKind(); if (kind != vector::CombiningKind::ADD) return rewriter.notifyMatchFailure( outerProductOp, "unsupported kind (lowering to SME only supports ADD at the moment)"); Value lhsMask = {}; Value rhsMask = {}; Operation *rootOp = outerProductOp; auto loc = outerProductOp.getLoc(); if (outerProductOp.isMasked()) { auto maskOp = outerProductOp.getMaskingOp(); rewriter.setInsertionPoint(maskOp); rootOp = maskOp; auto operandMasks = decomposeResultMask(loc, maskOp.getMask(), rewriter); if (failed(operandMasks)) return failure(); std::tie(lhsMask, rhsMask) = *operandMasks; } rewriter.replaceOpWithNewOp( rootOp, outerProductOp.getResultVectorType(), outerProductOp.getLhs(), outerProductOp.getRhs(), lhsMask, rhsMask, outerProductOp.getAcc()); return success(); } static FailureOr> decomposeResultMask(Location loc, Value mask, PatternRewriter &rewriter) { // Attempt to extract masks from vector.create_mask. // TODO: Add support for other mask sources. auto createMaskOp = mask.getDefiningOp(); if (!createMaskOp) return failure(); auto maskType = createMaskOp.getVectorType(); Value lhsMaskDim = createMaskOp.getOperand(0); Value rhsMaskDim = createMaskOp.getOperand(1); VectorType operandMaskType = VectorType::Builder(maskType).dropDim(0); Value lhsMask = rewriter.create(loc, operandMaskType, lhsMaskDim); Value rhsMask = rewriter.create(loc, operandMaskType, rhsMaskDim); return std::make_pair(lhsMask, rhsMask); } }; /// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`. /// /// Example: /// ``` /// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32> /// ``` /// Becomes: /// ``` /// %slice = arm_sme.move_tile_slice_to_vector %tile[%row] /// : vector<[4]xi32> from vector<[4]x[4]xi32> /// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32> /// ``` struct VectorExtractToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ExtractOp extractOp, PatternRewriter &rewriter) const override { VectorType sourceType = extractOp.getSourceVectorType(); if (!arm_sme::isValidSMETileVectorType(sourceType)) return failure(); auto loc = extractOp.getLoc(); auto position = extractOp.getMixedPosition(); Value sourceVector = extractOp.getVector(); // Extract entire vector. Should be handled by folder, but just to be safe. if (position.empty()) { rewriter.replaceOp(extractOp, sourceVector); return success(); } Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front(); auto moveTileSliceToVector = rewriter.create(loc, sourceVector, sliceIndex); if (position.size() == 1) { // Single index case: Extracts a 1D slice. rewriter.replaceOp(extractOp, moveTileSliceToVector); return success(); } // Two indices case: Extracts a single element. assert(position.size() == 2); rewriter.replaceOpWithNewOp( extractOp, moveTileSliceToVector, position[1]); return success(); } }; /// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and /// `arm_sme.move_tile_slice_to_vector`. /// /// Example: /// ``` /// %new_tile = vector.insert %el, %tile[%row, %col] /// : i32 into vector<[4]x[4]xi32> /// ``` /// Becomes: /// ``` /// %slice = arm_sme.move_tile_slice_to_vector %tile[%row] /// : vector<[4]xi32> from vector<[4]x[4]xi32> /// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32> /// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row /// : vector<[4]xi32> into vector<[4]x[4]xi32> /// ``` struct VectorInsertToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::InsertOp insertOp, PatternRewriter &rewriter) const override { VectorType resultType = insertOp.getResult().getType(); if (!arm_sme::isValidSMETileVectorType(resultType)) return failure(); auto loc = insertOp.getLoc(); auto position = insertOp.getMixedPosition(); Value source = insertOp.getSource(); // Overwrite entire vector with value. Should be handled by folder, but // just to be safe. if (position.empty()) { rewriter.replaceOp(insertOp, source); return success(); } Value tileSlice = source; Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front(); if (position.size() == 2) { // Two indices case: Insert single element into tile. // We need to first extract the existing slice and update the element. tileSlice = rewriter.create( loc, insertOp.getDest(), sliceIndex); tileSlice = rewriter.create(loc, source, tileSlice, position[1]); } // Insert the slice into the destination tile. rewriter.replaceOpWithNewOp( insertOp, tileSlice, insertOp.getDest(), sliceIndex); return success(); } }; /// Lowers `vector.print` of a tile into a loop over the rows of the tile, /// extracting them via `arm_sme.move_tile_slice_to_vector`, then printing with /// a 1D `vector.print`. /// /// BEFORE: /// ```mlir /// vector.print %tile : vector<[4]x[4]xf32> /// ``` /// AFTER: /// ```mlir /// %c0 = arith.constant 0 : index /// %c1 = arith.constant 1 : index /// %c4 = arith.constant 4 : index /// %vscale = vector.vscale /// %svl_s = arith.muli %c4, %vscale : index /// scf.for %i = %c0 to %svl_s step %c1 { /// %tile_slice = arm_sme.move_tile_slice_to_vector %tile[%i] /// : vector<[4]xf32> from vector<[4]x[4]xf32> /// vector.print %tile_slice : vector<[4]xf32> /// } /// ``` struct VectorPrintToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::PrintOp printOp, PatternRewriter &rewriter) const override { if (!printOp.getSource()) return failure(); VectorType vectorType = dyn_cast(printOp.getPrintType()); if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType)) return failure(); auto loc = printOp.getLoc(); // Create a loop over the rows of the tile. auto vscale = rewriter.create(loc); auto minTileRows = rewriter.create(loc, vectorType.getDimSize(0)); auto lowerBound = rewriter.create(loc, 0); auto upperBound = rewriter.create(loc, minTileRows, vscale); auto step = rewriter.create(loc, 1); auto forOp = rewriter.create(loc, lowerBound, upperBound, step); { // Loop body. rewriter.setInsertionPointToStart(forOp.getBody()); // Extract the current row from the tile. Value rowIndex = forOp.getInductionVar(); auto tileSlice = rewriter.create( loc, printOp.getSource(), rowIndex); // Print the row with a 1D vector.print. rewriter.create(loc, tileSlice, printOp.getPunctuation()); } rewriter.eraseOp(printOp); return success(); } }; } // namespace void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx) { patterns.add(&ctx); }