195 lines
7.2 KiB
C++
195 lines
7.2 KiB
C++
//===- TestBackwardDataFlowAnalysis.cpp - Test dead code analysis ---------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
|
|
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
|
|
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::dataflow;
|
|
|
|
namespace {
|
|
|
|
/// This lattice represents, for a given value, the set of memory resources that
|
|
/// this value, or anything derived from this value, is potentially written to.
|
|
struct WrittenTo : public AbstractSparseLattice {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
|
|
using AbstractSparseLattice::AbstractSparseLattice;
|
|
|
|
void print(raw_ostream &os) const override {
|
|
os << "[";
|
|
llvm::interleave(
|
|
writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
|
|
os << "]";
|
|
}
|
|
ChangeResult addWrites(const SetVector<StringAttr> &writes) {
|
|
int sizeBefore = this->writes.size();
|
|
this->writes.insert(writes.begin(), writes.end());
|
|
int sizeAfter = this->writes.size();
|
|
return sizeBefore == sizeAfter ? ChangeResult::NoChange
|
|
: ChangeResult::Change;
|
|
}
|
|
ChangeResult meet(const AbstractSparseLattice &other) override {
|
|
const auto *rhs = reinterpret_cast<const WrittenTo *>(&other);
|
|
return addWrites(rhs->writes);
|
|
}
|
|
|
|
SetVector<StringAttr> writes;
|
|
};
|
|
|
|
/// An analysis that, by going backwards along the dataflow graph, annotates
|
|
/// each value with all the memory resources it (or anything derived from it)
|
|
/// is eventually written to.
|
|
class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
|
|
public:
|
|
WrittenToAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable,
|
|
bool assumeFuncWrites)
|
|
: SparseBackwardDataFlowAnalysis(solver, symbolTable),
|
|
assumeFuncWrites(assumeFuncWrites) {}
|
|
|
|
void visitOperation(Operation *op, ArrayRef<WrittenTo *> operands,
|
|
ArrayRef<const WrittenTo *> results) override;
|
|
|
|
void visitBranchOperand(OpOperand &operand) override;
|
|
|
|
void visitCallOperand(OpOperand &operand) override;
|
|
|
|
void visitExternalCall(CallOpInterface call, ArrayRef<WrittenTo *> operands,
|
|
ArrayRef<const WrittenTo *> results) override;
|
|
|
|
void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); }
|
|
|
|
private:
|
|
bool assumeFuncWrites;
|
|
};
|
|
|
|
void WrittenToAnalysis::visitOperation(Operation *op,
|
|
ArrayRef<WrittenTo *> operands,
|
|
ArrayRef<const WrittenTo *> results) {
|
|
if (auto store = dyn_cast<memref::StoreOp>(op)) {
|
|
SetVector<StringAttr> newWrites;
|
|
newWrites.insert(op->getAttrOfType<StringAttr>("tag_name"));
|
|
propagateIfChanged(operands[0], operands[0]->addWrites(newWrites));
|
|
return;
|
|
} // By default, every result of an op depends on every operand.
|
|
for (const WrittenTo *r : results) {
|
|
for (WrittenTo *operand : operands) {
|
|
meet(operand, *r);
|
|
}
|
|
addDependency(const_cast<WrittenTo *>(r), op);
|
|
}
|
|
}
|
|
|
|
void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) {
|
|
// Mark branch operands as "brancharg%d", with %d the operand number.
|
|
WrittenTo *lattice = getLatticeElement(operand.get());
|
|
SetVector<StringAttr> newWrites;
|
|
newWrites.insert(
|
|
StringAttr::get(operand.getOwner()->getContext(),
|
|
"brancharg" + Twine(operand.getOperandNumber())));
|
|
propagateIfChanged(lattice, lattice->addWrites(newWrites));
|
|
}
|
|
|
|
void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
|
|
// Mark call operands as "callarg%d", with %d the operand number.
|
|
WrittenTo *lattice = getLatticeElement(operand.get());
|
|
SetVector<StringAttr> newWrites;
|
|
newWrites.insert(
|
|
StringAttr::get(operand.getOwner()->getContext(),
|
|
"callarg" + Twine(operand.getOperandNumber())));
|
|
propagateIfChanged(lattice, lattice->addWrites(newWrites));
|
|
}
|
|
|
|
void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
|
|
ArrayRef<WrittenTo *> operands,
|
|
ArrayRef<const WrittenTo *> results) {
|
|
if (!assumeFuncWrites) {
|
|
return SparseBackwardDataFlowAnalysis::visitExternalCall(call, operands,
|
|
results);
|
|
}
|
|
|
|
for (WrittenTo *lattice : operands) {
|
|
SetVector<StringAttr> newWrites;
|
|
StringAttr name = call->getAttrOfType<StringAttr>("tag_name");
|
|
if (!name) {
|
|
name = StringAttr::get(call->getContext(),
|
|
call.getOperation()->getName().getStringRef());
|
|
}
|
|
newWrites.insert(name);
|
|
propagateIfChanged(lattice, lattice->addWrites(newWrites));
|
|
}
|
|
}
|
|
|
|
} // end anonymous namespace
|
|
|
|
namespace {
|
|
struct TestWrittenToPass
|
|
: public PassWrapper<TestWrittenToPass, OperationPass<>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrittenToPass)
|
|
|
|
TestWrittenToPass() = default;
|
|
TestWrittenToPass(const TestWrittenToPass &other) : PassWrapper(other) {
|
|
interprocedural = other.interprocedural;
|
|
assumeFuncWrites = other.assumeFuncWrites;
|
|
}
|
|
|
|
StringRef getArgument() const override { return "test-written-to"; }
|
|
|
|
Option<bool> interprocedural{
|
|
*this, "interprocedural", llvm::cl::init(true),
|
|
llvm::cl::desc("perform interprocedural analysis")};
|
|
Option<bool> assumeFuncWrites{
|
|
*this, "assume-func-writes", llvm::cl::init(false),
|
|
llvm::cl::desc(
|
|
"assume external functions have write effect on all arguments")};
|
|
|
|
void runOnOperation() override {
|
|
Operation *op = getOperation();
|
|
|
|
SymbolTableCollection symbolTable;
|
|
|
|
DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural));
|
|
solver.load<DeadCodeAnalysis>();
|
|
solver.load<SparseConstantPropagation>();
|
|
solver.load<WrittenToAnalysis>(symbolTable, assumeFuncWrites);
|
|
if (failed(solver.initializeAndRun(op)))
|
|
return signalPassFailure();
|
|
|
|
raw_ostream &os = llvm::outs();
|
|
op->walk([&](Operation *op) {
|
|
auto tag = op->getAttrOfType<StringAttr>("tag");
|
|
if (!tag)
|
|
return;
|
|
os << "test_tag: " << tag.getValue() << ":\n";
|
|
for (auto [index, operand] : llvm::enumerate(op->getOperands())) {
|
|
const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand);
|
|
assert(writtenTo && "expected a sparse lattice");
|
|
os << " operand #" << index << ": ";
|
|
writtenTo->print(os);
|
|
os << "\n";
|
|
}
|
|
for (auto [index, operand] : llvm::enumerate(op->getResults())) {
|
|
const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand);
|
|
assert(writtenTo && "expected a sparse lattice");
|
|
os << " result #" << index << ": ";
|
|
writtenTo->print(os);
|
|
os << "\n";
|
|
}
|
|
});
|
|
}
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestWrittenToPass() { PassRegistration<TestWrittenToPass>(); }
|
|
} // end namespace test
|
|
} // end namespace mlir
|