253 lines
10 KiB
C++
253 lines
10 KiB
C++
//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Index/IR/IndexOps.h"
|
|
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
|
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
|
|
|
|
#include "llvm/Support/Debug.h"
|
|
#include <optional>
|
|
|
|
#define DEBUG_TYPE "int-range-analysis"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::index;
|
|
using namespace mlir::intrange;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Constants
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
const APInt &value = getValue();
|
|
setResultRange(getResult(), ConstantIntRanges::constant(value));
|
|
}
|
|
|
|
void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
bool value = getValue();
|
|
APInt asInt(/*numBits=*/1, value);
|
|
setResultRange(getResult(), ConstantIntRanges::constant(asInt));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Arithmec operations. All of these operations will have their results inferred
|
|
// using both the 64-bit values and truncated 32-bit values of their inputs,
|
|
// with the results being the union of those inferences, except where the
|
|
// truncation of the 64-bit result is equal to the 32-bit result (at which time
|
|
// we take the 64-bit result).
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both));
|
|
}
|
|
|
|
void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both));
|
|
}
|
|
|
|
void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both));
|
|
}
|
|
|
|
void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
setResultRange(getResult(),
|
|
inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned));
|
|
}
|
|
|
|
void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
setResultRange(getResult(),
|
|
inferIndexOp(inferDivS, argRanges, CmpMode::Signed));
|
|
}
|
|
|
|
void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
setResultRange(getResult(),
|
|
inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned));
|
|
}
|
|
|
|
void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
setResultRange(getResult(),
|
|
inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed));
|
|
}
|
|
|
|
void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
return setResultRange(
|
|
getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed));
|
|
}
|
|
|
|
void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
setResultRange(getResult(),
|
|
inferIndexOp(inferRemS, argRanges, CmpMode::Signed));
|
|
}
|
|
|
|
void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
setResultRange(getResult(),
|
|
inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned));
|
|
}
|
|
|
|
void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
setResultRange(getResult(),
|
|
inferIndexOp(inferMaxS, argRanges, CmpMode::Signed));
|
|
}
|
|
|
|
void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
setResultRange(getResult(),
|
|
inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned));
|
|
}
|
|
|
|
void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
setResultRange(getResult(),
|
|
inferIndexOp(inferMinS, argRanges, CmpMode::Signed));
|
|
}
|
|
|
|
void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
setResultRange(getResult(),
|
|
inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned));
|
|
}
|
|
|
|
void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both));
|
|
}
|
|
|
|
void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
setResultRange(getResult(),
|
|
inferIndexOp(inferShrS, argRanges, CmpMode::Signed));
|
|
}
|
|
|
|
void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
setResultRange(getResult(),
|
|
inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned));
|
|
}
|
|
|
|
void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
setResultRange(getResult(),
|
|
inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned));
|
|
}
|
|
|
|
void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
setResultRange(getResult(),
|
|
inferIndexOp(inferOr, argRanges, CmpMode::Unsigned));
|
|
}
|
|
|
|
void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
setResultRange(getResult(),
|
|
inferIndexOp(inferXor, argRanges, CmpMode::Unsigned));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Casts
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range,
|
|
unsigned srcWidth, unsigned destWidth,
|
|
bool isSigned) {
|
|
if (srcWidth < destWidth)
|
|
return isSigned ? extSIRange(range, destWidth)
|
|
: extUIRange(range, destWidth);
|
|
if (srcWidth > destWidth)
|
|
return truncRange(range, destWidth);
|
|
return range;
|
|
}
|
|
|
|
// When casting to `index`, we will take the union of the possible fixed-width
|
|
// casts.
|
|
static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range,
|
|
Type sourceType, Type destType,
|
|
bool isSigned) {
|
|
unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
|
|
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
|
|
if (sourceType.isIndex())
|
|
return makeLikeDest(range, srcWidth, destWidth, isSigned);
|
|
// We are casting to indexs, so use the union of the 32-bit and 64-bit casts
|
|
ConstantIntRanges storageRange =
|
|
makeLikeDest(range, srcWidth, destWidth, isSigned);
|
|
ConstantIntRanges minWidthRange =
|
|
makeLikeDest(range, srcWidth, indexMinWidth, isSigned);
|
|
ConstantIntRanges minWidthExt = extRange(minWidthRange, destWidth);
|
|
ConstantIntRanges ret = storageRange.rangeUnion(minWidthExt);
|
|
return ret;
|
|
}
|
|
|
|
void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
Type sourceType = getOperand().getType();
|
|
Type destType = getResult().getType();
|
|
setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
|
|
/*isSigned=*/true));
|
|
}
|
|
|
|
void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
Type sourceType = getOperand().getType();
|
|
Type destType = getResult().getType();
|
|
setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
|
|
/*isSigned=*/false));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CmpOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
index::IndexCmpPredicate indexPred = getPred();
|
|
intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred);
|
|
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
|
|
|
|
APInt min = APInt::getZero(1);
|
|
APInt max = APInt::getAllOnes(1);
|
|
|
|
std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
|
|
|
|
ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
|
|
rhsTrunc = truncRange(rhs, indexMinWidth);
|
|
std::optional<bool> truthValue32 =
|
|
intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
|
|
|
|
if (truthValue64 == truthValue32) {
|
|
if (truthValue64.has_value() && *truthValue64)
|
|
min = max;
|
|
else if (truthValue64.has_value() && !(*truthValue64))
|
|
max = min;
|
|
}
|
|
setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SizeOf, which is bounded between the two supported bitwidth (32 and 64).
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
|
SetIntRangeFn setResultRange) {
|
|
unsigned storageWidth =
|
|
ConstantIntRanges::getStorageBitwidth(getResult().getType());
|
|
APInt min(/*numBits=*/storageWidth, indexMinWidth);
|
|
APInt max(/*numBits=*/storageWidth, indexMaxWidth);
|
|
setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
|
|
}
|