126 lines
4.4 KiB
C++
126 lines
4.4 KiB
C++
|
//===- TestIntRangeInference.cpp - Create consts from range inference ---===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// TODO: This pass is needed to test integer range inference until that
|
||
|
// functionality has been integrated into SCCP.
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
|
||
|
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
|
||
|
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
|
||
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||
|
#include "mlir/Pass/Pass.h"
|
||
|
#include "mlir/Pass/PassRegistry.h"
|
||
|
#include "mlir/Support/TypeID.h"
|
||
|
#include "mlir/Transforms/FoldUtils.h"
|
||
|
#include <optional>
|
||
|
|
||
|
using namespace mlir;
|
||
|
using namespace mlir::dataflow;
|
||
|
|
||
|
/// Patterned after SCCP
|
||
|
static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b,
|
||
|
OperationFolder &folder, Value value) {
|
||
|
auto *maybeInferredRange =
|
||
|
solver.lookupState<IntegerValueRangeLattice>(value);
|
||
|
if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
|
||
|
return failure();
|
||
|
const ConstantIntRanges &inferredRange =
|
||
|
maybeInferredRange->getValue().getValue();
|
||
|
std::optional<APInt> maybeConstValue = inferredRange.getConstantValue();
|
||
|
if (!maybeConstValue.has_value())
|
||
|
return failure();
|
||
|
|
||
|
Operation *maybeDefiningOp = value.getDefiningOp();
|
||
|
Dialect *valueDialect =
|
||
|
maybeDefiningOp ? maybeDefiningOp->getDialect()
|
||
|
: value.getParentRegion()->getParentOp()->getDialect();
|
||
|
Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
|
||
|
Value constant = folder.getOrCreateConstant(
|
||
|
b.getInsertionBlock(), valueDialect, constAttr, value.getType());
|
||
|
if (!constant)
|
||
|
return failure();
|
||
|
|
||
|
value.replaceAllUsesWith(constant);
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
static void rewrite(DataFlowSolver &solver, MLIRContext *context,
|
||
|
MutableArrayRef<Region> initialRegions) {
|
||
|
SmallVector<Block *> worklist;
|
||
|
auto addToWorklist = [&](MutableArrayRef<Region> regions) {
|
||
|
for (Region ®ion : regions)
|
||
|
for (Block &block : llvm::reverse(region))
|
||
|
worklist.push_back(&block);
|
||
|
};
|
||
|
|
||
|
OpBuilder builder(context);
|
||
|
OperationFolder folder(context);
|
||
|
|
||
|
addToWorklist(initialRegions);
|
||
|
while (!worklist.empty()) {
|
||
|
Block *block = worklist.pop_back_val();
|
||
|
|
||
|
for (Operation &op : llvm::make_early_inc_range(*block)) {
|
||
|
builder.setInsertionPoint(&op);
|
||
|
|
||
|
// Replace any result with constants.
|
||
|
bool replacedAll = op.getNumResults() != 0;
|
||
|
for (Value res : op.getResults())
|
||
|
replacedAll &=
|
||
|
succeeded(replaceWithConstant(solver, builder, folder, res));
|
||
|
|
||
|
// If all of the results of the operation were replaced, try to erase
|
||
|
// the operation completely.
|
||
|
if (replacedAll && wouldOpBeTriviallyDead(&op)) {
|
||
|
assert(op.use_empty() && "expected all uses to be replaced");
|
||
|
op.erase();
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
// Add any the regions of this operation to the worklist.
|
||
|
addToWorklist(op.getRegions());
|
||
|
}
|
||
|
|
||
|
// Replace any block arguments with constants.
|
||
|
builder.setInsertionPointToStart(block);
|
||
|
for (BlockArgument arg : block->getArguments())
|
||
|
(void)replaceWithConstant(solver, builder, folder, arg);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
namespace {
|
||
|
struct TestIntRangeInference
|
||
|
: PassWrapper<TestIntRangeInference, OperationPass<>> {
|
||
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference)
|
||
|
|
||
|
StringRef getArgument() const final { return "test-int-range-inference"; }
|
||
|
StringRef getDescription() const final {
|
||
|
return "Test integer range inference analysis";
|
||
|
}
|
||
|
|
||
|
void runOnOperation() override {
|
||
|
Operation *op = getOperation();
|
||
|
DataFlowSolver solver;
|
||
|
solver.load<DeadCodeAnalysis>();
|
||
|
solver.load<SparseConstantPropagation>();
|
||
|
solver.load<IntegerRangeAnalysis>();
|
||
|
if (failed(solver.initializeAndRun(op)))
|
||
|
return signalPassFailure();
|
||
|
rewrite(solver, op->getContext(), op->getRegions());
|
||
|
}
|
||
|
};
|
||
|
} // end anonymous namespace
|
||
|
|
||
|
namespace mlir {
|
||
|
namespace test {
|
||
|
void registerTestIntRangeInference() {
|
||
|
PassRegistration<TestIntRangeInference>();
|
||
|
}
|
||
|
} // end namespace test
|
||
|
} // end namespace mlir
|