//===- TestBackwardDataFlowAnalysis.cpp - Test dead code analysis ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" using namespace mlir; using namespace mlir::dataflow; namespace { /// This lattice represents, for a given value, the set of memory resources that /// this value, or anything derived from this value, is potentially written to. struct WrittenTo : public AbstractSparseLattice { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo) using AbstractSparseLattice::AbstractSparseLattice; void print(raw_ostream &os) const override { os << "["; llvm::interleave( writes, os, [&](const StringAttr &a) { os << a.str(); }, " "); os << "]"; } ChangeResult addWrites(const SetVector &writes) { int sizeBefore = this->writes.size(); this->writes.insert(writes.begin(), writes.end()); int sizeAfter = this->writes.size(); return sizeBefore == sizeAfter ? ChangeResult::NoChange : ChangeResult::Change; } ChangeResult meet(const AbstractSparseLattice &other) override { const auto *rhs = reinterpret_cast(&other); return addWrites(rhs->writes); } SetVector writes; }; /// An analysis that, by going backwards along the dataflow graph, annotates /// each value with all the memory resources it (or anything derived from it) /// is eventually written to. class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis { public: WrittenToAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable, bool assumeFuncWrites) : SparseBackwardDataFlowAnalysis(solver, symbolTable), assumeFuncWrites(assumeFuncWrites) {} void visitOperation(Operation *op, ArrayRef operands, ArrayRef results) override; void visitBranchOperand(OpOperand &operand) override; void visitCallOperand(OpOperand &operand) override; void visitExternalCall(CallOpInterface call, ArrayRef operands, ArrayRef results) override; void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); } private: bool assumeFuncWrites; }; void WrittenToAnalysis::visitOperation(Operation *op, ArrayRef operands, ArrayRef results) { if (auto store = dyn_cast(op)) { SetVector newWrites; newWrites.insert(op->getAttrOfType("tag_name")); propagateIfChanged(operands[0], operands[0]->addWrites(newWrites)); return; } // By default, every result of an op depends on every operand. for (const WrittenTo *r : results) { for (WrittenTo *operand : operands) { meet(operand, *r); } addDependency(const_cast(r), op); } } void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) { // Mark branch operands as "brancharg%d", with %d the operand number. WrittenTo *lattice = getLatticeElement(operand.get()); SetVector newWrites; newWrites.insert( StringAttr::get(operand.getOwner()->getContext(), "brancharg" + Twine(operand.getOperandNumber()))); propagateIfChanged(lattice, lattice->addWrites(newWrites)); } void WrittenToAnalysis::visitCallOperand(OpOperand &operand) { // Mark call operands as "callarg%d", with %d the operand number. WrittenTo *lattice = getLatticeElement(operand.get()); SetVector newWrites; newWrites.insert( StringAttr::get(operand.getOwner()->getContext(), "callarg" + Twine(operand.getOperandNumber()))); propagateIfChanged(lattice, lattice->addWrites(newWrites)); } void WrittenToAnalysis::visitExternalCall(CallOpInterface call, ArrayRef operands, ArrayRef results) { if (!assumeFuncWrites) { return SparseBackwardDataFlowAnalysis::visitExternalCall(call, operands, results); } for (WrittenTo *lattice : operands) { SetVector newWrites; StringAttr name = call->getAttrOfType("tag_name"); if (!name) { name = StringAttr::get(call->getContext(), call.getOperation()->getName().getStringRef()); } newWrites.insert(name); propagateIfChanged(lattice, lattice->addWrites(newWrites)); } } } // end anonymous namespace namespace { struct TestWrittenToPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrittenToPass) TestWrittenToPass() = default; TestWrittenToPass(const TestWrittenToPass &other) : PassWrapper(other) { interprocedural = other.interprocedural; assumeFuncWrites = other.assumeFuncWrites; } StringRef getArgument() const override { return "test-written-to"; } Option interprocedural{ *this, "interprocedural", llvm::cl::init(true), llvm::cl::desc("perform interprocedural analysis")}; Option assumeFuncWrites{ *this, "assume-func-writes", llvm::cl::init(false), llvm::cl::desc( "assume external functions have write effect on all arguments")}; void runOnOperation() override { Operation *op = getOperation(); SymbolTableCollection symbolTable; DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural)); solver.load(); solver.load(); solver.load(symbolTable, assumeFuncWrites); if (failed(solver.initializeAndRun(op))) return signalPassFailure(); raw_ostream &os = llvm::outs(); op->walk([&](Operation *op) { auto tag = op->getAttrOfType("tag"); if (!tag) return; os << "test_tag: " << tag.getValue() << ":\n"; for (auto [index, operand] : llvm::enumerate(op->getOperands())) { const WrittenTo *writtenTo = solver.lookupState(operand); assert(writtenTo && "expected a sparse lattice"); os << " operand #" << index << ": "; writtenTo->print(os); os << "\n"; } for (auto [index, operand] : llvm::enumerate(op->getResults())) { const WrittenTo *writtenTo = solver.lookupState(operand); assert(writtenTo && "expected a sparse lattice"); os << " result #" << index << ": "; writtenTo->print(os); os << "\n"; } }); } }; } // end anonymous namespace namespace mlir { namespace test { void registerTestWrittenToPass() { PassRegistration(); } } // end namespace test } // end namespace mlir