bolt/deps/llvm-18.1.8/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
2025-02-14 19:21:04 +01:00

186 lines
5.7 KiB
C++

//===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include <utility>
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir::arith {
#define GEN_PASS_DEF_ARITHINTRANGEOPTS
#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
} // namespace mlir::arith
using namespace mlir;
using namespace mlir::arith;
using namespace mlir::dataflow;
/// Returns true if 2 integer ranges have intersection.
static bool intersects(const ConstantIntRanges &lhs,
const ConstantIntRanges &rhs) {
return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) &&
(lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax())));
}
static FailureOr<bool> handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (!intersects(lhs, rhs))
return false;
return failure();
}
static FailureOr<bool> handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (!intersects(lhs, rhs))
return true;
return failure();
}
static FailureOr<bool> handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (lhs.smax().slt(rhs.smin()))
return true;
if (lhs.smin().sge(rhs.smax()))
return false;
return failure();
}
static FailureOr<bool> handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (lhs.smax().sle(rhs.smin()))
return true;
if (lhs.smin().sgt(rhs.smax()))
return false;
return failure();
}
static FailureOr<bool> handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
return handleSlt(std::move(rhs), std::move(lhs));
}
static FailureOr<bool> handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
return handleSle(std::move(rhs), std::move(lhs));
}
static FailureOr<bool> handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (lhs.umax().ult(rhs.umin()))
return true;
if (lhs.umin().uge(rhs.umax()))
return false;
return failure();
}
static FailureOr<bool> handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (lhs.umax().ule(rhs.umin()))
return true;
if (lhs.umin().ugt(rhs.umax()))
return false;
return failure();
}
static FailureOr<bool> handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
return handleUlt(std::move(rhs), std::move(lhs));
}
static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
return handleUle(std::move(rhs), std::move(lhs));
}
namespace {
struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
: OpRewritePattern<arith::CmpIOp>(context), solver(s) {}
LogicalResult matchAndRewrite(arith::CmpIOp op,
PatternRewriter &rewriter) const override {
auto *lhsResult =
solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getLhs());
if (!lhsResult || lhsResult->getValue().isUninitialized())
return failure();
auto *rhsResult =
solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getRhs());
if (!rhsResult || rhsResult->getValue().isUninitialized())
return failure();
using HandlerFunc =
FailureOr<bool> (*)(ConstantIntRanges, ConstantIntRanges);
std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate() + 1>
handlers{};
using Pred = arith::CmpIPredicate;
handlers[static_cast<size_t>(Pred::eq)] = &handleEq;
handlers[static_cast<size_t>(Pred::ne)] = &handleNe;
handlers[static_cast<size_t>(Pred::slt)] = &handleSlt;
handlers[static_cast<size_t>(Pred::sle)] = &handleSle;
handlers[static_cast<size_t>(Pred::sgt)] = &handleSgt;
handlers[static_cast<size_t>(Pred::sge)] = &handleSge;
handlers[static_cast<size_t>(Pred::ult)] = &handleUlt;
handlers[static_cast<size_t>(Pred::ule)] = &handleUle;
handlers[static_cast<size_t>(Pred::ugt)] = &handleUgt;
handlers[static_cast<size_t>(Pred::uge)] = &handleUge;
HandlerFunc handler = handlers[static_cast<size_t>(op.getPredicate())];
if (!handler)
return failure();
ConstantIntRanges lhsValue = lhsResult->getValue().getValue();
ConstantIntRanges rhsValue = rhsResult->getValue().getValue();
FailureOr<bool> result = handler(lhsValue, rhsValue);
if (failed(result))
return failure();
rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(
op, static_cast<int64_t>(*result), /*width*/ 1);
return success();
}
private:
DataFlowSolver &solver;
};
struct IntRangeOptimizationsPass
: public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
void runOnOperation() override {
Operation *op = getOperation();
MLIRContext *ctx = op->getContext();
DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
solver.load<IntegerRangeAnalysis>();
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
RewritePatternSet patterns(ctx);
populateIntRangeOptimizationsPatterns(patterns, solver);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
void mlir::arith::populateIntRangeOptimizationsPatterns(
RewritePatternSet &patterns, DataFlowSolver &solver) {
patterns.add<ConvertCmpOp>(patterns.getContext(), solver);
}
std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
return std::make_unique<IntRangeOptimizationsPass>();
}