bolt/deps/llvm-18.1.8/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp
2025-02-14 19:21:04 +01:00

282 lines
11 KiB
C++

//===- TestDenseForwardDataFlowAnalysis.cpp -------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Implementation of tests passes exercising dense forward data flow analysis.
//
//===----------------------------------------------------------------------===//
#include "TestDenseDataFlowAnalysis.h"
#include "TestDialect.h"
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/TypeSwitch.h"
#include <optional>
using namespace mlir;
using namespace mlir::dataflow;
using namespace mlir::dataflow::test;
namespace {
/// This lattice represents, for a given memory resource, the potential last
/// operations that modified the resource.
class LastModification : public AbstractDenseLattice, public AccessLatticeBase {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification)
using AbstractDenseLattice::AbstractDenseLattice;
/// Join the last modifications.
ChangeResult join(const AbstractDenseLattice &lattice) override {
return AccessLatticeBase::merge(static_cast<AccessLatticeBase>(
static_cast<const LastModification &>(lattice)));
}
void print(raw_ostream &os) const override {
return AccessLatticeBase::print(os);
}
};
class LastModifiedAnalysis
: public DenseForwardDataFlowAnalysis<LastModification> {
public:
explicit LastModifiedAnalysis(DataFlowSolver &solver, bool assumeFuncWrites)
: DenseForwardDataFlowAnalysis(solver),
assumeFuncWrites(assumeFuncWrites) {}
/// Visit an operation. If the operation has no memory effects, then the state
/// is propagated with no change. If the operation allocates a resource, then
/// its reaching definitions is set to empty. If the operation writes to a
/// resource, then its reaching definition is set to the written value.
void visitOperation(Operation *op, const LastModification &before,
LastModification *after) override;
void visitCallControlFlowTransfer(CallOpInterface call,
CallControlFlowAction action,
const LastModification &before,
LastModification *after) override;
void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch,
std::optional<unsigned> regionFrom,
std::optional<unsigned> regionTo,
const LastModification &before,
LastModification *after) override;
/// At an entry point, the last modifications of all memory resources are
/// unknown.
void setToEntryState(LastModification *lattice) override {
propagateIfChanged(lattice, lattice->reset());
}
private:
const bool assumeFuncWrites;
};
} // end anonymous namespace
void LastModifiedAnalysis::visitOperation(Operation *op,
const LastModification &before,
LastModification *after) {
auto memory = dyn_cast<MemoryEffectOpInterface>(op);
// If we can't reason about the memory effects, then conservatively assume we
// can't deduce anything about the last modifications.
if (!memory)
return setToEntryState(after);
SmallVector<MemoryEffects::EffectInstance> effects;
memory.getEffects(effects);
// First, check if all underlying values are already known. Otherwise, avoid
// propagating and stay in the "undefined" state to avoid incorrectly
// propagating values that may be overwritten later on as that could be
// problematic for convergence based on monotonicity of lattice updates.
SmallVector<Value> underlyingValues;
underlyingValues.reserve(effects.size());
for (const auto &effect : effects) {
Value value = effect.getValue();
// If we see an effect on anything other than a value, assume we can't
// deduce anything about the last modifications.
if (!value)
return setToEntryState(after);
// If we cannot find the underlying value, we shouldn't just propagate the
// effects through, return the pessimistic state.
std::optional<Value> underlyingValue =
UnderlyingValueAnalysis::getMostUnderlyingValue(
value, [&](Value value) {
return getOrCreateFor<UnderlyingValueLattice>(op, value);
});
// If the underlying value is not yet known, don't propagate yet.
if (!underlyingValue)
return;
underlyingValues.push_back(*underlyingValue);
}
// Update the state when all underlying values are known.
ChangeResult result = after->join(before);
for (const auto &[effect, value] : llvm::zip(effects, underlyingValues)) {
// If the underlying value is known to be unknown, set to fixpoint state.
if (!value)
return setToEntryState(after);
// Nothing to do for reads.
if (isa<MemoryEffects::Read>(effect.getEffect()))
continue;
result |= after->set(value, op);
}
propagateIfChanged(after, result);
}
void LastModifiedAnalysis::visitCallControlFlowTransfer(
CallOpInterface call, CallControlFlowAction action,
const LastModification &before, LastModification *after) {
if (action == CallControlFlowAction::ExternalCallee && assumeFuncWrites) {
SmallVector<Value> underlyingValues;
underlyingValues.reserve(call->getNumOperands());
for (Value operand : call.getArgOperands()) {
std::optional<Value> underlyingValue =
UnderlyingValueAnalysis::getMostUnderlyingValue(
operand, [&](Value value) {
return getOrCreateFor<UnderlyingValueLattice>(
call.getOperation(), value);
});
if (!underlyingValue)
return;
underlyingValues.push_back(*underlyingValue);
}
ChangeResult result = after->join(before);
for (Value operand : underlyingValues)
result |= after->set(operand, call);
return propagateIfChanged(after, result);
}
auto testCallAndStore =
dyn_cast<::test::TestCallAndStoreOp>(call.getOperation());
if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee &&
testCallAndStore.getStoreBeforeCall()) ||
(action == CallControlFlowAction::ExitCallee &&
!testCallAndStore.getStoreBeforeCall()))) {
return visitOperation(call, before, after);
}
AbstractDenseForwardDataFlowAnalysis::visitCallControlFlowTransfer(
call, action, before, after);
}
void LastModifiedAnalysis::visitRegionBranchControlFlowTransfer(
RegionBranchOpInterface branch, std::optional<unsigned> regionFrom,
std::optional<unsigned> regionTo, const LastModification &before,
LastModification *after) {
auto defaultHandling = [&]() {
AbstractDenseForwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer(
branch, regionFrom, regionTo, before, after);
};
TypeSwitch<Operation *>(branch.getOperation())
.Case<::test::TestStoreWithARegion, ::test::TestStoreWithALoopRegion>(
[=](auto storeWithRegion) {
if ((!regionTo && !storeWithRegion.getStoreBeforeRegion()) ||
(!regionFrom && storeWithRegion.getStoreBeforeRegion()))
visitOperation(branch, before, after);
defaultHandling();
})
.Default([=](auto) { defaultHandling(); });
}
namespace {
struct TestLastModifiedPass
: public PassWrapper<TestLastModifiedPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass)
TestLastModifiedPass() = default;
TestLastModifiedPass(const TestLastModifiedPass &other) : PassWrapper(other) {
interprocedural = other.interprocedural;
assumeFuncWrites = other.assumeFuncWrites;
}
StringRef getArgument() const override { return "test-last-modified"; }
Option<bool> interprocedural{
*this, "interprocedural", llvm::cl::init(true),
llvm::cl::desc("perform interprocedural analysis")};
Option<bool> assumeFuncWrites{
*this, "assume-func-writes", llvm::cl::init(false),
llvm::cl::desc(
"assume external functions have write effect on all arguments")};
void runOnOperation() override {
Operation *op = getOperation();
DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural));
solver.load<DeadCodeAnalysis>();
solver.load<SparseConstantPropagation>();
solver.load<LastModifiedAnalysis>(assumeFuncWrites);
solver.load<UnderlyingValueAnalysis>();
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
raw_ostream &os = llvm::errs();
// Note that if the underlying value could not be computed or is unknown, we
// conservatively treat the result also unknown.
op->walk([&](Operation *op) {
auto tag = op->getAttrOfType<StringAttr>("tag");
if (!tag)
return;
os << "test_tag: " << tag.getValue() << ":\n";
const LastModification *lastMods =
solver.lookupState<LastModification>(op);
assert(lastMods && "expected a dense lattice");
for (auto [index, operand] : llvm::enumerate(op->getOperands())) {
os << " operand #" << index << "\n";
std::optional<Value> underlyingValue =
UnderlyingValueAnalysis::getMostUnderlyingValue(
operand, [&](Value value) {
return solver.lookupState<UnderlyingValueLattice>(value);
});
if (!underlyingValue) {
os << " - <unknown>\n";
continue;
}
Value value = *underlyingValue;
assert(value && "expected an underlying value");
if (const AdjacentAccess *lastMod =
lastMods->getAdjacentAccess(value)) {
if (!lastMod->isKnown()) {
os << " - <unknown>\n";
} else {
for (Operation *lastModifier : lastMod->get()) {
if (auto tagName =
lastModifier->getAttrOfType<StringAttr>("tag_name")) {
os << " - " << tagName.getValue() << "\n";
} else {
os << " - " << lastModifier->getName() << "\n";
}
}
}
} else {
os << " - <unknown>\n";
}
}
});
}
};
} // end anonymous namespace
namespace mlir {
namespace test {
void registerTestLastModifiedPass() {
PassRegistration<TestLastModifiedPass>();
}
} // end namespace test
} // end namespace mlir