//===-- PreCGRewrite.cpp --------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ // //===----------------------------------------------------------------------===// #include "flang/Optimizer/CodeGen/CodeGen.h" #include "CGOps.h" #include "flang/Optimizer/Builder/Todo.h" // remove when TODO's are done #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Dialect/Support/FIRContext.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" namespace fir { #define GEN_PASS_DEF_CODEGENREWRITE #include "flang/Optimizer/CodeGen/CGPasses.h.inc" } // namespace fir //===----------------------------------------------------------------------===// // Codegen rewrite: rewriting of subgraphs of ops //===----------------------------------------------------------------------===// #define DEBUG_TYPE "flang-codegen-rewrite" static void populateShape(llvm::SmallVectorImpl &vec, fir::ShapeOp shape) { vec.append(shape.getExtents().begin(), shape.getExtents().end()); } // Operands of fir.shape_shift split into two vectors. static void populateShapeAndShift(llvm::SmallVectorImpl &shapeVec, llvm::SmallVectorImpl &shiftVec, fir::ShapeShiftOp shift) { for (auto i = shift.getPairs().begin(), endIter = shift.getPairs().end(); i != endIter;) { shiftVec.push_back(*i++); shapeVec.push_back(*i++); } } static void populateShift(llvm::SmallVectorImpl &vec, fir::ShiftOp shift) { vec.append(shift.getOrigins().begin(), shift.getOrigins().end()); } namespace { /// Convert fir.embox to the extended form where necessary. /// /// The embox operation can take arguments that specify multidimensional array /// properties at runtime. These properties may be shared between distinct /// objects that have the same properties. Before we lower these small DAGs to /// LLVM-IR, we gather all the information into a single extended operation. For /// example, /// ``` /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1> /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1> /// %3 = fir.embox %0 (%1) [%2] : (!fir.ref>, /// !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box> /// ``` /// can be rewritten as /// ``` /// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] : /// (!fir.ref>, index, index, index, index, index) -> /// !fir.box> /// ``` class EmboxConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(fir::EmboxOp embox, mlir::PatternRewriter &rewriter) const override { // If the embox does not include a shape, then do not convert it if (auto shapeVal = embox.getShape()) return rewriteDynamicShape(embox, rewriter, shapeVal); if (embox.getType().isa()) TODO(embox.getLoc(), "embox conversion for fir.class type"); if (auto boxTy = embox.getType().dyn_cast()) if (auto seqTy = boxTy.getEleTy().dyn_cast()) if (!seqTy.hasDynamicExtents()) return rewriteStaticShape(embox, rewriter, seqTy); return mlir::failure(); } mlir::LogicalResult rewriteStaticShape(fir::EmboxOp embox, mlir::PatternRewriter &rewriter, fir::SequenceType seqTy) const { auto loc = embox.getLoc(); llvm::SmallVector shapeOpers; auto idxTy = rewriter.getIndexType(); for (auto ext : seqTy.getShape()) { auto iAttr = rewriter.getIndexAttr(ext); auto extVal = rewriter.create(loc, idxTy, iAttr); shapeOpers.push_back(extVal); } auto xbox = rewriter.create( loc, embox.getType(), embox.getMemref(), shapeOpers, std::nullopt, std::nullopt, std::nullopt, std::nullopt, embox.getTypeparams(), embox.getSourceBox()); LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); rewriter.replaceOp(embox, xbox.getOperation()->getResults()); return mlir::success(); } mlir::LogicalResult rewriteDynamicShape(fir::EmboxOp embox, mlir::PatternRewriter &rewriter, mlir::Value shapeVal) const { auto loc = embox.getLoc(); llvm::SmallVector shapeOpers; llvm::SmallVector shiftOpers; if (auto shapeOp = mlir::dyn_cast(shapeVal.getDefiningOp())) { populateShape(shapeOpers, shapeOp); } else { auto shiftOp = mlir::dyn_cast(shapeVal.getDefiningOp()); assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift"); populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); } llvm::SmallVector sliceOpers; llvm::SmallVector subcompOpers; llvm::SmallVector substrOpers; if (auto s = embox.getSlice()) if (auto sliceOp = mlir::dyn_cast_or_null(s.getDefiningOp())) { sliceOpers.assign(sliceOp.getTriples().begin(), sliceOp.getTriples().end()); subcompOpers.assign(sliceOp.getFields().begin(), sliceOp.getFields().end()); substrOpers.assign(sliceOp.getSubstr().begin(), sliceOp.getSubstr().end()); } auto xbox = rewriter.create( loc, embox.getType(), embox.getMemref(), shapeOpers, shiftOpers, sliceOpers, subcompOpers, substrOpers, embox.getTypeparams(), embox.getSourceBox()); LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n'); rewriter.replaceOp(embox, xbox.getOperation()->getResults()); return mlir::success(); } }; /// Convert fir.rebox to the extended form where necessary. /// /// For example, /// ``` /// %5 = fir.rebox %3(%1) : (!fir.box>, !fir.shapeshift<1>) -> /// !fir.box> /// ``` /// converted to /// ``` /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box>, /// index, index) -> !fir.box> /// ``` class ReboxConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(fir::ReboxOp rebox, mlir::PatternRewriter &rewriter) const override { auto loc = rebox.getLoc(); llvm::SmallVector shapeOpers; llvm::SmallVector shiftOpers; if (auto shapeVal = rebox.getShape()) { if (auto shapeOp = mlir::dyn_cast(shapeVal.getDefiningOp())) populateShape(shapeOpers, shapeOp); else if (auto shiftOp = mlir::dyn_cast(shapeVal.getDefiningOp())) populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); else if (auto shiftOp = mlir::dyn_cast(shapeVal.getDefiningOp())) populateShift(shiftOpers, shiftOp); else return mlir::failure(); } llvm::SmallVector sliceOpers; llvm::SmallVector subcompOpers; llvm::SmallVector substrOpers; if (auto s = rebox.getSlice()) if (auto sliceOp = mlir::dyn_cast_or_null(s.getDefiningOp())) { sliceOpers.append(sliceOp.getTriples().begin(), sliceOp.getTriples().end()); subcompOpers.append(sliceOp.getFields().begin(), sliceOp.getFields().end()); substrOpers.append(sliceOp.getSubstr().begin(), sliceOp.getSubstr().end()); } auto xRebox = rewriter.create( loc, rebox.getType(), rebox.getBox(), shapeOpers, shiftOpers, sliceOpers, subcompOpers, substrOpers); LLVM_DEBUG(llvm::dbgs() << "rewriting " << rebox << " to " << xRebox << '\n'); rewriter.replaceOp(rebox, xRebox.getOperation()->getResults()); return mlir::success(); } }; /// Convert all fir.array_coor to the extended form. /// /// For example, /// ``` /// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref>, /// !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref /// ``` /// converted to /// ``` /// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> : /// (!fir.ref>, index, index, index, index, index, index) -> /// !fir.ref /// ``` class ArrayCoorConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(fir::ArrayCoorOp arrCoor, mlir::PatternRewriter &rewriter) const override { auto loc = arrCoor.getLoc(); llvm::SmallVector shapeOpers; llvm::SmallVector shiftOpers; if (auto shapeVal = arrCoor.getShape()) { if (auto shapeOp = mlir::dyn_cast(shapeVal.getDefiningOp())) populateShape(shapeOpers, shapeOp); else if (auto shiftOp = mlir::dyn_cast(shapeVal.getDefiningOp())) populateShapeAndShift(shapeOpers, shiftOpers, shiftOp); else if (auto shiftOp = mlir::dyn_cast(shapeVal.getDefiningOp())) populateShift(shiftOpers, shiftOp); else return mlir::failure(); } llvm::SmallVector sliceOpers; llvm::SmallVector subcompOpers; if (auto s = arrCoor.getSlice()) if (auto sliceOp = mlir::dyn_cast_or_null(s.getDefiningOp())) { sliceOpers.append(sliceOp.getTriples().begin(), sliceOp.getTriples().end()); subcompOpers.append(sliceOp.getFields().begin(), sliceOp.getFields().end()); assert(sliceOp.getSubstr().empty() && "Don't allow substring operations on array_coor. This " "restriction may be lifted in the future."); } auto xArrCoor = rewriter.create( loc, arrCoor.getType(), arrCoor.getMemref(), shapeOpers, shiftOpers, sliceOpers, subcompOpers, arrCoor.getIndices(), arrCoor.getTypeparams()); LLVM_DEBUG(llvm::dbgs() << "rewriting " << arrCoor << " to " << xArrCoor << '\n'); rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults()); return mlir::success(); } }; class DeclareOpConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(fir::DeclareOp declareOp, mlir::PatternRewriter &rewriter) const override { rewriter.replaceOp(declareOp, declareOp.getMemref()); return mlir::success(); } }; class CodeGenRewrite : public fir::impl::CodeGenRewriteBase { public: void runOn(mlir::Operation *op, mlir::Region ®ion) { auto &context = getContext(); mlir::ConversionTarget target(context); target.addLegalDialect(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addDynamicallyLegalOp([](fir::EmboxOp embox) { return !(embox.getShape() || embox.getType() .cast() .getEleTy() .isa()); }); mlir::RewritePatternSet patterns(&context); patterns.insert(&context); if (mlir::failed( mlir::applyPartialConversion(op, target, std::move(patterns)))) { mlir::emitError(mlir::UnknownLoc::get(&context), "error in running the pre-codegen conversions"); signalPassFailure(); return; } // Erase any residual (fir.shape, fir.slice...). mlir::IRRewriter rewriter(&context); (void)mlir::runRegionDCE(rewriter, op->getRegions()); } void runOnOperation() override final { // Call runOn on all top level regions that may contain emboxOp/arrayCoorOp. auto mod = getOperation(); for (auto func : mod.getOps()) runOn(func, func.getBody()); for (auto global : mod.getOps()) runOn(global, global.getRegion()); } }; } // namespace std::unique_ptr fir::createFirCodeGenRewritePass() { return std::make_unique(); }