//===- MemRefDataFlowOpt.cpp - Memory DataFlow Optimization pass ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/Operation.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include namespace fir { #define GEN_PASS_DEF_MEMREFDATAFLOWOPT #include "flang/Optimizer/Transforms/Passes.h.inc" } // namespace fir #define DEBUG_TYPE "fir-memref-dataflow-opt" using namespace mlir; namespace { template static std::vector getSpecificUsers(mlir::Value v) { std::vector ops; for (mlir::Operation *user : v.getUsers()) if (auto op = dyn_cast(user)) ops.push_back(op); return ops; } /// This is based on MLIR's MemRefDataFlowOpt which is specialized on AffineRead /// and AffineWrite interface template class LoadStoreForwarding { public: LoadStoreForwarding(mlir::DominanceInfo *di) : domInfo(di) {} // FIXME: This algorithm has a bug. It ignores escaping references between a // store and a load. std::optional findStoreToForward(ReadOp loadOp, std::vector &&storeOps) { llvm::SmallVector candidateSet; for (auto storeOp : storeOps) if (domInfo->dominates(storeOp, loadOp)) candidateSet.push_back(storeOp); if (candidateSet.empty()) return {}; std::optional nearestStore; for (auto candidate : candidateSet) { auto nearerThan = [&](WriteOp otherStore) { if (candidate == otherStore) return false; bool rv = domInfo->properlyDominates(candidate, otherStore); if (rv) { LLVM_DEBUG(llvm::dbgs() << "candidate " << candidate << " is not the nearest to " << loadOp << " because " << otherStore << " is closer\n"); } return rv; }; if (!llvm::any_of(candidateSet, nearerThan)) { nearestStore = mlir::cast(candidate); break; } } if (!nearestStore) { LLVM_DEBUG( llvm::dbgs() << "load " << loadOp << " has " << candidateSet.size() << " store candidates, but this algorithm can't find a best.\n"); } return nearestStore; } std::optional findReadForWrite(WriteOp storeOp, std::vector &&loadOps) { for (auto &loadOp : loadOps) { if (domInfo->dominates(storeOp, loadOp)) return loadOp; } return {}; } private: mlir::DominanceInfo *domInfo; }; class MemDataFlowOpt : public fir::impl::MemRefDataFlowOptBase { public: void runOnOperation() override { mlir::func::FuncOp f = getOperation(); auto *domInfo = &getAnalysis(); LoadStoreForwarding lsf(domInfo); f.walk([&](fir::LoadOp loadOp) { auto maybeStore = lsf.findStoreToForward( loadOp, getSpecificUsers(loadOp.getMemref())); if (maybeStore) { auto storeOp = *maybeStore; LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName() << " erasing load " << loadOp << " with value from " << storeOp << '\n'); loadOp.getResult().replaceAllUsesWith(storeOp.getValue()); loadOp.erase(); } }); f.walk([&](fir::AllocaOp alloca) { for (auto &storeOp : getSpecificUsers(alloca.getResult())) { if (!lsf.findReadForWrite( storeOp, getSpecificUsers(storeOp.getMemref()))) { LLVM_DEBUG(llvm::dbgs() << "FlangMemDataFlowOpt: In " << f.getName() << " erasing store " << storeOp << '\n'); storeOp.erase(); } } }); } }; } // namespace std::unique_ptr fir::createMemDataFlowOptPass() { return std::make_unique(); }