bolt/deps/llvm-18.1.8/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h
2025-02-14 19:21:04 +01:00

227 lines
6.9 KiB
C++

//===- TestDenseDataFlowAnalysis.h - Dataflow test utilities ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
namespace mlir {
namespace dataflow {
namespace test {
/// This lattice represents a single underlying value for an SSA value.
class UnderlyingValue {
public:
/// Create an underlying value state with a known underlying value.
explicit UnderlyingValue(std::optional<Value> underlyingValue = std::nullopt)
: underlyingValue(underlyingValue) {}
/// Whether the state is uninitialized.
bool isUninitialized() const { return !underlyingValue.has_value(); }
/// Returns the underlying value.
Value getUnderlyingValue() const {
assert(!isUninitialized());
return *underlyingValue;
}
/// Join two underlying values. If there are conflicting underlying values,
/// go to the pessimistic value.
static UnderlyingValue join(const UnderlyingValue &lhs,
const UnderlyingValue &rhs) {
if (lhs.isUninitialized())
return rhs;
if (rhs.isUninitialized())
return lhs;
return lhs.underlyingValue == rhs.underlyingValue
? lhs
: UnderlyingValue(Value{});
}
/// Compare underlying values.
bool operator==(const UnderlyingValue &rhs) const {
return underlyingValue == rhs.underlyingValue;
}
void print(raw_ostream &os) const { os << underlyingValue; }
private:
std::optional<Value> underlyingValue;
};
class AdjacentAccess {
public:
using DeterministicSetVector =
SetVector<Operation *, SmallVector<Operation *, 2>,
SmallPtrSet<Operation *, 2>>;
ArrayRef<Operation *> get() const { return accesses.getArrayRef(); }
bool isKnown() const { return !unknown; }
ChangeResult merge(const AdjacentAccess &other) {
if (unknown)
return ChangeResult::NoChange;
if (other.unknown) {
unknown = true;
accesses.clear();
return ChangeResult::Change;
}
size_t sizeBefore = accesses.size();
accesses.insert(other.accesses.begin(), other.accesses.end());
return accesses.size() == sizeBefore ? ChangeResult::NoChange
: ChangeResult::Change;
}
ChangeResult set(Operation *op) {
if (!unknown && accesses.size() == 1 && *accesses.begin() == op)
return ChangeResult::NoChange;
unknown = false;
accesses.clear();
accesses.insert(op);
return ChangeResult::Change;
}
ChangeResult setUnknown() {
if (unknown)
return ChangeResult::NoChange;
accesses.clear();
unknown = true;
return ChangeResult::Change;
}
bool operator==(const AdjacentAccess &other) const {
return unknown == other.unknown && accesses == other.accesses;
}
bool operator!=(const AdjacentAccess &other) const {
return !operator==(other);
}
private:
bool unknown = false;
DeterministicSetVector accesses;
};
/// This lattice represents, for a given memory resource, the potential last
/// operations that modified the resource.
class AccessLatticeBase {
public:
/// Clear all modifications.
ChangeResult reset() {
if (adjAccesses.empty())
return ChangeResult::NoChange;
adjAccesses.clear();
return ChangeResult::Change;
}
/// Join the last modifications.
ChangeResult merge(const AccessLatticeBase &rhs) {
ChangeResult result = ChangeResult::NoChange;
for (const auto &mod : rhs.adjAccesses) {
AdjacentAccess &lhsMod = adjAccesses[mod.first];
result |= lhsMod.merge(mod.second);
}
return result;
}
/// Set the last modification of a value.
ChangeResult set(Value value, Operation *op) {
AdjacentAccess &lastMod = adjAccesses[value];
return lastMod.set(op);
}
ChangeResult setKnownToUnknown() {
ChangeResult result = ChangeResult::NoChange;
for (auto &[value, adjacent] : adjAccesses)
result |= adjacent.setUnknown();
return result;
}
/// Get the adjacent accesses to a value. Returns std::nullopt if they
/// are not known.
const AdjacentAccess *getAdjacentAccess(Value value) const {
auto it = adjAccesses.find(value);
if (it == adjAccesses.end())
return nullptr;
return &it->getSecond();
}
void print(raw_ostream &os) const {
for (const auto &lastMod : adjAccesses) {
os << lastMod.first << ":\n";
if (!lastMod.second.isKnown()) {
os << " <unknown>\n";
return;
}
for (Operation *op : lastMod.second.get())
os << " " << *op << "\n";
}
}
private:
/// The potential adjacent accesses to a memory resource. Use a set vector to
/// keep the results deterministic.
DenseMap<Value, AdjacentAccess> adjAccesses;
};
/// Define the lattice class explicitly to provide a type ID.
struct UnderlyingValueLattice : public Lattice<UnderlyingValue> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueLattice)
using Lattice::Lattice;
};
/// An analysis that uses forwarding of values along control-flow and callgraph
/// edges to determine single underlying values for block arguments. This
/// analysis exists so that the test analysis and pass can test the behaviour of
/// the dense data-flow analysis on the callgraph.
class UnderlyingValueAnalysis
: public SparseForwardDataFlowAnalysis<UnderlyingValueLattice> {
public:
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
/// The underlying value of the results of an operation are not known.
void visitOperation(Operation *op,
ArrayRef<const UnderlyingValueLattice *> operands,
ArrayRef<UnderlyingValueLattice *> results) override {
setAllToEntryStates(results);
}
/// At an entry point, the underlying value of a value is itself.
void setToEntryState(UnderlyingValueLattice *lattice) override {
propagateIfChanged(lattice,
lattice->join(UnderlyingValue{lattice->getPoint()}));
}
/// Look for the most underlying value of a value.
static std::optional<Value>
getMostUnderlyingValue(Value value,
function_ref<const UnderlyingValueLattice *(Value)>
getUnderlyingValueFn) {
const UnderlyingValueLattice *underlying;
do {
underlying = getUnderlyingValueFn(value);
if (!underlying || underlying->getValue().isUninitialized())
return std::nullopt;
Value underlyingValue = underlying->getValue().getUnderlyingValue();
if (underlyingValue == value)
break;
value = underlyingValue;
} while (true);
return value;
}
};
} // namespace test
} // namespace dataflow
} // namespace mlir