//===-- AffineDemotion.cpp -----------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This transformation is a prototype that demote affine dialects operations // after optimizations to FIR loops operations. // It is used after the AffinePromotion pass. // It is not part of the production pipeline and would need more work in order // to be used in production. // More information can be found in this presentation: // https://slides.com/rajanwalia/deck // //===----------------------------------------------------------------------===// #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Transforms/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" namespace fir { #define GEN_PASS_DEF_AFFINEDIALECTDEMOTION #include "flang/Optimizer/Transforms/Passes.h.inc" } // namespace fir #define DEBUG_TYPE "flang-affine-demotion" using namespace fir; using namespace mlir; namespace { class AffineLoadConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(mlir::affine::AffineLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector indices(adaptor.getIndices()); auto maybeExpandedMap = affine::expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); if (!maybeExpandedMap) return failure(); auto coorOp = rewriter.create( op.getLoc(), fir::ReferenceType::get(op.getResult().getType()), adaptor.getMemref(), *maybeExpandedMap); rewriter.replaceOpWithNewOp(op, coorOp.getResult()); return success(); } }; class AffineStoreConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(mlir::affine::AffineStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector indices(op.getIndices()); auto maybeExpandedMap = affine::expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); if (!maybeExpandedMap) return failure(); auto coorOp = rewriter.create( op.getLoc(), fir::ReferenceType::get(op.getValueToStore().getType()), adaptor.getMemref(), *maybeExpandedMap); rewriter.replaceOpWithNewOp(op, adaptor.getValue(), coorOp.getResult()); return success(); } }; class ConvertConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(fir::ConvertOp op, mlir::PatternRewriter &rewriter) const override { if (op.getRes().getType().isa()) { // due to index calculation moving to affine maps we still need to // add converts for sequence types this has a side effect of losing // some information about arrays with known dimensions by creating: // fir.convert %arg0 : (!fir.ref>) -> // !fir.ref> if (auto refTy = op.getValue().getType().dyn_cast()) if (auto arrTy = refTy.getEleTy().dyn_cast()) { fir::SequenceType::Shape flatShape = { fir::SequenceType::getUnknownExtent()}; auto flatArrTy = fir::SequenceType::get(flatShape, arrTy.getEleTy()); auto flatTy = fir::ReferenceType::get(flatArrTy); rewriter.replaceOpWithNewOp(op, flatTy, op.getValue()); return success(); } rewriter.startOpModification(op->getParentOp()); op.getResult().replaceAllUsesWith(op.getValue()); rewriter.finalizeOpModification(op->getParentOp()); rewriter.eraseOp(op); } return success(); } }; mlir::Type convertMemRef(mlir::MemRefType type) { return fir::SequenceType::get( SmallVector(type.getShape().begin(), type.getShape().end()), type.getElementType()); } class StdAllocConversion : public mlir::OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(memref::AllocOp op, mlir::PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, convertMemRef(op.getType()), op.getMemref()); return success(); } }; class AffineDialectDemotion : public fir::impl::AffineDialectDemotionBase { public: void runOnOperation() override { auto *context = &getContext(); auto function = getOperation(); LLVM_DEBUG(llvm::dbgs() << "AffineDemotion: running on function:\n"; function.print(llvm::dbgs());); mlir::RewritePatternSet patterns(context); patterns.insert(context); patterns.insert(context); patterns.insert(context); patterns.insert(context); mlir::ConversionTarget target(*context); target.addIllegalOp(); target.addDynamicallyLegalOp([](fir::ConvertOp op) { if (op.getRes().getType().isa()) return false; return true; }); target .addLegalDialect(); if (mlir::failed(mlir::applyPartialConversion(function, target, std::move(patterns)))) { mlir::emitError(mlir::UnknownLoc::get(context), "error in converting affine dialect\n"); signalPassFailure(); } } }; } // namespace std::unique_ptr fir::createAffineDemotionPass() { return std::make_unique(); }