//===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir::arith { #define GEN_PASS_DEF_ARITHINTRANGEOPTS #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" } // namespace mlir::arith using namespace mlir; using namespace mlir::arith; using namespace mlir::dataflow; /// Returns true if 2 integer ranges have intersection. static bool intersects(const ConstantIntRanges &lhs, const ConstantIntRanges &rhs) { return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) && (lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax()))); } static FailureOr handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) { if (!intersects(lhs, rhs)) return false; return failure(); } static FailureOr handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) { if (!intersects(lhs, rhs)) return true; return failure(); } static FailureOr handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) { if (lhs.smax().slt(rhs.smin())) return true; if (lhs.smin().sge(rhs.smax())) return false; return failure(); } static FailureOr handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) { if (lhs.smax().sle(rhs.smin())) return true; if (lhs.smin().sgt(rhs.smax())) return false; return failure(); } static FailureOr handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) { return handleSlt(std::move(rhs), std::move(lhs)); } static FailureOr handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) { return handleSle(std::move(rhs), std::move(lhs)); } static FailureOr handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) { if (lhs.umax().ult(rhs.umin())) return true; if (lhs.umin().uge(rhs.umax())) return false; return failure(); } static FailureOr handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) { if (lhs.umax().ule(rhs.umin())) return true; if (lhs.umin().ugt(rhs.umax())) return false; return failure(); } static FailureOr handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) { return handleUlt(std::move(rhs), std::move(lhs)); } static FailureOr handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) { return handleUle(std::move(rhs), std::move(lhs)); } namespace { struct ConvertCmpOp : public OpRewritePattern { ConvertCmpOp(MLIRContext *context, DataFlowSolver &s) : OpRewritePattern(context), solver(s) {} LogicalResult matchAndRewrite(arith::CmpIOp op, PatternRewriter &rewriter) const override { auto *lhsResult = solver.lookupState(op.getLhs()); if (!lhsResult || lhsResult->getValue().isUninitialized()) return failure(); auto *rhsResult = solver.lookupState(op.getRhs()); if (!rhsResult || rhsResult->getValue().isUninitialized()) return failure(); using HandlerFunc = FailureOr (*)(ConstantIntRanges, ConstantIntRanges); std::array handlers{}; using Pred = arith::CmpIPredicate; handlers[static_cast(Pred::eq)] = &handleEq; handlers[static_cast(Pred::ne)] = &handleNe; handlers[static_cast(Pred::slt)] = &handleSlt; handlers[static_cast(Pred::sle)] = &handleSle; handlers[static_cast(Pred::sgt)] = &handleSgt; handlers[static_cast(Pred::sge)] = &handleSge; handlers[static_cast(Pred::ult)] = &handleUlt; handlers[static_cast(Pred::ule)] = &handleUle; handlers[static_cast(Pred::ugt)] = &handleUgt; handlers[static_cast(Pred::uge)] = &handleUge; HandlerFunc handler = handlers[static_cast(op.getPredicate())]; if (!handler) return failure(); ConstantIntRanges lhsValue = lhsResult->getValue().getValue(); ConstantIntRanges rhsValue = rhsResult->getValue().getValue(); FailureOr result = handler(lhsValue, rhsValue); if (failed(result)) return failure(); rewriter.replaceOpWithNewOp( op, static_cast(*result), /*width*/ 1); return success(); } private: DataFlowSolver &solver; }; struct IntRangeOptimizationsPass : public arith::impl::ArithIntRangeOptsBase { void runOnOperation() override { Operation *op = getOperation(); MLIRContext *ctx = op->getContext(); DataFlowSolver solver; solver.load(); solver.load(); if (failed(solver.initializeAndRun(op))) return signalPassFailure(); RewritePatternSet patterns(ctx); populateIntRangeOptimizationsPatterns(patterns, solver); if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) signalPassFailure(); } }; } // namespace void mlir::arith::populateIntRangeOptimizationsPatterns( RewritePatternSet &patterns, DataFlowSolver &solver) { patterns.add(patterns.getContext(), solver); } std::unique_ptr mlir::arith::createIntRangeOptimizationsPass() { return std::make_unique(); }