702 lines
25 KiB
C++
702 lines
25 KiB
C++
|
//===- IndexOps.cpp - Index operation definitions --------------------------==//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "mlir/Dialect/Index/IR/IndexOps.h"
|
||
|
#include "mlir/Dialect/Index/IR/IndexAttrs.h"
|
||
|
#include "mlir/Dialect/Index/IR/IndexDialect.h"
|
||
|
#include "mlir/IR/Builders.h"
|
||
|
#include "mlir/IR/Matchers.h"
|
||
|
#include "mlir/IR/OpImplementation.h"
|
||
|
#include "mlir/IR/PatternMatch.h"
|
||
|
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
|
||
|
#include "llvm/ADT/SmallString.h"
|
||
|
#include "llvm/ADT/TypeSwitch.h"
|
||
|
|
||
|
using namespace mlir;
|
||
|
using namespace mlir::index;
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// IndexDialect
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
void IndexDialect::registerOperations() {
|
||
|
addOperations<
|
||
|
#define GET_OP_LIST
|
||
|
#include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
|
||
|
>();
|
||
|
}
|
||
|
|
||
|
Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
|
||
|
Type type, Location loc) {
|
||
|
// Materialize bool constants as `i1`.
|
||
|
if (auto boolValue = dyn_cast<BoolAttr>(value)) {
|
||
|
if (!type.isSignlessInteger(1))
|
||
|
return nullptr;
|
||
|
return b.create<BoolConstantOp>(loc, type, boolValue);
|
||
|
}
|
||
|
|
||
|
// Materialize integer attributes as `index`.
|
||
|
if (auto indexValue = dyn_cast<IntegerAttr>(value)) {
|
||
|
if (!llvm::isa<IndexType>(indexValue.getType()) ||
|
||
|
!llvm::isa<IndexType>(type))
|
||
|
return nullptr;
|
||
|
assert(indexValue.getValue().getBitWidth() ==
|
||
|
IndexType::kInternalStorageBitWidth);
|
||
|
return b.create<ConstantOp>(loc, indexValue);
|
||
|
}
|
||
|
|
||
|
return nullptr;
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Fold Utilities
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
/// Fold an index operation irrespective of the target bitwidth. The
|
||
|
/// operation must satisfy the property:
|
||
|
///
|
||
|
/// ```
|
||
|
/// trunc(f(a, b)) = f(trunc(a), trunc(b))
|
||
|
/// ```
|
||
|
///
|
||
|
/// For all values of `a` and `b`. The function accepts a lambda that computes
|
||
|
/// the integer result, which in turn must satisfy the above property.
|
||
|
static OpFoldResult foldBinaryOpUnchecked(
|
||
|
ArrayRef<Attribute> operands,
|
||
|
function_ref<std::optional<APInt>(const APInt &, const APInt &)>
|
||
|
calculate) {
|
||
|
assert(operands.size() == 2 && "binary operation expected 2 operands");
|
||
|
auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
|
||
|
auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
|
||
|
if (!lhs || !rhs)
|
||
|
return {};
|
||
|
|
||
|
std::optional<APInt> result = calculate(lhs.getValue(), rhs.getValue());
|
||
|
if (!result)
|
||
|
return {};
|
||
|
assert(result->trunc(32) ==
|
||
|
calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
|
||
|
return IntegerAttr::get(IndexType::get(lhs.getContext()), *result);
|
||
|
}
|
||
|
|
||
|
/// Fold an index operation only if the truncated 64-bit result matches the
|
||
|
/// 32-bit result for operations that don't satisfy the above property. These
|
||
|
/// are operations where the upper bits of the operands can affect the lower
|
||
|
/// bits of the results.
|
||
|
///
|
||
|
/// The function accepts a lambda that computes the integer result in both
|
||
|
/// 64-bit and 32-bit. If either call returns `std::nullopt`, the operation is
|
||
|
/// not folded.
|
||
|
static OpFoldResult foldBinaryOpChecked(
|
||
|
ArrayRef<Attribute> operands,
|
||
|
function_ref<std::optional<APInt>(const APInt &, const APInt &lhs)>
|
||
|
calculate) {
|
||
|
assert(operands.size() == 2 && "binary operation expected 2 operands");
|
||
|
auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
|
||
|
auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
|
||
|
// Only fold index operands.
|
||
|
if (!lhs || !rhs)
|
||
|
return {};
|
||
|
|
||
|
// Compute the 64-bit result and the 32-bit result.
|
||
|
std::optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue());
|
||
|
if (!result64)
|
||
|
return {};
|
||
|
std::optional<APInt> result32 =
|
||
|
calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32));
|
||
|
if (!result32)
|
||
|
return {};
|
||
|
// Compare the truncated 64-bit result to the 32-bit result.
|
||
|
if (result64->trunc(32) != *result32)
|
||
|
return {};
|
||
|
// The operation can be folded for these particular operands.
|
||
|
return IntegerAttr::get(IndexType::get(lhs.getContext()), *result64);
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// AddOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
|
||
|
if (OpFoldResult result = foldBinaryOpUnchecked(
|
||
|
adaptor.getOperands(),
|
||
|
[](const APInt &lhs, const APInt &rhs) { return lhs + rhs; }))
|
||
|
return result;
|
||
|
|
||
|
if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
|
||
|
// Fold `add(x, 0) -> x`.
|
||
|
if (rhs.getValue().isZero())
|
||
|
return getLhs();
|
||
|
}
|
||
|
|
||
|
return {};
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// SubOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
|
||
|
if (OpFoldResult result = foldBinaryOpUnchecked(
|
||
|
adaptor.getOperands(),
|
||
|
[](const APInt &lhs, const APInt &rhs) { return lhs - rhs; }))
|
||
|
return result;
|
||
|
|
||
|
if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
|
||
|
// Fold `sub(x, 0) -> x`.
|
||
|
if (rhs.getValue().isZero())
|
||
|
return getLhs();
|
||
|
}
|
||
|
|
||
|
return {};
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// MulOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
|
||
|
if (OpFoldResult result = foldBinaryOpUnchecked(
|
||
|
adaptor.getOperands(),
|
||
|
[](const APInt &lhs, const APInt &rhs) { return lhs * rhs; }))
|
||
|
return result;
|
||
|
|
||
|
if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
|
||
|
// Fold `mul(x, 1) -> x`.
|
||
|
if (rhs.getValue().isOne())
|
||
|
return getLhs();
|
||
|
// Fold `mul(x, 0) -> 0`.
|
||
|
if (rhs.getValue().isZero())
|
||
|
return rhs;
|
||
|
}
|
||
|
|
||
|
return {};
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// DivSOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
|
||
|
return foldBinaryOpChecked(
|
||
|
adaptor.getOperands(),
|
||
|
[](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
|
||
|
// Don't fold division by zero.
|
||
|
if (rhs.isZero())
|
||
|
return std::nullopt;
|
||
|
return lhs.sdiv(rhs);
|
||
|
});
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// DivUOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
|
||
|
return foldBinaryOpChecked(
|
||
|
adaptor.getOperands(),
|
||
|
[](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
|
||
|
// Don't fold division by zero.
|
||
|
if (rhs.isZero())
|
||
|
return std::nullopt;
|
||
|
return lhs.udiv(rhs);
|
||
|
});
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// CeilDivSOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
/// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then
|
||
|
/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
|
||
|
static std::optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) {
|
||
|
// Don't fold division by zero.
|
||
|
if (m.isZero())
|
||
|
return std::nullopt;
|
||
|
// Short-circuit the zero case.
|
||
|
if (n.isZero())
|
||
|
return n;
|
||
|
|
||
|
bool mGtZ = m.sgt(0);
|
||
|
if (n.sgt(0) != mGtZ) {
|
||
|
// If the operands have different signs, compute the negative result. Signed
|
||
|
// division overflow is not possible, since if `m == -1`, `n` can be at most
|
||
|
// `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement.
|
||
|
return -(-n).sdiv(m);
|
||
|
}
|
||
|
// Otherwise, compute the positive result. Signed division overflow is not
|
||
|
// possible since if `m == -1`, `x` will be `1`.
|
||
|
int64_t x = mGtZ ? -1 : 1;
|
||
|
return (n + x).sdiv(m) + 1;
|
||
|
}
|
||
|
|
||
|
OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) {
|
||
|
return foldBinaryOpChecked(adaptor.getOperands(), calculateCeilDivS);
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// CeilDivUOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) {
|
||
|
// Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
|
||
|
return foldBinaryOpChecked(
|
||
|
adaptor.getOperands(),
|
||
|
[](const APInt &n, const APInt &m) -> std::optional<APInt> {
|
||
|
// Don't fold division by zero.
|
||
|
if (m.isZero())
|
||
|
return std::nullopt;
|
||
|
// Short-circuit the zero case.
|
||
|
if (n.isZero())
|
||
|
return n;
|
||
|
|
||
|
return (n - 1).udiv(m) + 1;
|
||
|
});
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// FloorDivSOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
/// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then
|
||
|
/// `n*m < 0 ? -1 - (x-n)/m : n/m`.
|
||
|
static std::optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) {
|
||
|
// Don't fold division by zero.
|
||
|
if (m.isZero())
|
||
|
return std::nullopt;
|
||
|
// Short-circuit the zero case.
|
||
|
if (n.isZero())
|
||
|
return n;
|
||
|
|
||
|
bool mLtZ = m.slt(0);
|
||
|
if (n.slt(0) == mLtZ) {
|
||
|
// If the operands have the same sign, compute the positive result.
|
||
|
return n.sdiv(m);
|
||
|
}
|
||
|
// If the operands have different signs, compute the negative result. Signed
|
||
|
// division overflow is not possible since if `m == -1`, `x` will be 1 and
|
||
|
// `n` can be at most `INT_MAX`.
|
||
|
int64_t x = mLtZ ? 1 : -1;
|
||
|
return -1 - (x - n).sdiv(m);
|
||
|
}
|
||
|
|
||
|
OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) {
|
||
|
return foldBinaryOpChecked(adaptor.getOperands(), calculateFloorDivS);
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// RemSOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
OpFoldResult RemSOp::fold(FoldAdaptor adaptor) {
|
||
|
return foldBinaryOpChecked(
|
||
|
adaptor.getOperands(),
|
||
|
[](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
|
||
|
// Don't fold division by zero.
|
||
|
if (rhs.isZero())
|
||
|
return std::nullopt;
|
||
|
return lhs.srem(rhs);
|
||
|
});
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// RemUOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
OpFoldResult RemUOp::fold(FoldAdaptor adaptor) {
|
||
|
return foldBinaryOpChecked(
|
||
|
adaptor.getOperands(),
|
||
|
[](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
|
||
|
// Don't fold division by zero.
|
||
|
if (rhs.isZero())
|
||
|
return std::nullopt;
|
||
|
return lhs.urem(rhs);
|
||
|
});
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// MaxSOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
|
||
|
return foldBinaryOpChecked(adaptor.getOperands(),
|
||
|
[](const APInt &lhs, const APInt &rhs) {
|
||
|
return lhs.sgt(rhs) ? lhs : rhs;
|
||
|
});
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// MaxUOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
|
||
|
return foldBinaryOpChecked(adaptor.getOperands(),
|
||
|
[](const APInt &lhs, const APInt &rhs) {
|
||
|
return lhs.ugt(rhs) ? lhs : rhs;
|
||
|
});
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// MinSOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
|
||
|
return foldBinaryOpChecked(adaptor.getOperands(),
|
||
|
[](const APInt &lhs, const APInt &rhs) {
|
||
|
return lhs.slt(rhs) ? lhs : rhs;
|
||
|
});
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// MinUOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
|
||
|
return foldBinaryOpChecked(adaptor.getOperands(),
|
||
|
[](const APInt &lhs, const APInt &rhs) {
|
||
|
return lhs.ult(rhs) ? lhs : rhs;
|
||
|
});
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// ShlOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
|
||
|
return foldBinaryOpUnchecked(
|
||
|
adaptor.getOperands(),
|
||
|
[](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
|
||
|
// We cannot fold if the RHS is greater than or equal to 32 because
|
||
|
// this would be UB in 32-bit systems but not on 64-bit systems. RHS is
|
||
|
// already treated as unsigned.
|
||
|
if (rhs.uge(32))
|
||
|
return {};
|
||
|
return lhs << rhs;
|
||
|
});
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// ShrSOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
|
||
|
return foldBinaryOpChecked(
|
||
|
adaptor.getOperands(),
|
||
|
[](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
|
||
|
// Don't fold if RHS is greater than or equal to 32.
|
||
|
if (rhs.uge(32))
|
||
|
return {};
|
||
|
return lhs.ashr(rhs);
|
||
|
});
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// ShrUOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
|
||
|
return foldBinaryOpChecked(
|
||
|
adaptor.getOperands(),
|
||
|
[](const APInt &lhs, const APInt &rhs) -> std::optional<APInt> {
|
||
|
// Don't fold if RHS is greater than or equal to 32.
|
||
|
if (rhs.uge(32))
|
||
|
return {};
|
||
|
return lhs.lshr(rhs);
|
||
|
});
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// AndOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
|
||
|
return foldBinaryOpUnchecked(
|
||
|
adaptor.getOperands(),
|
||
|
[](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// OrOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
|
||
|
return foldBinaryOpUnchecked(
|
||
|
adaptor.getOperands(),
|
||
|
[](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// XOrOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
|
||
|
return foldBinaryOpUnchecked(
|
||
|
adaptor.getOperands(),
|
||
|
[](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// CastSOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
static OpFoldResult
|
||
|
foldCastOp(Attribute input, Type type,
|
||
|
function_ref<APInt(const APInt &, unsigned)> extFn,
|
||
|
function_ref<APInt(const APInt &, unsigned)> extOrTruncFn) {
|
||
|
auto attr = dyn_cast_if_present<IntegerAttr>(input);
|
||
|
if (!attr)
|
||
|
return {};
|
||
|
const APInt &value = attr.getValue();
|
||
|
|
||
|
if (isa<IndexType>(type)) {
|
||
|
// When casting to an index type, perform the cast assuming a 64-bit target.
|
||
|
// The result can be truncated to 32 bits as needed and always be correct.
|
||
|
// This is because `cast32(cast64(value)) == cast32(value)`.
|
||
|
APInt result = extOrTruncFn(value, 64);
|
||
|
return IntegerAttr::get(type, result);
|
||
|
}
|
||
|
|
||
|
// When casting from an index type, we must ensure the results respect
|
||
|
// `cast_t(value) == cast_t(trunc32(value))`.
|
||
|
auto intType = cast<IntegerType>(type);
|
||
|
unsigned width = intType.getWidth();
|
||
|
|
||
|
// If the result type is at most 32 bits, then the cast can always be folded
|
||
|
// because it is always a truncation.
|
||
|
if (width <= 32) {
|
||
|
APInt result = value.trunc(width);
|
||
|
return IntegerAttr::get(type, result);
|
||
|
}
|
||
|
|
||
|
// If the result type is at least 64 bits, then the cast is always a
|
||
|
// extension. The results will differ if `trunc32(value) != value)`.
|
||
|
if (width >= 64) {
|
||
|
if (extFn(value.trunc(32), 64) != value)
|
||
|
return {};
|
||
|
APInt result = extFn(value, width);
|
||
|
return IntegerAttr::get(type, result);
|
||
|
}
|
||
|
|
||
|
// Otherwise, we just have to check the property directly.
|
||
|
APInt result = value.trunc(width);
|
||
|
if (result != extFn(value.trunc(32), width))
|
||
|
return {};
|
||
|
return IntegerAttr::get(type, result);
|
||
|
}
|
||
|
|
||
|
bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
|
||
|
return llvm::isa<IndexType>(lhsTypes.front()) !=
|
||
|
llvm::isa<IndexType>(rhsTypes.front());
|
||
|
}
|
||
|
|
||
|
OpFoldResult CastSOp::fold(FoldAdaptor adaptor) {
|
||
|
return foldCastOp(
|
||
|
adaptor.getInput(), getType(),
|
||
|
[](const APInt &x, unsigned width) { return x.sext(width); },
|
||
|
[](const APInt &x, unsigned width) { return x.sextOrTrunc(width); });
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// CastUOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
|
||
|
return llvm::isa<IndexType>(lhsTypes.front()) !=
|
||
|
llvm::isa<IndexType>(rhsTypes.front());
|
||
|
}
|
||
|
|
||
|
OpFoldResult CastUOp::fold(FoldAdaptor adaptor) {
|
||
|
return foldCastOp(
|
||
|
adaptor.getInput(), getType(),
|
||
|
[](const APInt &x, unsigned width) { return x.zext(width); },
|
||
|
[](const APInt &x, unsigned width) { return x.zextOrTrunc(width); });
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// CmpOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
/// Compare two integers according to the comparison predicate.
|
||
|
bool compareIndices(const APInt &lhs, const APInt &rhs,
|
||
|
IndexCmpPredicate pred) {
|
||
|
switch (pred) {
|
||
|
case IndexCmpPredicate::EQ:
|
||
|
return lhs.eq(rhs);
|
||
|
case IndexCmpPredicate::NE:
|
||
|
return lhs.ne(rhs);
|
||
|
case IndexCmpPredicate::SGE:
|
||
|
return lhs.sge(rhs);
|
||
|
case IndexCmpPredicate::SGT:
|
||
|
return lhs.sgt(rhs);
|
||
|
case IndexCmpPredicate::SLE:
|
||
|
return lhs.sle(rhs);
|
||
|
case IndexCmpPredicate::SLT:
|
||
|
return lhs.slt(rhs);
|
||
|
case IndexCmpPredicate::UGE:
|
||
|
return lhs.uge(rhs);
|
||
|
case IndexCmpPredicate::UGT:
|
||
|
return lhs.ugt(rhs);
|
||
|
case IndexCmpPredicate::ULE:
|
||
|
return lhs.ule(rhs);
|
||
|
case IndexCmpPredicate::ULT:
|
||
|
return lhs.ult(rhs);
|
||
|
}
|
||
|
llvm_unreachable("unhandled IndexCmpPredicate predicate");
|
||
|
}
|
||
|
|
||
|
/// `cmp(max/min(x, cstA), cstB)` can be folded to a constant depending on the
|
||
|
/// values of `cstA` and `cstB`, the max or min operation, and the comparison
|
||
|
/// predicate. Check whether the value folds in both 32-bit and 64-bit
|
||
|
/// arithmetic and to the same value.
|
||
|
static std::optional<bool> foldCmpOfMaxOrMin(Operation *lhsOp,
|
||
|
const APInt &cstA,
|
||
|
const APInt &cstB, unsigned width,
|
||
|
IndexCmpPredicate pred) {
|
||
|
ConstantIntRanges lhsRange = TypeSwitch<Operation *, ConstantIntRanges>(lhsOp)
|
||
|
.Case([&](MinSOp op) {
|
||
|
return ConstantIntRanges::fromSigned(
|
||
|
APInt::getSignedMinValue(width), cstA);
|
||
|
})
|
||
|
.Case([&](MinUOp op) {
|
||
|
return ConstantIntRanges::fromUnsigned(
|
||
|
APInt::getMinValue(width), cstA);
|
||
|
})
|
||
|
.Case([&](MaxSOp op) {
|
||
|
return ConstantIntRanges::fromSigned(
|
||
|
cstA, APInt::getSignedMaxValue(width));
|
||
|
})
|
||
|
.Case([&](MaxUOp op) {
|
||
|
return ConstantIntRanges::fromUnsigned(
|
||
|
cstA, APInt::getMaxValue(width));
|
||
|
});
|
||
|
return intrange::evaluatePred(static_cast<intrange::CmpPredicate>(pred),
|
||
|
lhsRange, ConstantIntRanges::constant(cstB));
|
||
|
}
|
||
|
|
||
|
/// Return the result of `cmp(pred, x, x)`
|
||
|
static bool compareSameArgs(IndexCmpPredicate pred) {
|
||
|
switch (pred) {
|
||
|
case IndexCmpPredicate::EQ:
|
||
|
case IndexCmpPredicate::SGE:
|
||
|
case IndexCmpPredicate::SLE:
|
||
|
case IndexCmpPredicate::UGE:
|
||
|
case IndexCmpPredicate::ULE:
|
||
|
return true;
|
||
|
case IndexCmpPredicate::NE:
|
||
|
case IndexCmpPredicate::SGT:
|
||
|
case IndexCmpPredicate::SLT:
|
||
|
case IndexCmpPredicate::UGT:
|
||
|
case IndexCmpPredicate::ULT:
|
||
|
return false;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
|
||
|
// Attempt to fold if both inputs are constant.
|
||
|
auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
|
||
|
auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
|
||
|
if (lhs && rhs) {
|
||
|
// Perform the comparison in 64-bit and 32-bit.
|
||
|
bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred());
|
||
|
bool result32 = compareIndices(lhs.getValue().trunc(32),
|
||
|
rhs.getValue().trunc(32), getPred());
|
||
|
if (result64 == result32)
|
||
|
return BoolAttr::get(getContext(), result64);
|
||
|
}
|
||
|
|
||
|
// Fold `cmp(max/min(x, cstA), cstB)`.
|
||
|
Operation *lhsOp = getLhs().getDefiningOp();
|
||
|
IntegerAttr cstA;
|
||
|
if (isa_and_nonnull<MinSOp, MinUOp, MaxSOp, MaxUOp>(lhsOp) &&
|
||
|
matchPattern(lhsOp->getOperand(1), m_Constant(&cstA)) && rhs) {
|
||
|
std::optional<bool> result64 = foldCmpOfMaxOrMin(
|
||
|
lhsOp, cstA.getValue(), rhs.getValue(), 64, getPred());
|
||
|
std::optional<bool> result32 =
|
||
|
foldCmpOfMaxOrMin(lhsOp, cstA.getValue().trunc(32),
|
||
|
rhs.getValue().trunc(32), 32, getPred());
|
||
|
// Fold if the 32-bit and 64-bit results are the same.
|
||
|
if (result64 && result32 && *result64 == *result32)
|
||
|
return BoolAttr::get(getContext(), *result64);
|
||
|
}
|
||
|
|
||
|
// Fold `cmp(x, x)`
|
||
|
if (getLhs() == getRhs())
|
||
|
return BoolAttr::get(getContext(), compareSameArgs(getPred()));
|
||
|
|
||
|
return {};
|
||
|
}
|
||
|
|
||
|
/// Canonicalize
|
||
|
/// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`.
|
||
|
/// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`.
|
||
|
LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
|
||
|
IntegerAttr cmpRhs;
|
||
|
IntegerAttr cmpLhs;
|
||
|
|
||
|
bool rhsIsZero = matchPattern(op.getRhs(), m_Constant(&cmpRhs)) &&
|
||
|
cmpRhs.getValue().isZero();
|
||
|
bool lhsIsZero = matchPattern(op.getLhs(), m_Constant(&cmpLhs)) &&
|
||
|
cmpLhs.getValue().isZero();
|
||
|
if (!rhsIsZero && !lhsIsZero)
|
||
|
return rewriter.notifyMatchFailure(op.getLoc(),
|
||
|
"cmp is not comparing something with 0");
|
||
|
SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>()
|
||
|
: op.getRhs().getDefiningOp<index::SubOp>();
|
||
|
if (!subOp)
|
||
|
return rewriter.notifyMatchFailure(
|
||
|
op.getLoc(), "non-zero operand is not a result of subtraction");
|
||
|
|
||
|
index::CmpOp newCmp;
|
||
|
if (rhsIsZero)
|
||
|
newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
|
||
|
subOp.getLhs(), subOp.getRhs());
|
||
|
else
|
||
|
newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
|
||
|
subOp.getRhs(), subOp.getLhs());
|
||
|
rewriter.replaceOp(op, newCmp);
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// ConstantOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
void ConstantOp::getAsmResultNames(
|
||
|
function_ref<void(Value, StringRef)> setNameFn) {
|
||
|
SmallString<32> specialNameBuffer;
|
||
|
llvm::raw_svector_ostream specialName(specialNameBuffer);
|
||
|
specialName << "idx" << getValueAttr().getValue();
|
||
|
setNameFn(getResult(), specialName.str());
|
||
|
}
|
||
|
|
||
|
OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
|
||
|
|
||
|
void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
|
||
|
build(b, state, b.getIndexType(), b.getIndexAttr(value));
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// BoolConstantOp
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
|
||
|
return getValueAttr();
|
||
|
}
|
||
|
|
||
|
void BoolConstantOp::getAsmResultNames(
|
||
|
function_ref<void(Value, StringRef)> setNameFn) {
|
||
|
setNameFn(getResult(), getValue() ? "true" : "false");
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// ODS-Generated Definitions
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#define GET_OP_CLASSES
|
||
|
#include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
|