bolt/deps/llvm-18.1.8/mlir/test/lib/Analysis/DataFlow/TestSparseBackwardDataFlowAnalysis.cpp

196 lines
7.2 KiB
C++
Raw Normal View History

2025-02-14 19:21:04 +01:00
//===- TestBackwardDataFlowAnalysis.cpp - Test dead code analysis ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
using namespace mlir::dataflow;
namespace {
/// This lattice represents, for a given value, the set of memory resources that
/// this value, or anything derived from this value, is potentially written to.
struct WrittenTo : public AbstractSparseLattice {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo)
using AbstractSparseLattice::AbstractSparseLattice;
void print(raw_ostream &os) const override {
os << "[";
llvm::interleave(
writes, os, [&](const StringAttr &a) { os << a.str(); }, " ");
os << "]";
}
ChangeResult addWrites(const SetVector<StringAttr> &writes) {
int sizeBefore = this->writes.size();
this->writes.insert(writes.begin(), writes.end());
int sizeAfter = this->writes.size();
return sizeBefore == sizeAfter ? ChangeResult::NoChange
: ChangeResult::Change;
}
ChangeResult meet(const AbstractSparseLattice &other) override {
const auto *rhs = reinterpret_cast<const WrittenTo *>(&other);
return addWrites(rhs->writes);
}
SetVector<StringAttr> writes;
};
/// An analysis that, by going backwards along the dataflow graph, annotates
/// each value with all the memory resources it (or anything derived from it)
/// is eventually written to.
class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> {
public:
WrittenToAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable,
bool assumeFuncWrites)
: SparseBackwardDataFlowAnalysis(solver, symbolTable),
assumeFuncWrites(assumeFuncWrites) {}
void visitOperation(Operation *op, ArrayRef<WrittenTo *> operands,
ArrayRef<const WrittenTo *> results) override;
void visitBranchOperand(OpOperand &operand) override;
void visitCallOperand(OpOperand &operand) override;
void visitExternalCall(CallOpInterface call, ArrayRef<WrittenTo *> operands,
ArrayRef<const WrittenTo *> results) override;
void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); }
private:
bool assumeFuncWrites;
};
void WrittenToAnalysis::visitOperation(Operation *op,
ArrayRef<WrittenTo *> operands,
ArrayRef<const WrittenTo *> results) {
if (auto store = dyn_cast<memref::StoreOp>(op)) {
SetVector<StringAttr> newWrites;
newWrites.insert(op->getAttrOfType<StringAttr>("tag_name"));
propagateIfChanged(operands[0], operands[0]->addWrites(newWrites));
return;
} // By default, every result of an op depends on every operand.
for (const WrittenTo *r : results) {
for (WrittenTo *operand : operands) {
meet(operand, *r);
}
addDependency(const_cast<WrittenTo *>(r), op);
}
}
void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) {
// Mark branch operands as "brancharg%d", with %d the operand number.
WrittenTo *lattice = getLatticeElement(operand.get());
SetVector<StringAttr> newWrites;
newWrites.insert(
StringAttr::get(operand.getOwner()->getContext(),
"brancharg" + Twine(operand.getOperandNumber())));
propagateIfChanged(lattice, lattice->addWrites(newWrites));
}
void WrittenToAnalysis::visitCallOperand(OpOperand &operand) {
// Mark call operands as "callarg%d", with %d the operand number.
WrittenTo *lattice = getLatticeElement(operand.get());
SetVector<StringAttr> newWrites;
newWrites.insert(
StringAttr::get(operand.getOwner()->getContext(),
"callarg" + Twine(operand.getOperandNumber())));
propagateIfChanged(lattice, lattice->addWrites(newWrites));
}
void WrittenToAnalysis::visitExternalCall(CallOpInterface call,
ArrayRef<WrittenTo *> operands,
ArrayRef<const WrittenTo *> results) {
if (!assumeFuncWrites) {
return SparseBackwardDataFlowAnalysis::visitExternalCall(call, operands,
results);
}
for (WrittenTo *lattice : operands) {
SetVector<StringAttr> newWrites;
StringAttr name = call->getAttrOfType<StringAttr>("tag_name");
if (!name) {
name = StringAttr::get(call->getContext(),
call.getOperation()->getName().getStringRef());
}
newWrites.insert(name);
propagateIfChanged(lattice, lattice->addWrites(newWrites));
}
}
} // end anonymous namespace
namespace {
struct TestWrittenToPass
: public PassWrapper<TestWrittenToPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrittenToPass)
TestWrittenToPass() = default;
TestWrittenToPass(const TestWrittenToPass &other) : PassWrapper(other) {
interprocedural = other.interprocedural;
assumeFuncWrites = other.assumeFuncWrites;
}
StringRef getArgument() const override { return "test-written-to"; }
Option<bool> interprocedural{
*this, "interprocedural", llvm::cl::init(true),
llvm::cl::desc("perform interprocedural analysis")};
Option<bool> assumeFuncWrites{
*this, "assume-func-writes", llvm::cl::init(false),
llvm::cl::desc(
"assume external functions have write effect on all arguments")};
void runOnOperation() override {
Operation *op = getOperation();
SymbolTableCollection symbolTable;
DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural));
solver.load<DeadCodeAnalysis>();
solver.load<SparseConstantPropagation>();
solver.load<WrittenToAnalysis>(symbolTable, assumeFuncWrites);
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
raw_ostream &os = llvm::outs();
op->walk([&](Operation *op) {
auto tag = op->getAttrOfType<StringAttr>("tag");
if (!tag)
return;
os << "test_tag: " << tag.getValue() << ":\n";
for (auto [index, operand] : llvm::enumerate(op->getOperands())) {
const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand);
assert(writtenTo && "expected a sparse lattice");
os << " operand #" << index << ": ";
writtenTo->print(os);
os << "\n";
}
for (auto [index, operand] : llvm::enumerate(op->getResults())) {
const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand);
assert(writtenTo && "expected a sparse lattice");
os << " result #" << index << ": ";
writtenTo->print(os);
os << "\n";
}
});
}
};
} // end anonymous namespace
namespace mlir {
namespace test {
void registerTestWrittenToPass() { PassRegistration<TestWrittenToPass>(); }
} // end namespace test
} // end namespace mlir