//===- TestDenseBackwardDataFlowAnalysis.cpp - Test pass ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Test pass for backward dense dataflow analysis. // //===----------------------------------------------------------------------===// #include "TestDenseDataFlowAnalysis.h" #include "TestDialect.h" #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/DenseAnalysis.h" #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/IR/Builders.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/TypeID.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::dataflow; using namespace mlir::dataflow::test; namespace { class NextAccess : public AbstractDenseLattice, public AccessLatticeBase { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NextAccess) using dataflow::AbstractDenseLattice::AbstractDenseLattice; ChangeResult meet(const AbstractDenseLattice &lattice) override { return AccessLatticeBase::merge(static_cast( static_cast(lattice))); } void print(raw_ostream &os) const override { return AccessLatticeBase::print(os); } }; class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis { public: NextAccessAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable, bool assumeFuncReads = false) : DenseBackwardDataFlowAnalysis(solver, symbolTable), assumeFuncReads(assumeFuncReads) {} void visitOperation(Operation *op, const NextAccess &after, NextAccess *before) override; void visitCallControlFlowTransfer(CallOpInterface call, CallControlFlowAction action, const NextAccess &after, NextAccess *before) override; void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, RegionBranchPoint regionFrom, RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) override; // TODO: this isn't ideal for the analysis. When there is no next access, it // means "we don't know what the next access is" rather than "there is no next // access". But it's unclear how to differentiate the two cases... void setToExitState(NextAccess *lattice) override { propagateIfChanged(lattice, lattice->setKnownToUnknown()); } const bool assumeFuncReads; }; } // namespace void NextAccessAnalysis::visitOperation(Operation *op, const NextAccess &after, NextAccess *before) { auto memory = dyn_cast(op); // If we can't reason about the memory effects, conservatively assume we can't // say anything about the next access. if (!memory) return setToExitState(before); SmallVector effects; memory.getEffects(effects); // First, check if all underlying values are already known. Otherwise, avoid // propagating and stay in the "undefined" state to avoid incorrectly // propagating values that may be overwritten later on as that could be // problematic for convergence based on monotonicity of lattice updates. SmallVector underlyingValues; underlyingValues.reserve(effects.size()); for (const MemoryEffects::EffectInstance &effect : effects) { Value value = effect.getValue(); // Effects with unspecified value are treated conservatively and we cannot // assume anything about the next access. if (!value) return setToExitState(before); // If cannot find the most underlying value, we cannot assume anything about // the next accesses. std::optional underlyingValue = UnderlyingValueAnalysis::getMostUnderlyingValue( value, [&](Value value) { return getOrCreateFor(op, value); }); // If the underlying value is not known yet, don't propagate. if (!underlyingValue) return; underlyingValues.push_back(*underlyingValue); } // Update the state if all underlying values are known. ChangeResult result = before->meet(after); for (const auto &[effect, value] : llvm::zip(effects, underlyingValues)) { // If the underlying value is known to be unknown, set to fixpoint. if (!value) return setToExitState(before); result |= before->set(value, op); } propagateIfChanged(before, result); } void NextAccessAnalysis::visitCallControlFlowTransfer( CallOpInterface call, CallControlFlowAction action, const NextAccess &after, NextAccess *before) { if (action == CallControlFlowAction::ExternalCallee && assumeFuncReads) { SmallVector underlyingValues; underlyingValues.reserve(call->getNumOperands()); for (Value operand : call.getArgOperands()) { std::optional underlyingValue = UnderlyingValueAnalysis::getMostUnderlyingValue( operand, [&](Value value) { return getOrCreateFor( call.getOperation(), value); }); if (!underlyingValue) return; underlyingValues.push_back(*underlyingValue); } ChangeResult result = before->meet(after); for (Value operand : underlyingValues) { result |= before->set(operand, call); } return propagateIfChanged(before, result); } auto testCallAndStore = dyn_cast<::test::TestCallAndStoreOp>(call.getOperation()); if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee && testCallAndStore.getStoreBeforeCall()) || (action == CallControlFlowAction::ExitCallee && !testCallAndStore.getStoreBeforeCall()))) { visitOperation(call, after, before); } else { AbstractDenseBackwardDataFlowAnalysis::visitCallControlFlowTransfer( call, action, after, before); } } void NextAccessAnalysis::visitRegionBranchControlFlowTransfer( RegionBranchOpInterface branch, RegionBranchPoint regionFrom, RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) { auto testStoreWithARegion = dyn_cast<::test::TestStoreWithARegion>(branch.getOperation()); if (testStoreWithARegion && ((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) || (regionFrom.isParent() && testStoreWithARegion.getStoreBeforeRegion()))) { visitOperation(branch, static_cast(after), static_cast(before)); } else { propagateIfChanged(before, before->meet(after)); } } namespace { struct TestNextAccessPass : public PassWrapper> { TestNextAccessPass() = default; TestNextAccessPass(const TestNextAccessPass &other) : PassWrapper(other) { interprocedural = other.interprocedural; assumeFuncReads = other.assumeFuncReads; } MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNextAccessPass) StringRef getArgument() const override { return "test-next-access"; } Option interprocedural{ *this, "interprocedural", llvm::cl::init(true), llvm::cl::desc("perform interprocedural analysis")}; Option assumeFuncReads{ *this, "assume-func-reads", llvm::cl::init(false), llvm::cl::desc( "assume external functions have read effect on all arguments")}; static constexpr llvm::StringLiteral kTagAttrName = "name"; static constexpr llvm::StringLiteral kNextAccessAttrName = "next_access"; static constexpr llvm::StringLiteral kAtEntryPointAttrName = "next_at_entry_point"; static Attribute makeNextAccessAttribute(Operation *op, const DataFlowSolver &solver, const NextAccess *nextAccess) { if (!nextAccess) return StringAttr::get(op->getContext(), "not computed"); // Note that if the underlying value could not be computed or is unknown, we // conservatively treat the result also unknown. SmallVector attrs; for (Value operand : op->getOperands()) { std::optional underlyingValue = UnderlyingValueAnalysis::getMostUnderlyingValue( operand, [&](Value value) { return solver.lookupState(value); }); if (!underlyingValue) { attrs.push_back(StringAttr::get(op->getContext(), "unknown")); continue; } Value value = *underlyingValue; const AdjacentAccess *nextAcc = nextAccess->getAdjacentAccess(value); if (!nextAcc || !nextAcc->isKnown()) { attrs.push_back(StringAttr::get(op->getContext(), "unknown")); continue; } SmallVector innerAttrs; innerAttrs.reserve(nextAcc->get().size()); for (Operation *nextAccOp : nextAcc->get()) { if (auto nextAccTag = nextAccOp->getAttrOfType(kTagAttrName)) { innerAttrs.push_back(nextAccTag); continue; } std::string repr; llvm::raw_string_ostream os(repr); nextAccOp->print(os); innerAttrs.push_back(StringAttr::get(op->getContext(), os.str())); } attrs.push_back(ArrayAttr::get(op->getContext(), innerAttrs)); } return ArrayAttr::get(op->getContext(), attrs); } void runOnOperation() override { Operation *op = getOperation(); SymbolTableCollection symbolTable; auto config = DataFlowConfig().setInterprocedural(interprocedural); DataFlowSolver solver(config); solver.load(); solver.load(symbolTable, assumeFuncReads); solver.load(); solver.load(); if (failed(solver.initializeAndRun(op))) { emitError(op->getLoc(), "dataflow solver failed"); return signalPassFailure(); } op->walk([&](Operation *op) { auto tag = op->getAttrOfType(kTagAttrName); if (!tag) return; const NextAccess *nextAccess = solver.lookupState( op->getNextNode() == nullptr ? ProgramPoint(op->getBlock()) : op->getNextNode()); op->setAttr(kNextAccessAttrName, makeNextAccessAttribute(op, solver, nextAccess)); auto iface = dyn_cast(op); if (!iface) return; SmallVector entryPointNextAccess; SmallVector regionSuccessors; iface.getSuccessorRegions(RegionBranchPoint::parent(), regionSuccessors); for (const RegionSuccessor &successor : regionSuccessors) { if (!successor.getSuccessor() || successor.getSuccessor()->empty()) continue; Block &successorBlock = successor.getSuccessor()->front(); ProgramPoint successorPoint = successorBlock.empty() ? ProgramPoint(&successorBlock) : &successorBlock.front(); entryPointNextAccess.push_back(makeNextAccessAttribute( op, solver, solver.lookupState(successorPoint))); } op->setAttr(kAtEntryPointAttrName, ArrayAttr::get(op->getContext(), entryPointNextAccess)); }); } }; } // namespace namespace mlir::test { void registerTestNextAccessPass() { PassRegistration(); } } // namespace mlir::test