//===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Utils/Utils.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { #define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir #define DEBUG_TYPE "arith-to-arm-sme" using namespace mlir; //===----------------------------------------------------------------------===// // Conversion helpers //===----------------------------------------------------------------------===// /// Returns true if 'val' is a splat of zero, false otherwise. static bool isSplatZero(Type elemType, DenseElementsAttr val) { if (llvm::isa(elemType)) return val && val.isSplat() && val.getSplatValue().isZero(); if (llvm::isa(elemType)) return val && val.isSplat() && val.getSplatValue().isZero(); return false; } namespace { //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// /// Conversion pattern for dense arith.constant. struct ConstantOpToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(arith::ConstantOp constantOp, PatternRewriter &rewriter) const final { auto tileType = dyn_cast(constantOp.getType()); if (!tileType || !arm_sme::isValidSMETileVectorType(tileType)) return failure(); auto denseAttr = dyn_cast(constantOp.getValueAttr()); if (!denseAttr || !denseAttr.isSplat()) return failure(); auto tileElementType = tileType.getElementType(); // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op. if (isSplatZero(tileElementType, denseAttr)) { rewriter.replaceOpWithNewOp(constantOp, tileType); return success(); } // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice' // ops that broadcast the constant to each tile slice. auto loc = constantOp.getLoc(); // To fill a tile with a constant, we create a 1-D splat of the constant, // then move that into each tile slice (the largest unit we can set at once, // outside of operations like the outerproduct). VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); auto denseAttr1D = DenseElementsAttr::get( tileSliceType, denseAttr.getSplatValue()); auto constantOp1D = rewriter.create(loc, denseAttr1D); auto initTile = rewriter.create(loc, tileType); auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, Value currentTile) { // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile // slice. auto nextTile = b.create( loc, tileType, constantOp1D, currentTile, tileSliceIndex); return nextTile.getResult(); }; auto forOp = mlir::arm_sme::createLoopOverTileSlices( rewriter, loc, initTile, makeLoopBody); rewriter.replaceOp(constantOp, forOp.getResult(0)); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// void mlir::arith::populateArithToArmSMEConversionPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); } //===----------------------------------------------------------------------===// // Pass definition //===----------------------------------------------------------------------===// namespace { struct ArithToArmSMEConversionPass final : impl::ArithToArmSMEConversionPassBase { using impl::ArithToArmSMEConversionPassBase< ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase; void runOnOperation() override { RewritePatternSet patterns(&getContext()); arith::populateArithToArmSMEConversionPatterns(patterns); if (failed( applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; } // namespace