bolt/deps/llvm-18.1.8/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp

310 lines
12 KiB
C++
Raw Normal View History

2025-02-14 19:21:04 +01:00
//===- TestDenseBackwardDataFlowAnalysis.cpp - Test pass ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Test pass for backward dense dataflow analysis.
//
//===----------------------------------------------------------------------===//
#include "TestDenseDataFlowAnalysis.h"
#include "TestDialect.h"
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/TypeID.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::dataflow;
using namespace mlir::dataflow::test;
namespace {
class NextAccess : public AbstractDenseLattice, public AccessLatticeBase {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NextAccess)
using dataflow::AbstractDenseLattice::AbstractDenseLattice;
ChangeResult meet(const AbstractDenseLattice &lattice) override {
return AccessLatticeBase::merge(static_cast<AccessLatticeBase>(
static_cast<const NextAccess &>(lattice)));
}
void print(raw_ostream &os) const override {
return AccessLatticeBase::print(os);
}
};
class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> {
public:
NextAccessAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable,
bool assumeFuncReads = false)
: DenseBackwardDataFlowAnalysis(solver, symbolTable),
assumeFuncReads(assumeFuncReads) {}
void visitOperation(Operation *op, const NextAccess &after,
NextAccess *before) override;
void visitCallControlFlowTransfer(CallOpInterface call,
CallControlFlowAction action,
const NextAccess &after,
NextAccess *before) override;
void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
RegionBranchPoint regionFrom,
RegionBranchPoint regionTo,
const NextAccess &after,
NextAccess *before) override;
// TODO: this isn't ideal for the analysis. When there is no next access, it
// means "we don't know what the next access is" rather than "there is no next
// access". But it's unclear how to differentiate the two cases...
void setToExitState(NextAccess *lattice) override {
propagateIfChanged(lattice, lattice->setKnownToUnknown());
}
const bool assumeFuncReads;
};
} // namespace
void NextAccessAnalysis::visitOperation(Operation *op, const NextAccess &after,
NextAccess *before) {
auto memory = dyn_cast<MemoryEffectOpInterface>(op);
// If we can't reason about the memory effects, conservatively assume we can't
// say anything about the next access.
if (!memory)
return setToExitState(before);
SmallVector<MemoryEffects::EffectInstance> effects;
memory.getEffects(effects);
// First, check if all underlying values are already known. Otherwise, avoid
// propagating and stay in the "undefined" state to avoid incorrectly
// propagating values that may be overwritten later on as that could be
// problematic for convergence based on monotonicity of lattice updates.
SmallVector<Value> underlyingValues;
underlyingValues.reserve(effects.size());
for (const MemoryEffects::EffectInstance &effect : effects) {
Value value = effect.getValue();
// Effects with unspecified value are treated conservatively and we cannot
// assume anything about the next access.
if (!value)
return setToExitState(before);
// If cannot find the most underlying value, we cannot assume anything about
// the next accesses.
std::optional<Value> underlyingValue =
UnderlyingValueAnalysis::getMostUnderlyingValue(
value, [&](Value value) {
return getOrCreateFor<UnderlyingValueLattice>(op, value);
});
// If the underlying value is not known yet, don't propagate.
if (!underlyingValue)
return;
underlyingValues.push_back(*underlyingValue);
}
// Update the state if all underlying values are known.
ChangeResult result = before->meet(after);
for (const auto &[effect, value] : llvm::zip(effects, underlyingValues)) {
// If the underlying value is known to be unknown, set to fixpoint.
if (!value)
return setToExitState(before);
result |= before->set(value, op);
}
propagateIfChanged(before, result);
}
void NextAccessAnalysis::visitCallControlFlowTransfer(
CallOpInterface call, CallControlFlowAction action, const NextAccess &after,
NextAccess *before) {
if (action == CallControlFlowAction::ExternalCallee && assumeFuncReads) {
SmallVector<Value> underlyingValues;
underlyingValues.reserve(call->getNumOperands());
for (Value operand : call.getArgOperands()) {
std::optional<Value> underlyingValue =
UnderlyingValueAnalysis::getMostUnderlyingValue(
operand, [&](Value value) {
return getOrCreateFor<UnderlyingValueLattice>(
call.getOperation(), value);
});
if (!underlyingValue)
return;
underlyingValues.push_back(*underlyingValue);
}
ChangeResult result = before->meet(after);
for (Value operand : underlyingValues) {
result |= before->set(operand, call);
}
return propagateIfChanged(before, result);
}
auto testCallAndStore =
dyn_cast<::test::TestCallAndStoreOp>(call.getOperation());
if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee &&
testCallAndStore.getStoreBeforeCall()) ||
(action == CallControlFlowAction::ExitCallee &&
!testCallAndStore.getStoreBeforeCall()))) {
visitOperation(call, after, before);
} else {
AbstractDenseBackwardDataFlowAnalysis::visitCallControlFlowTransfer(
call, action, after, before);
}
}
void NextAccessAnalysis::visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, RegionBranchPoint regionFrom,
RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) {
auto testStoreWithARegion =
dyn_cast<::test::TestStoreWithARegion>(branch.getOperation());
if (testStoreWithARegion &&
((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) ||
(regionFrom.isParent() &&
testStoreWithARegion.getStoreBeforeRegion()))) {
visitOperation(branch, static_cast<const NextAccess &>(after),
static_cast<NextAccess *>(before));
} else {
propagateIfChanged(before, before->meet(after));
}
}
namespace {
struct TestNextAccessPass
: public PassWrapper<TestNextAccessPass, OperationPass<>> {
TestNextAccessPass() = default;
TestNextAccessPass(const TestNextAccessPass &other) : PassWrapper(other) {
interprocedural = other.interprocedural;
assumeFuncReads = other.assumeFuncReads;
}
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNextAccessPass)
StringRef getArgument() const override { return "test-next-access"; }
Option<bool> interprocedural{
*this, "interprocedural", llvm::cl::init(true),
llvm::cl::desc("perform interprocedural analysis")};
Option<bool> assumeFuncReads{
*this, "assume-func-reads", llvm::cl::init(false),
llvm::cl::desc(
"assume external functions have read effect on all arguments")};
static constexpr llvm::StringLiteral kTagAttrName = "name";
static constexpr llvm::StringLiteral kNextAccessAttrName = "next_access";
static constexpr llvm::StringLiteral kAtEntryPointAttrName =
"next_at_entry_point";
static Attribute makeNextAccessAttribute(Operation *op,
const DataFlowSolver &solver,
const NextAccess *nextAccess) {
if (!nextAccess)
return StringAttr::get(op->getContext(), "not computed");
// Note that if the underlying value could not be computed or is unknown, we
// conservatively treat the result also unknown.
SmallVector<Attribute> attrs;
for (Value operand : op->getOperands()) {
std::optional<Value> underlyingValue =
UnderlyingValueAnalysis::getMostUnderlyingValue(
operand, [&](Value value) {
return solver.lookupState<UnderlyingValueLattice>(value);
});
if (!underlyingValue) {
attrs.push_back(StringAttr::get(op->getContext(), "unknown"));
continue;
}
Value value = *underlyingValue;
const AdjacentAccess *nextAcc = nextAccess->getAdjacentAccess(value);
if (!nextAcc || !nextAcc->isKnown()) {
attrs.push_back(StringAttr::get(op->getContext(), "unknown"));
continue;
}
SmallVector<Attribute> innerAttrs;
innerAttrs.reserve(nextAcc->get().size());
for (Operation *nextAccOp : nextAcc->get()) {
if (auto nextAccTag =
nextAccOp->getAttrOfType<StringAttr>(kTagAttrName)) {
innerAttrs.push_back(nextAccTag);
continue;
}
std::string repr;
llvm::raw_string_ostream os(repr);
nextAccOp->print(os);
innerAttrs.push_back(StringAttr::get(op->getContext(), os.str()));
}
attrs.push_back(ArrayAttr::get(op->getContext(), innerAttrs));
}
return ArrayAttr::get(op->getContext(), attrs);
}
void runOnOperation() override {
Operation *op = getOperation();
SymbolTableCollection symbolTable;
auto config = DataFlowConfig().setInterprocedural(interprocedural);
DataFlowSolver solver(config);
solver.load<DeadCodeAnalysis>();
solver.load<NextAccessAnalysis>(symbolTable, assumeFuncReads);
solver.load<SparseConstantPropagation>();
solver.load<UnderlyingValueAnalysis>();
if (failed(solver.initializeAndRun(op))) {
emitError(op->getLoc(), "dataflow solver failed");
return signalPassFailure();
}
op->walk([&](Operation *op) {
auto tag = op->getAttrOfType<StringAttr>(kTagAttrName);
if (!tag)
return;
const NextAccess *nextAccess = solver.lookupState<NextAccess>(
op->getNextNode() == nullptr ? ProgramPoint(op->getBlock())
: op->getNextNode());
op->setAttr(kNextAccessAttrName,
makeNextAccessAttribute(op, solver, nextAccess));
auto iface = dyn_cast<RegionBranchOpInterface>(op);
if (!iface)
return;
SmallVector<Attribute> entryPointNextAccess;
SmallVector<RegionSuccessor> regionSuccessors;
iface.getSuccessorRegions(RegionBranchPoint::parent(), regionSuccessors);
for (const RegionSuccessor &successor : regionSuccessors) {
if (!successor.getSuccessor() || successor.getSuccessor()->empty())
continue;
Block &successorBlock = successor.getSuccessor()->front();
ProgramPoint successorPoint = successorBlock.empty()
? ProgramPoint(&successorBlock)
: &successorBlock.front();
entryPointNextAccess.push_back(makeNextAccessAttribute(
op, solver, solver.lookupState<NextAccess>(successorPoint)));
}
op->setAttr(kAtEntryPointAttrName,
ArrayAttr::get(op->getContext(), entryPointNextAccess));
});
}
};
} // namespace
namespace mlir::test {
void registerTestNextAccessPass() { PassRegistration<TestNextAccessPass>(); }
} // namespace mlir::test