//===- InlineElementals.cpp - Inline chained hlfir.elemental ops ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // Chained elemental operations like a + b + c can inline the first elemental // at the hlfir.apply in the body of the second one (as described in // docs/HighLevelFIR.md). This has to be done in a pass rather than in lowering // so that it happens after the HLFIR intrinsic simplification pass. //===----------------------------------------------------------------------===// #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/HLFIRTools.h" #include "flang/Optimizer/Dialect/Support/FIRContext.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/HLFIR/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" #include namespace hlfir { #define GEN_PASS_DEF_INLINEELEMENTALS #include "flang/Optimizer/HLFIR/Passes.h.inc" } // namespace hlfir /// If the elemental has only two uses and those two are an apply operation and /// a destory operation, return those two, otherwise return {} static std::optional> getTwoUses(hlfir::ElementalOp elemental) { mlir::Operation::user_range users = elemental->getUsers(); // don't inline anything with more than one use (plus hfir.destroy) if (std::distance(users.begin(), users.end()) != 2) { return std::nullopt; } // If the ElementalOp must produce a temporary (e.g. for // finalization purposes), then we cannot inline it. if (hlfir::elementalOpMustProduceTemp(elemental)) return std::nullopt; hlfir::ApplyOp apply; hlfir::DestroyOp destroy; for (mlir::Operation *user : users) mlir::TypeSwitch(user) .Case([&](hlfir::ApplyOp op) { apply = op; }) .Case([&](hlfir::DestroyOp op) { destroy = op; }); if (!apply || !destroy) return std::nullopt; // we can't inline if the return type of the yield doesn't match the return // type of the apply auto yield = mlir::dyn_cast_or_null( elemental.getRegion().back().back()); assert(yield && "hlfir.elemental should always end with a yield"); if (apply.getResult().getType() != yield.getElementValue().getType()) return std::nullopt; return std::pair{apply, destroy}; } namespace { class InlineElementalConversion : public mlir::OpRewritePattern { public: using mlir::OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite(hlfir::ElementalOp elemental, mlir::PatternRewriter &rewriter) const override { std::optional> maybeTuple = getTwoUses(elemental); if (!maybeTuple) return rewriter.notifyMatchFailure( elemental, "hlfir.elemental does not have two uses"); if (elemental.isOrdered()) { // We can only inline the ordered elemental into a loop-like // construct that processes the indices in-order and does not // have the side effects itself. Adhere to conservative behavior // for the time being. return rewriter.notifyMatchFailure(elemental, "hlfir.elemental is ordered"); } auto [apply, destroy] = *maybeTuple; assert(elemental.getRegion().hasOneBlock() && "expect elemental region to have one block"); fir::FirOpBuilder builder{rewriter, elemental.getOperation()}; builder.setInsertionPointAfter(apply); hlfir::YieldElementOp yield = hlfir::inlineElementalOp( elemental.getLoc(), builder, elemental, apply.getIndices()); // remove the old elemental and all of the bookkeeping rewriter.replaceAllUsesWith(apply.getResult(), yield.getElementValue()); rewriter.eraseOp(yield); rewriter.eraseOp(apply); rewriter.eraseOp(destroy); rewriter.eraseOp(elemental); return mlir::success(); } }; class InlineElementalsPass : public hlfir::impl::InlineElementalsBase { public: void runOnOperation() override { mlir::func::FuncOp func = getOperation(); mlir::MLIRContext *context = &getContext(); mlir::GreedyRewriteConfig config; // Prevent the pattern driver from merging blocks. config.enableRegionSimplification = false; mlir::RewritePatternSet patterns(context); patterns.insert(context); if (mlir::failed(mlir::applyPatternsAndFoldGreedily( func, std::move(patterns), config))) { mlir::emitError(func->getLoc(), "failure in HLFIR elemental inlining"); signalPassFailure(); } } }; } // namespace std::unique_ptr hlfir::createInlineElementalsPass() { return std::make_unique(); }