bolt/deps/llvm-18.1.8/mlir/lib/Dialect/Index/IR/InferIntRangeInterfaceImpls.cpp
2025-02-14 19:21:04 +01:00

253 lines
10 KiB
C++

//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for arith -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "llvm/Support/Debug.h"
#include <optional>
#define DEBUG_TYPE "int-range-analysis"
using namespace mlir;
using namespace mlir::index;
using namespace mlir::intrange;
//===----------------------------------------------------------------------===//
// Constants
//===----------------------------------------------------------------------===//
void ConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
const APInt &value = getValue();
setResultRange(getResult(), ConstantIntRanges::constant(value));
}
void BoolConstantOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
bool value = getValue();
APInt asInt(/*numBits=*/1, value);
setResultRange(getResult(), ConstantIntRanges::constant(asInt));
}
//===----------------------------------------------------------------------===//
// Arithmec operations. All of these operations will have their results inferred
// using both the 64-bit values and truncated 32-bit values of their inputs,
// with the results being the union of those inferences, except where the
// truncation of the 64-bit result is equal to the 32-bit result (at which time
// we take the 64-bit result).
//===----------------------------------------------------------------------===//
void AddOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferIndexOp(inferAdd, argRanges, CmpMode::Both));
}
void SubOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferIndexOp(inferSub, argRanges, CmpMode::Both));
}
void MulOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferIndexOp(inferMul, argRanges, CmpMode::Both));
}
void DivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(),
inferIndexOp(inferDivU, argRanges, CmpMode::Unsigned));
}
void DivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(),
inferIndexOp(inferDivS, argRanges, CmpMode::Signed));
}
void CeilDivUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(),
inferIndexOp(inferCeilDivU, argRanges, CmpMode::Unsigned));
}
void CeilDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(),
inferIndexOp(inferCeilDivS, argRanges, CmpMode::Signed));
}
void FloorDivSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
return setResultRange(
getResult(), inferIndexOp(inferFloorDivS, argRanges, CmpMode::Signed));
}
void RemSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(),
inferIndexOp(inferRemS, argRanges, CmpMode::Signed));
}
void RemUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(),
inferIndexOp(inferRemU, argRanges, CmpMode::Unsigned));
}
void MaxSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(),
inferIndexOp(inferMaxS, argRanges, CmpMode::Signed));
}
void MaxUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(),
inferIndexOp(inferMaxU, argRanges, CmpMode::Unsigned));
}
void MinSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(),
inferIndexOp(inferMinS, argRanges, CmpMode::Signed));
}
void MinUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(),
inferIndexOp(inferMinU, argRanges, CmpMode::Unsigned));
}
void ShlOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(), inferIndexOp(inferShl, argRanges, CmpMode::Both));
}
void ShrSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(),
inferIndexOp(inferShrS, argRanges, CmpMode::Signed));
}
void ShrUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(),
inferIndexOp(inferShrU, argRanges, CmpMode::Unsigned));
}
void AndOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(),
inferIndexOp(inferAnd, argRanges, CmpMode::Unsigned));
}
void OrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(),
inferIndexOp(inferOr, argRanges, CmpMode::Unsigned));
}
void XOrOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
setResultRange(getResult(),
inferIndexOp(inferXor, argRanges, CmpMode::Unsigned));
}
//===----------------------------------------------------------------------===//
// Casts
//===----------------------------------------------------------------------===//
static ConstantIntRanges makeLikeDest(const ConstantIntRanges &range,
unsigned srcWidth, unsigned destWidth,
bool isSigned) {
if (srcWidth < destWidth)
return isSigned ? extSIRange(range, destWidth)
: extUIRange(range, destWidth);
if (srcWidth > destWidth)
return truncRange(range, destWidth);
return range;
}
// When casting to `index`, we will take the union of the possible fixed-width
// casts.
static ConstantIntRanges inferIndexCast(const ConstantIntRanges &range,
Type sourceType, Type destType,
bool isSigned) {
unsigned srcWidth = ConstantIntRanges::getStorageBitwidth(sourceType);
unsigned destWidth = ConstantIntRanges::getStorageBitwidth(destType);
if (sourceType.isIndex())
return makeLikeDest(range, srcWidth, destWidth, isSigned);
// We are casting to indexs, so use the union of the 32-bit and 64-bit casts
ConstantIntRanges storageRange =
makeLikeDest(range, srcWidth, destWidth, isSigned);
ConstantIntRanges minWidthRange =
makeLikeDest(range, srcWidth, indexMinWidth, isSigned);
ConstantIntRanges minWidthExt = extRange(minWidthRange, destWidth);
ConstantIntRanges ret = storageRange.rangeUnion(minWidthExt);
return ret;
}
void CastSOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
Type sourceType = getOperand().getType();
Type destType = getResult().getType();
setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
/*isSigned=*/true));
}
void CastUOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
Type sourceType = getOperand().getType();
Type destType = getResult().getType();
setResultRange(getResult(), inferIndexCast(argRanges[0], sourceType, destType,
/*isSigned=*/false));
}
//===----------------------------------------------------------------------===//
// CmpOp
//===----------------------------------------------------------------------===//
void CmpOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
index::IndexCmpPredicate indexPred = getPred();
intrange::CmpPredicate pred = static_cast<intrange::CmpPredicate>(indexPred);
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
APInt min = APInt::getZero(1);
APInt max = APInt::getAllOnes(1);
std::optional<bool> truthValue64 = intrange::evaluatePred(pred, lhs, rhs);
ConstantIntRanges lhsTrunc = truncRange(lhs, indexMinWidth),
rhsTrunc = truncRange(rhs, indexMinWidth);
std::optional<bool> truthValue32 =
intrange::evaluatePred(pred, lhsTrunc, rhsTrunc);
if (truthValue64 == truthValue32) {
if (truthValue64.has_value() && *truthValue64)
min = max;
else if (truthValue64.has_value() && !(*truthValue64))
max = min;
}
setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
}
//===----------------------------------------------------------------------===//
// SizeOf, which is bounded between the two supported bitwidth (32 and 64).
//===----------------------------------------------------------------------===//
void SizeOfOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRange) {
unsigned storageWidth =
ConstantIntRanges::getStorageBitwidth(getResult().getType());
APInt min(/*numBits=*/storageWidth, indexMinWidth);
APInt max(/*numBits=*/storageWidth, indexMaxWidth);
setResultRange(getResult(), ConstantIntRanges::fromUnsigned(min, max));
}