//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/Utils/InferIntRangeCommon.h" #include "llvm/Support/Debug.h" #include #define DEBUG_TYPE "int-range-analysis" using namespace mlir; using namespace mlir::arith; using namespace mlir::intrange; //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// void arith::ConstantOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { auto constAttr = llvm::dyn_cast_or_null(getValue()); if (constAttr) { const APInt &value = constAttr.getValue(); setResultRange(getResult(), ConstantIntRanges::constant(value)); } } //===----------------------------------------------------------------------===// // AddIOp //===----------------------------------------------------------------------===// void arith::AddIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferAdd(argRanges)); } //===----------------------------------------------------------------------===// // SubIOp //===----------------------------------------------------------------------===// void arith::SubIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferSub(argRanges)); } //===----------------------------------------------------------------------===// // MulIOp //===----------------------------------------------------------------------===// void arith::MulIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferMul(argRanges)); } //===----------------------------------------------------------------------===// // DivUIOp //===----------------------------------------------------------------------===// void arith::DivUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferDivU(argRanges)); } //===----------------------------------------------------------------------===// // DivSIOp //===----------------------------------------------------------------------===// void arith::DivSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferDivS(argRanges)); } //===----------------------------------------------------------------------===// // CeilDivUIOp //===----------------------------------------------------------------------===// void arith::CeilDivUIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferCeilDivU(argRanges)); } //===----------------------------------------------------------------------===// // CeilDivSIOp //===----------------------------------------------------------------------===// void arith::CeilDivSIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferCeilDivS(argRanges)); } //===----------------------------------------------------------------------===// // FloorDivSIOp //===----------------------------------------------------------------------===// void arith::FloorDivSIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { return setResultRange(getResult(), inferFloorDivS(argRanges)); } //===----------------------------------------------------------------------===// // RemUIOp //===----------------------------------------------------------------------===// void arith::RemUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferRemU(argRanges)); } //===----------------------------------------------------------------------===// // RemSIOp //===----------------------------------------------------------------------===// void arith::RemSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferRemS(argRanges)); } //===----------------------------------------------------------------------===// // AndIOp //===----------------------------------------------------------------------===// void arith::AndIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferAnd(argRanges)); } //===----------------------------------------------------------------------===// // OrIOp //===----------------------------------------------------------------------===// void arith::OrIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferOr(argRanges)); } //===----------------------------------------------------------------------===// // XOrIOp //===----------------------------------------------------------------------===// void arith::XOrIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferXor(argRanges)); } //===----------------------------------------------------------------------===// // MaxSIOp //===----------------------------------------------------------------------===// void arith::MaxSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferMaxS(argRanges)); } //===----------------------------------------------------------------------===// // MaxUIOp //===----------------------------------------------------------------------===// void arith::MaxUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferMaxU(argRanges)); } //===----------------------------------------------------------------------===// // MinSIOp //===----------------------------------------------------------------------===// void arith::MinSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferMinS(argRanges)); } //===----------------------------------------------------------------------===// // MinUIOp //===----------------------------------------------------------------------===// void arith::MinUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferMinU(argRanges)); } //===----------------------------------------------------------------------===// // ExtUIOp //===----------------------------------------------------------------------===// void arith::ExtUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { unsigned destWidth = ConstantIntRanges::getStorageBitwidth(getResult().getType()); setResultRange(getResult(), extUIRange(argRanges[0], destWidth)); } //===----------------------------------------------------------------------===// // ExtSIOp //===----------------------------------------------------------------------===// void arith::ExtSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { unsigned destWidth = ConstantIntRanges::getStorageBitwidth(getResult().getType()); setResultRange(getResult(), extSIRange(argRanges[0], destWidth)); } //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// void arith::TruncIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { unsigned destWidth = ConstantIntRanges::getStorageBitwidth(getResult().getType()); setResultRange(getResult(), truncRange(argRanges[0], destWidth)); } //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// void arith::IndexCastOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { Type sourceType = getOperand().getType(); Type destType = getResult().getType(); unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType); unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); if (srcWidth < destWidth) setResultRange(getResult(), extSIRange(argRanges[0], destWidth)); else if (srcWidth > destWidth) setResultRange(getResult(), truncRange(argRanges[0], destWidth)); else setResultRange(getResult(), argRanges[0]); } //===----------------------------------------------------------------------===// // IndexCastUIOp //===----------------------------------------------------------------------===// void arith::IndexCastUIOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRange) { Type sourceType = getOperand().getType(); Type destType = getResult().getType(); unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType); unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType); if (srcWidth < destWidth) setResultRange(getResult(), extUIRange(argRanges[0], destWidth)); else if (srcWidth > destWidth) setResultRange(getResult(), truncRange(argRanges[0], destWidth)); else setResultRange(getResult(), argRanges[0]); } //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// void arith::CmpIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { arith::CmpIPredicate arithPred = getPredicate(); intrange::CmpPredicate pred = static_cast(arithPred); const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1]; APInt min = APInt::getZero(1); APInt max = APInt::getAllOnes(1); std::optional truthValue = intrange::evaluatePred(pred, lhs, rhs); if (truthValue.has_value() && *truthValue) min = max; else if (truthValue.has_value() && !(*truthValue)) max = min; setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max)); } //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// void arith::SelectOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { std::optional mbCondVal = argRanges[0].getConstantValue(); if (mbCondVal) { if (mbCondVal->isZero()) setResultRange(getResult(), argRanges[2]); else setResultRange(getResult(), argRanges[1]); return; } setResultRange(getResult(), argRanges[1].rangeUnion(argRanges[2])); } //===----------------------------------------------------------------------===// // ShLIOp //===----------------------------------------------------------------------===// void arith::ShLIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferShl(argRanges)); } //===----------------------------------------------------------------------===// // ShRUIOp //===----------------------------------------------------------------------===// void arith::ShRUIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferShrU(argRanges)); } //===----------------------------------------------------------------------===// // ShRSIOp //===----------------------------------------------------------------------===// void arith::ShRSIOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { setResultRange(getResult(), inferShrS(argRanges)); }