2570 lines
94 KiB
C++
2570 lines
94 KiB
C++
//===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include <cassert>
|
|
#include <cstdint>
|
|
#include <functional>
|
|
#include <utility>
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/CommonFolders.h"
|
|
#include "mlir/Dialect/UB/IR/UBOps.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinAttributeInterfaces.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
|
|
#include "llvm/ADT/APInt.h"
|
|
#include "llvm/ADT/APSInt.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallString.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::arith;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Pattern helpers
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static IntegerAttr
|
|
applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs,
|
|
Attribute rhs,
|
|
function_ref<APInt(const APInt &, const APInt &)> binFn) {
|
|
APInt lhsVal = llvm::cast<IntegerAttr>(lhs).getValue();
|
|
APInt rhsVal = llvm::cast<IntegerAttr>(rhs).getValue();
|
|
APInt value = binFn(lhsVal, rhsVal);
|
|
return IntegerAttr::get(res.getType(), value);
|
|
}
|
|
|
|
static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
|
|
Attribute lhs, Attribute rhs) {
|
|
return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus<APInt>());
|
|
}
|
|
|
|
static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
|
|
Attribute lhs, Attribute rhs) {
|
|
return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus<APInt>());
|
|
}
|
|
|
|
static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
|
|
Attribute lhs, Attribute rhs) {
|
|
return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
|
|
}
|
|
|
|
static IntegerOverflowFlagsAttr getDefOverflowFlags(OpBuilder &builder) {
|
|
return IntegerOverflowFlagsAttr::get(builder.getContext(),
|
|
IntegerOverflowFlags::none);
|
|
}
|
|
|
|
/// Invert an integer comparison predicate.
|
|
arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
|
|
switch (pred) {
|
|
case arith::CmpIPredicate::eq:
|
|
return arith::CmpIPredicate::ne;
|
|
case arith::CmpIPredicate::ne:
|
|
return arith::CmpIPredicate::eq;
|
|
case arith::CmpIPredicate::slt:
|
|
return arith::CmpIPredicate::sge;
|
|
case arith::CmpIPredicate::sle:
|
|
return arith::CmpIPredicate::sgt;
|
|
case arith::CmpIPredicate::sgt:
|
|
return arith::CmpIPredicate::sle;
|
|
case arith::CmpIPredicate::sge:
|
|
return arith::CmpIPredicate::slt;
|
|
case arith::CmpIPredicate::ult:
|
|
return arith::CmpIPredicate::uge;
|
|
case arith::CmpIPredicate::ule:
|
|
return arith::CmpIPredicate::ugt;
|
|
case arith::CmpIPredicate::ugt:
|
|
return arith::CmpIPredicate::ule;
|
|
case arith::CmpIPredicate::uge:
|
|
return arith::CmpIPredicate::ult;
|
|
}
|
|
llvm_unreachable("unknown cmpi predicate kind");
|
|
}
|
|
|
|
static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
|
|
return arith::CmpIPredicateAttr::get(pred.getContext(),
|
|
invertPredicate(pred.getValue()));
|
|
}
|
|
|
|
static int64_t getScalarOrElementWidth(Type type) {
|
|
Type elemTy = getElementTypeOrSelf(type);
|
|
if (elemTy.isIntOrFloat())
|
|
return elemTy.getIntOrFloatBitWidth();
|
|
|
|
return -1;
|
|
}
|
|
|
|
static int64_t getScalarOrElementWidth(Value value) {
|
|
return getScalarOrElementWidth(value.getType());
|
|
}
|
|
|
|
static FailureOr<APInt> getIntOrSplatIntValue(Attribute attr) {
|
|
APInt value;
|
|
if (matchPattern(attr, m_ConstantInt(&value)))
|
|
return value;
|
|
|
|
return failure();
|
|
}
|
|
|
|
static Attribute getBoolAttribute(Type type, bool value) {
|
|
auto boolAttr = BoolAttr::get(type.getContext(), value);
|
|
ShapedType shapedType = llvm::dyn_cast_or_null<ShapedType>(type);
|
|
if (!shapedType)
|
|
return boolAttr;
|
|
return DenseElementsAttr::get(shapedType, boolAttr);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen'd canonicalization patterns
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
#include "ArithCanonicalization.inc"
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Common helpers
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Return the type of the same shape (scalar, vector or tensor) containing i1.
|
|
static Type getI1SameShape(Type type) {
|
|
auto i1Type = IntegerType::get(type.getContext(), 1);
|
|
if (auto shapedType = llvm::dyn_cast<ShapedType>(type))
|
|
return shapedType.cloneWith(std::nullopt, i1Type);
|
|
if (llvm::isa<UnrankedTensorType>(type))
|
|
return UnrankedTensorType::get(i1Type);
|
|
return i1Type;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ConstantOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void arith::ConstantOp::getAsmResultNames(
|
|
function_ref<void(Value, StringRef)> setNameFn) {
|
|
auto type = getType();
|
|
if (auto intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
|
|
auto intType = llvm::dyn_cast<IntegerType>(type);
|
|
|
|
// Sugar i1 constants with 'true' and 'false'.
|
|
if (intType && intType.getWidth() == 1)
|
|
return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
|
|
|
|
// Otherwise, build a complex name with the value and type.
|
|
SmallString<32> specialNameBuffer;
|
|
llvm::raw_svector_ostream specialName(specialNameBuffer);
|
|
specialName << 'c' << intCst.getValue();
|
|
if (intType)
|
|
specialName << '_' << type;
|
|
setNameFn(getResult(), specialName.str());
|
|
} else {
|
|
setNameFn(getResult(), "cst");
|
|
}
|
|
}
|
|
|
|
/// TODO: disallow arith.constant to return anything other than signless integer
|
|
/// or float like.
|
|
LogicalResult arith::ConstantOp::verify() {
|
|
auto type = getType();
|
|
// The value's type must match the return type.
|
|
if (getValue().getType() != type) {
|
|
return emitOpError() << "value type " << getValue().getType()
|
|
<< " must match return type: " << type;
|
|
}
|
|
// Integer values must be signless.
|
|
if (llvm::isa<IntegerType>(type) &&
|
|
!llvm::cast<IntegerType>(type).isSignless())
|
|
return emitOpError("integer return type must be signless");
|
|
// Any float or elements attribute are acceptable.
|
|
if (!llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(getValue())) {
|
|
return emitOpError(
|
|
"value must be an integer, float, or elements attribute");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
|
|
// The value's type must be the same as the provided type.
|
|
auto typedAttr = llvm::dyn_cast<TypedAttr>(value);
|
|
if (!typedAttr || typedAttr.getType() != type)
|
|
return false;
|
|
// Integer values must be signless.
|
|
if (llvm::isa<IntegerType>(type) &&
|
|
!llvm::cast<IntegerType>(type).isSignless())
|
|
return false;
|
|
// Integer, float, and element attributes are buildable.
|
|
return llvm::isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
|
|
}
|
|
|
|
ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value,
|
|
Type type, Location loc) {
|
|
if (isBuildableWith(value, type))
|
|
return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(value));
|
|
return nullptr;
|
|
}
|
|
|
|
OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
|
|
|
|
void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
|
|
int64_t value, unsigned width) {
|
|
auto type = builder.getIntegerType(width);
|
|
arith::ConstantOp::build(builder, result, type,
|
|
builder.getIntegerAttr(type, value));
|
|
}
|
|
|
|
void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
|
|
int64_t value, Type type) {
|
|
assert(type.isSignlessInteger() &&
|
|
"ConstantIntOp can only have signless integer type values");
|
|
arith::ConstantOp::build(builder, result, type,
|
|
builder.getIntegerAttr(type, value));
|
|
}
|
|
|
|
bool arith::ConstantIntOp::classof(Operation *op) {
|
|
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
|
|
return constOp.getType().isSignlessInteger();
|
|
return false;
|
|
}
|
|
|
|
void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
|
|
const APFloat &value, FloatType type) {
|
|
arith::ConstantOp::build(builder, result, type,
|
|
builder.getFloatAttr(type, value));
|
|
}
|
|
|
|
bool arith::ConstantFloatOp::classof(Operation *op) {
|
|
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
|
|
return llvm::isa<FloatType>(constOp.getType());
|
|
return false;
|
|
}
|
|
|
|
void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
|
|
int64_t value) {
|
|
arith::ConstantOp::build(builder, result, builder.getIndexType(),
|
|
builder.getIndexAttr(value));
|
|
}
|
|
|
|
bool arith::ConstantIndexOp::classof(Operation *op) {
|
|
if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
|
|
return constOp.getType().isIndex();
|
|
return false;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AddIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) {
|
|
// addi(x, 0) -> x
|
|
if (matchPattern(adaptor.getRhs(), m_Zero()))
|
|
return getLhs();
|
|
|
|
// addi(subi(a, b), b) -> a
|
|
if (auto sub = getLhs().getDefiningOp<SubIOp>())
|
|
if (getRhs() == sub.getRhs())
|
|
return sub.getLhs();
|
|
|
|
// addi(b, subi(a, b)) -> a
|
|
if (auto sub = getRhs().getDefiningOp<SubIOp>())
|
|
if (getLhs() == sub.getRhs())
|
|
return sub.getLhs();
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(
|
|
adaptor.getOperands(),
|
|
[](APInt a, const APInt &b) { return std::move(a) + b; });
|
|
}
|
|
|
|
void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
MLIRContext *context) {
|
|
patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS,
|
|
AddIMulNegativeOneRhs, AddIMulNegativeOneLhs>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AddUIExtendedOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
std::optional<SmallVector<int64_t, 4>>
|
|
arith::AddUIExtendedOp::getShapeForUnroll() {
|
|
if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
|
|
return llvm::to_vector<4>(vt.getShape());
|
|
return std::nullopt;
|
|
}
|
|
|
|
// Returns the overflow bit, assuming that `sum` is the result of unsigned
|
|
// addition of `operand` and another number.
|
|
static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) {
|
|
return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
|
|
}
|
|
|
|
LogicalResult
|
|
arith::AddUIExtendedOp::fold(FoldAdaptor adaptor,
|
|
SmallVectorImpl<OpFoldResult> &results) {
|
|
Type overflowTy = getOverflow().getType();
|
|
// addui_extended(x, 0) -> x, false
|
|
if (matchPattern(getRhs(), m_Zero())) {
|
|
Builder builder(getContext());
|
|
auto falseValue = builder.getZeroAttr(overflowTy);
|
|
|
|
results.push_back(getLhs());
|
|
results.push_back(falseValue);
|
|
return success();
|
|
}
|
|
|
|
// addui_extended(constant_a, constant_b) -> constant_sum, constant_carry
|
|
// Let the `constFoldBinaryOp` utility attempt to fold the sum of both
|
|
// operands. If that succeeds, calculate the overflow bit based on the sum
|
|
// and the first (constant) operand, `lhs`.
|
|
if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
|
|
adaptor.getOperands(),
|
|
[](APInt a, const APInt &b) { return std::move(a) + b; })) {
|
|
Attribute overflowAttr = constFoldBinaryOp<IntegerAttr>(
|
|
ArrayRef({sumAttr, adaptor.getLhs()}),
|
|
getI1SameShape(llvm::cast<TypedAttr>(sumAttr).getType()),
|
|
calculateUnsignedOverflow);
|
|
if (!overflowAttr)
|
|
return failure();
|
|
|
|
results.push_back(sumAttr);
|
|
results.push_back(overflowAttr);
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
void arith::AddUIExtendedOp::getCanonicalizationPatterns(
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
patterns.add<AddUIExtendedToAddI>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SubIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
|
|
// subi(x,x) -> 0
|
|
if (getOperand(0) == getOperand(1))
|
|
return Builder(getContext()).getZeroAttr(getType());
|
|
// subi(x,0) -> x
|
|
if (matchPattern(adaptor.getRhs(), m_Zero()))
|
|
return getLhs();
|
|
|
|
if (auto add = getLhs().getDefiningOp<AddIOp>()) {
|
|
// subi(addi(a, b), b) -> a
|
|
if (getRhs() == add.getRhs())
|
|
return add.getLhs();
|
|
// subi(addi(a, b), a) -> b
|
|
if (getRhs() == add.getLhs())
|
|
return add.getRhs();
|
|
}
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(
|
|
adaptor.getOperands(),
|
|
[](APInt a, const APInt &b) { return std::move(a) - b; });
|
|
}
|
|
|
|
void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
MLIRContext *context) {
|
|
patterns.add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
|
|
SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
|
|
SubILHSSubConstantLHS, SubISubILHSRHSLHS>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MulIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) {
|
|
// muli(x, 0) -> 0
|
|
if (matchPattern(adaptor.getRhs(), m_Zero()))
|
|
return getRhs();
|
|
// muli(x, 1) -> x
|
|
if (matchPattern(adaptor.getRhs(), m_One()))
|
|
return getLhs();
|
|
// TODO: Handle the overflow case.
|
|
|
|
// default folder
|
|
return constFoldBinaryOp<IntegerAttr>(
|
|
adaptor.getOperands(),
|
|
[](const APInt &a, const APInt &b) { return a * b; });
|
|
}
|
|
|
|
void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
MLIRContext *context) {
|
|
patterns.add<MulIMulIConstant>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MulSIExtendedOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
std::optional<SmallVector<int64_t, 4>>
|
|
arith::MulSIExtendedOp::getShapeForUnroll() {
|
|
if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
|
|
return llvm::to_vector<4>(vt.getShape());
|
|
return std::nullopt;
|
|
}
|
|
|
|
LogicalResult
|
|
arith::MulSIExtendedOp::fold(FoldAdaptor adaptor,
|
|
SmallVectorImpl<OpFoldResult> &results) {
|
|
// mulsi_extended(x, 0) -> 0, 0
|
|
if (matchPattern(adaptor.getRhs(), m_Zero())) {
|
|
Attribute zero = adaptor.getRhs();
|
|
results.push_back(zero);
|
|
results.push_back(zero);
|
|
return success();
|
|
}
|
|
|
|
// mulsi_extended(cst_a, cst_b) -> cst_low, cst_high
|
|
if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
|
|
adaptor.getOperands(),
|
|
[](const APInt &a, const APInt &b) { return a * b; })) {
|
|
// Invoke the constant fold helper again to calculate the 'high' result.
|
|
Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
|
|
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
|
|
unsigned bitWidth = a.getBitWidth();
|
|
APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2);
|
|
return fullProduct.extractBits(bitWidth, bitWidth);
|
|
});
|
|
assert(highAttr && "Unexpected constant-folding failure");
|
|
|
|
results.push_back(lowAttr);
|
|
results.push_back(highAttr);
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
void arith::MulSIExtendedOp::getCanonicalizationPatterns(
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
patterns.add<MulSIExtendedToMulI, MulSIExtendedRHSOne>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MulUIExtendedOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
std::optional<SmallVector<int64_t, 4>>
|
|
arith::MulUIExtendedOp::getShapeForUnroll() {
|
|
if (auto vt = llvm::dyn_cast<VectorType>(getType(0)))
|
|
return llvm::to_vector<4>(vt.getShape());
|
|
return std::nullopt;
|
|
}
|
|
|
|
LogicalResult
|
|
arith::MulUIExtendedOp::fold(FoldAdaptor adaptor,
|
|
SmallVectorImpl<OpFoldResult> &results) {
|
|
// mului_extended(x, 0) -> 0, 0
|
|
if (matchPattern(adaptor.getRhs(), m_Zero())) {
|
|
Attribute zero = adaptor.getRhs();
|
|
results.push_back(zero);
|
|
results.push_back(zero);
|
|
return success();
|
|
}
|
|
|
|
// mului_extended(x, 1) -> x, 0
|
|
if (matchPattern(adaptor.getRhs(), m_One())) {
|
|
Builder builder(getContext());
|
|
Attribute zero = builder.getZeroAttr(getLhs().getType());
|
|
results.push_back(getLhs());
|
|
results.push_back(zero);
|
|
return success();
|
|
}
|
|
|
|
// mului_extended(cst_a, cst_b) -> cst_low, cst_high
|
|
if (Attribute lowAttr = constFoldBinaryOp<IntegerAttr>(
|
|
adaptor.getOperands(),
|
|
[](const APInt &a, const APInt &b) { return a * b; })) {
|
|
// Invoke the constant fold helper again to calculate the 'high' result.
|
|
Attribute highAttr = constFoldBinaryOp<IntegerAttr>(
|
|
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
|
|
unsigned bitWidth = a.getBitWidth();
|
|
APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2);
|
|
return fullProduct.extractBits(bitWidth, bitWidth);
|
|
});
|
|
assert(highAttr && "Unexpected constant-folding failure");
|
|
|
|
results.push_back(lowAttr);
|
|
results.push_back(highAttr);
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
|
|
void arith::MulUIExtendedOp::getCanonicalizationPatterns(
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
patterns.add<MulUIExtendedToMulI>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DivUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
|
|
// divui (x, 1) -> x.
|
|
if (matchPattern(adaptor.getRhs(), m_One()))
|
|
return getLhs();
|
|
|
|
// Don't fold if it would require a division by zero.
|
|
bool div0 = false;
|
|
auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
|
|
[&](APInt a, const APInt &b) {
|
|
if (div0 || !b) {
|
|
div0 = true;
|
|
return a;
|
|
}
|
|
return a.udiv(b);
|
|
});
|
|
|
|
return div0 ? Attribute() : result;
|
|
}
|
|
|
|
Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
|
|
// X / 0 => UB
|
|
return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
|
|
: Speculation::NotSpeculatable;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DivSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
|
|
// divsi (x, 1) -> x.
|
|
if (matchPattern(adaptor.getRhs(), m_One()))
|
|
return getLhs();
|
|
|
|
// Don't fold if it would overflow or if it requires a division by zero.
|
|
bool overflowOrDiv0 = false;
|
|
auto result = constFoldBinaryOp<IntegerAttr>(
|
|
adaptor.getOperands(), [&](APInt a, const APInt &b) {
|
|
if (overflowOrDiv0 || !b) {
|
|
overflowOrDiv0 = true;
|
|
return a;
|
|
}
|
|
return a.sdiv_ov(b, overflowOrDiv0);
|
|
});
|
|
|
|
return overflowOrDiv0 ? Attribute() : result;
|
|
}
|
|
|
|
Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
|
|
bool mayHaveUB = true;
|
|
|
|
APInt constRHS;
|
|
// X / 0 => UB
|
|
// INT_MIN / -1 => UB
|
|
if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
|
|
mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
|
|
|
|
return mayHaveUB ? Speculation::NotSpeculatable : Speculation::Speculatable;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Ceil and floor division folding helpers
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
|
|
bool &overflow) {
|
|
// Returns (a-1)/b + 1
|
|
APInt one(a.getBitWidth(), 1, true); // Signed value 1.
|
|
APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
|
|
return val.sadd_ov(one, overflow);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CeilDivUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
|
|
// ceildivui (x, 1) -> x.
|
|
if (matchPattern(adaptor.getRhs(), m_One()))
|
|
return getLhs();
|
|
|
|
bool overflowOrDiv0 = false;
|
|
auto result = constFoldBinaryOp<IntegerAttr>(
|
|
adaptor.getOperands(), [&](APInt a, const APInt &b) {
|
|
if (overflowOrDiv0 || !b) {
|
|
overflowOrDiv0 = true;
|
|
return a;
|
|
}
|
|
APInt quotient = a.udiv(b);
|
|
if (!a.urem(b))
|
|
return quotient;
|
|
APInt one(a.getBitWidth(), 1, true);
|
|
return quotient.uadd_ov(one, overflowOrDiv0);
|
|
});
|
|
|
|
return overflowOrDiv0 ? Attribute() : result;
|
|
}
|
|
|
|
Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
|
|
// X / 0 => UB
|
|
return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
|
|
: Speculation::NotSpeculatable;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CeilDivSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
|
|
// ceildivsi (x, 1) -> x.
|
|
if (matchPattern(adaptor.getRhs(), m_One()))
|
|
return getLhs();
|
|
|
|
// Don't fold if it would overflow or if it requires a division by zero.
|
|
bool overflowOrDiv0 = false;
|
|
auto result = constFoldBinaryOp<IntegerAttr>(
|
|
adaptor.getOperands(), [&](APInt a, const APInt &b) {
|
|
if (overflowOrDiv0 || !b) {
|
|
overflowOrDiv0 = true;
|
|
return a;
|
|
}
|
|
if (!a)
|
|
return a;
|
|
// After this point we know that neither a or b are zero.
|
|
unsigned bits = a.getBitWidth();
|
|
APInt zero = APInt::getZero(bits);
|
|
bool aGtZero = a.sgt(zero);
|
|
bool bGtZero = b.sgt(zero);
|
|
if (aGtZero && bGtZero) {
|
|
// Both positive, return ceil(a, b).
|
|
return signedCeilNonnegInputs(a, b, overflowOrDiv0);
|
|
}
|
|
if (!aGtZero && !bGtZero) {
|
|
// Both negative, return ceil(-a, -b).
|
|
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
|
|
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
|
|
return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
|
|
}
|
|
if (!aGtZero && bGtZero) {
|
|
// A is negative, b is positive, return - ( -a / b).
|
|
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
|
|
APInt div = posA.sdiv_ov(b, overflowOrDiv0);
|
|
return zero.ssub_ov(div, overflowOrDiv0);
|
|
}
|
|
// A is positive, b is negative, return - (a / -b).
|
|
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
|
|
APInt div = a.sdiv_ov(posB, overflowOrDiv0);
|
|
return zero.ssub_ov(div, overflowOrDiv0);
|
|
});
|
|
|
|
return overflowOrDiv0 ? Attribute() : result;
|
|
}
|
|
|
|
Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
|
|
bool mayHaveUB = true;
|
|
|
|
APInt constRHS;
|
|
// X / 0 => UB
|
|
// INT_MIN / -1 => UB
|
|
if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
|
|
mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
|
|
|
|
return mayHaveUB ? Speculation::NotSpeculatable : Speculation::Speculatable;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// FloorDivSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
|
|
// floordivsi (x, 1) -> x.
|
|
if (matchPattern(adaptor.getRhs(), m_One()))
|
|
return getLhs();
|
|
|
|
// Don't fold if it would overflow or if it requires a division by zero.
|
|
bool overflowOrDiv0 = false;
|
|
auto result = constFoldBinaryOp<IntegerAttr>(
|
|
adaptor.getOperands(), [&](APInt a, const APInt &b) {
|
|
if (overflowOrDiv0 || !b) {
|
|
overflowOrDiv0 = true;
|
|
return a;
|
|
}
|
|
if (!a)
|
|
return a;
|
|
// After this point we know that neither a or b are zero.
|
|
unsigned bits = a.getBitWidth();
|
|
APInt zero = APInt::getZero(bits);
|
|
bool aGtZero = a.sgt(zero);
|
|
bool bGtZero = b.sgt(zero);
|
|
if (aGtZero && bGtZero) {
|
|
// Both positive, return a / b.
|
|
return a.sdiv_ov(b, overflowOrDiv0);
|
|
}
|
|
if (!aGtZero && !bGtZero) {
|
|
// Both negative, return -a / -b.
|
|
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
|
|
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
|
|
return posA.sdiv_ov(posB, overflowOrDiv0);
|
|
}
|
|
if (!aGtZero && bGtZero) {
|
|
// A is negative, b is positive, return - ceil(-a, b).
|
|
APInt posA = zero.ssub_ov(a, overflowOrDiv0);
|
|
APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
|
|
return zero.ssub_ov(ceil, overflowOrDiv0);
|
|
}
|
|
// A is positive, b is negative, return - ceil(a, -b).
|
|
APInt posB = zero.ssub_ov(b, overflowOrDiv0);
|
|
APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
|
|
return zero.ssub_ov(ceil, overflowOrDiv0);
|
|
});
|
|
|
|
return overflowOrDiv0 ? Attribute() : result;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// RemUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) {
|
|
// remui (x, 1) -> 0.
|
|
if (matchPattern(adaptor.getRhs(), m_One()))
|
|
return Builder(getContext()).getZeroAttr(getType());
|
|
|
|
// Don't fold if it would require a division by zero.
|
|
bool div0 = false;
|
|
auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
|
|
[&](APInt a, const APInt &b) {
|
|
if (div0 || b.isZero()) {
|
|
div0 = true;
|
|
return a;
|
|
}
|
|
return a.urem(b);
|
|
});
|
|
|
|
return div0 ? Attribute() : result;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// RemSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) {
|
|
// remsi (x, 1) -> 0.
|
|
if (matchPattern(adaptor.getRhs(), m_One()))
|
|
return Builder(getContext()).getZeroAttr(getType());
|
|
|
|
// Don't fold if it would require a division by zero.
|
|
bool div0 = false;
|
|
auto result = constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
|
|
[&](APInt a, const APInt &b) {
|
|
if (div0 || b.isZero()) {
|
|
div0 = true;
|
|
return a;
|
|
}
|
|
return a.srem(b);
|
|
});
|
|
|
|
return div0 ? Attribute() : result;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AndIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Fold `and(a, and(a, b))` to `and(a, b)`
|
|
static Value foldAndIofAndI(arith::AndIOp op) {
|
|
for (bool reversePrev : {false, true}) {
|
|
auto prev = (reversePrev ? op.getRhs() : op.getLhs())
|
|
.getDefiningOp<arith::AndIOp>();
|
|
if (!prev)
|
|
continue;
|
|
|
|
Value other = (reversePrev ? op.getLhs() : op.getRhs());
|
|
if (other != prev.getLhs() && other != prev.getRhs())
|
|
continue;
|
|
|
|
return prev.getResult();
|
|
}
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) {
|
|
/// and(x, 0) -> 0
|
|
if (matchPattern(adaptor.getRhs(), m_Zero()))
|
|
return getRhs();
|
|
/// and(x, allOnes) -> x
|
|
APInt intValue;
|
|
if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) &&
|
|
intValue.isAllOnes())
|
|
return getLhs();
|
|
/// and(x, not(x)) -> 0
|
|
if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
|
|
m_ConstantInt(&intValue))) &&
|
|
intValue.isAllOnes())
|
|
return Builder(getContext()).getZeroAttr(getType());
|
|
/// and(not(x), x) -> 0
|
|
if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
|
|
m_ConstantInt(&intValue))) &&
|
|
intValue.isAllOnes())
|
|
return Builder(getContext()).getZeroAttr(getType());
|
|
|
|
/// and(a, and(a, b)) -> and(a, b)
|
|
if (Value result = foldAndIofAndI(*this))
|
|
return result;
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(
|
|
adaptor.getOperands(),
|
|
[](APInt a, const APInt &b) { return std::move(a) & b; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OrIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) {
|
|
if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) {
|
|
/// or(x, 0) -> x
|
|
if (rhsVal.isZero())
|
|
return getLhs();
|
|
/// or(x, <all ones>) -> <all ones>
|
|
if (rhsVal.isAllOnes())
|
|
return adaptor.getRhs();
|
|
}
|
|
|
|
APInt intValue;
|
|
/// or(x, xor(x, 1)) -> 1
|
|
if (matchPattern(getRhs(), m_Op<XOrIOp>(matchers::m_Val(getLhs()),
|
|
m_ConstantInt(&intValue))) &&
|
|
intValue.isAllOnes())
|
|
return getRhs().getDefiningOp<XOrIOp>().getRhs();
|
|
/// or(xor(x, 1), x) -> 1
|
|
if (matchPattern(getLhs(), m_Op<XOrIOp>(matchers::m_Val(getRhs()),
|
|
m_ConstantInt(&intValue))) &&
|
|
intValue.isAllOnes())
|
|
return getLhs().getDefiningOp<XOrIOp>().getRhs();
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(
|
|
adaptor.getOperands(),
|
|
[](APInt a, const APInt &b) { return std::move(a) | b; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// XOrIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
|
|
/// xor(x, 0) -> x
|
|
if (matchPattern(adaptor.getRhs(), m_Zero()))
|
|
return getLhs();
|
|
/// xor(x, x) -> 0
|
|
if (getLhs() == getRhs())
|
|
return Builder(getContext()).getZeroAttr(getType());
|
|
/// xor(xor(x, a), a) -> x
|
|
/// xor(xor(a, x), a) -> x
|
|
if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
|
|
if (prev.getRhs() == getRhs())
|
|
return prev.getLhs();
|
|
if (prev.getLhs() == getRhs())
|
|
return prev.getRhs();
|
|
}
|
|
/// xor(a, xor(x, a)) -> x
|
|
/// xor(a, xor(a, x)) -> x
|
|
if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
|
|
if (prev.getRhs() == getLhs())
|
|
return prev.getLhs();
|
|
if (prev.getLhs() == getLhs())
|
|
return prev.getRhs();
|
|
}
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(
|
|
adaptor.getOperands(),
|
|
[](APInt a, const APInt &b) { return std::move(a) ^ b; });
|
|
}
|
|
|
|
void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
MLIRContext *context) {
|
|
patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// NegFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) {
|
|
/// negf(negf(x)) -> x
|
|
if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
|
|
return op.getOperand();
|
|
return constFoldUnaryOp<FloatAttr>(adaptor.getOperands(),
|
|
[](const APFloat &a) { return -a; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AddFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) {
|
|
// addf(x, -0) -> x
|
|
if (matchPattern(adaptor.getRhs(), m_NegZeroFloat()))
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<FloatAttr>(
|
|
adaptor.getOperands(),
|
|
[](const APFloat &a, const APFloat &b) { return a + b; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SubFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) {
|
|
// subf(x, +0) -> x
|
|
if (matchPattern(adaptor.getRhs(), m_PosZeroFloat()))
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<FloatAttr>(
|
|
adaptor.getOperands(),
|
|
[](const APFloat &a, const APFloat &b) { return a - b; });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MaximumFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) {
|
|
// maximumf(x,x) -> x
|
|
if (getLhs() == getRhs())
|
|
return getRhs();
|
|
|
|
// maximumf(x, -inf) -> x
|
|
if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<FloatAttr>(
|
|
adaptor.getOperands(),
|
|
[](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MaxNumFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) {
|
|
// maxnumf(x,x) -> x
|
|
if (getLhs() == getRhs())
|
|
return getRhs();
|
|
|
|
// maxnumf(x, -inf) -> x
|
|
if (matchPattern(adaptor.getRhs(), m_NegInfFloat()))
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<FloatAttr>(
|
|
adaptor.getOperands(),
|
|
[](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
|
|
}
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MaxSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) {
|
|
// maxsi(x,x) -> x
|
|
if (getLhs() == getRhs())
|
|
return getRhs();
|
|
|
|
if (APInt intValue;
|
|
matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
|
|
// maxsi(x,MAX_INT) -> MAX_INT
|
|
if (intValue.isMaxSignedValue())
|
|
return getRhs();
|
|
// maxsi(x, MIN_INT) -> x
|
|
if (intValue.isMinSignedValue())
|
|
return getLhs();
|
|
}
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
|
|
[](const APInt &a, const APInt &b) {
|
|
return llvm::APIntOps::smax(a, b);
|
|
});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MaxUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) {
|
|
// maxui(x,x) -> x
|
|
if (getLhs() == getRhs())
|
|
return getRhs();
|
|
|
|
if (APInt intValue;
|
|
matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
|
|
// maxui(x,MAX_INT) -> MAX_INT
|
|
if (intValue.isMaxValue())
|
|
return getRhs();
|
|
// maxui(x, MIN_INT) -> x
|
|
if (intValue.isMinValue())
|
|
return getLhs();
|
|
}
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
|
|
[](const APInt &a, const APInt &b) {
|
|
return llvm::APIntOps::umax(a, b);
|
|
});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MinimumFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) {
|
|
// minimumf(x,x) -> x
|
|
if (getLhs() == getRhs())
|
|
return getRhs();
|
|
|
|
// minimumf(x, +inf) -> x
|
|
if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<FloatAttr>(
|
|
adaptor.getOperands(),
|
|
[](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MinNumFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) {
|
|
// minnumf(x,x) -> x
|
|
if (getLhs() == getRhs())
|
|
return getRhs();
|
|
|
|
// minnumf(x, +inf) -> x
|
|
if (matchPattern(adaptor.getRhs(), m_PosInfFloat()))
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<FloatAttr>(
|
|
adaptor.getOperands(),
|
|
[](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MinSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) {
|
|
// minsi(x,x) -> x
|
|
if (getLhs() == getRhs())
|
|
return getRhs();
|
|
|
|
if (APInt intValue;
|
|
matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
|
|
// minsi(x,MIN_INT) -> MIN_INT
|
|
if (intValue.isMinSignedValue())
|
|
return getRhs();
|
|
// minsi(x, MAX_INT) -> x
|
|
if (intValue.isMaxSignedValue())
|
|
return getLhs();
|
|
}
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
|
|
[](const APInt &a, const APInt &b) {
|
|
return llvm::APIntOps::smin(a, b);
|
|
});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MinUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) {
|
|
// minui(x,x) -> x
|
|
if (getLhs() == getRhs())
|
|
return getRhs();
|
|
|
|
if (APInt intValue;
|
|
matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) {
|
|
// minui(x,MIN_INT) -> MIN_INT
|
|
if (intValue.isMinValue())
|
|
return getRhs();
|
|
// minui(x, MAX_INT) -> x
|
|
if (intValue.isMaxValue())
|
|
return getLhs();
|
|
}
|
|
|
|
return constFoldBinaryOp<IntegerAttr>(adaptor.getOperands(),
|
|
[](const APInt &a, const APInt &b) {
|
|
return llvm::APIntOps::umin(a, b);
|
|
});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MulFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
|
|
// mulf(x, 1) -> x
|
|
if (matchPattern(adaptor.getRhs(), m_OneFloat()))
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<FloatAttr>(
|
|
adaptor.getOperands(),
|
|
[](const APFloat &a, const APFloat &b) { return a * b; });
|
|
}
|
|
|
|
void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
MLIRContext *context) {
|
|
patterns.add<MulFOfNegF>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DivFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) {
|
|
// divf(x, 1) -> x
|
|
if (matchPattern(adaptor.getRhs(), m_OneFloat()))
|
|
return getLhs();
|
|
|
|
return constFoldBinaryOp<FloatAttr>(
|
|
adaptor.getOperands(),
|
|
[](const APFloat &a, const APFloat &b) { return a / b; });
|
|
}
|
|
|
|
void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
MLIRContext *context) {
|
|
patterns.add<DivFOfNegF>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// RemFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) {
|
|
return constFoldBinaryOp<FloatAttr>(adaptor.getOperands(),
|
|
[](const APFloat &a, const APFloat &b) {
|
|
APFloat result(a);
|
|
(void)result.remainder(b);
|
|
return result;
|
|
});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utility functions for verifying cast ops
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename... Types>
|
|
using type_list = std::tuple<Types...> *;
|
|
|
|
/// Returns a non-null type only if the provided type is one of the allowed
|
|
/// types or one of the allowed shaped types of the allowed types. Returns the
|
|
/// element type if a valid shaped type is provided.
|
|
template <typename... ShapedTypes, typename... ElementTypes>
|
|
static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
|
|
type_list<ElementTypes...>) {
|
|
if (llvm::isa<ShapedType>(type) && !llvm::isa<ShapedTypes...>(type))
|
|
return {};
|
|
|
|
auto underlyingType = getElementTypeOrSelf(type);
|
|
if (!llvm::isa<ElementTypes...>(underlyingType))
|
|
return {};
|
|
|
|
return underlyingType;
|
|
}
|
|
|
|
/// Get allowed underlying types for vectors and tensors.
|
|
template <typename... ElementTypes>
|
|
static Type getTypeIfLike(Type type) {
|
|
return getUnderlyingType(type, type_list<VectorType, TensorType>(),
|
|
type_list<ElementTypes...>());
|
|
}
|
|
|
|
/// Get allowed underlying types for vectors, tensors, and memrefs.
|
|
template <typename... ElementTypes>
|
|
static Type getTypeIfLikeOrMemRef(Type type) {
|
|
return getUnderlyingType(type,
|
|
type_list<VectorType, TensorType, MemRefType>(),
|
|
type_list<ElementTypes...>());
|
|
}
|
|
|
|
/// Return false if both types are ranked tensor with mismatching encoding.
|
|
static bool hasSameEncoding(Type typeA, Type typeB) {
|
|
auto rankedTensorA = dyn_cast<RankedTensorType>(typeA);
|
|
auto rankedTensorB = dyn_cast<RankedTensorType>(typeB);
|
|
if (!rankedTensorA || !rankedTensorB)
|
|
return true;
|
|
return rankedTensorA.getEncoding() == rankedTensorB.getEncoding();
|
|
}
|
|
|
|
static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
|
|
if (inputs.size() != 1 || outputs.size() != 1)
|
|
return false;
|
|
if (!hasSameEncoding(inputs.front(), outputs.front()))
|
|
return false;
|
|
return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Verifiers for integer and floating point extension/truncation ops
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Extend ops can only extend to a wider type.
|
|
template <typename ValType, typename Op>
|
|
static LogicalResult verifyExtOp(Op op) {
|
|
Type srcType = getElementTypeOrSelf(op.getIn().getType());
|
|
Type dstType = getElementTypeOrSelf(op.getType());
|
|
|
|
if (llvm::cast<ValType>(srcType).getWidth() >=
|
|
llvm::cast<ValType>(dstType).getWidth())
|
|
return op.emitError("result type ")
|
|
<< dstType << " must be wider than operand type " << srcType;
|
|
|
|
return success();
|
|
}
|
|
|
|
// Truncate ops can only truncate to a shorter type.
|
|
template <typename ValType, typename Op>
|
|
static LogicalResult verifyTruncateOp(Op op) {
|
|
Type srcType = getElementTypeOrSelf(op.getIn().getType());
|
|
Type dstType = getElementTypeOrSelf(op.getType());
|
|
|
|
if (llvm::cast<ValType>(srcType).getWidth() <=
|
|
llvm::cast<ValType>(dstType).getWidth())
|
|
return op.emitError("result type ")
|
|
<< dstType << " must be shorter than operand type " << srcType;
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Validate a cast that changes the width of a type.
|
|
template <template <typename> class WidthComparator, typename... ElementTypes>
|
|
static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
|
|
if (!areValidCastInputsAndOutputs(inputs, outputs))
|
|
return false;
|
|
|
|
auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
|
|
auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
|
|
if (!srcType || !dstType)
|
|
return false;
|
|
|
|
return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
|
|
srcType.getIntOrFloatBitWidth());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ExtUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::ExtUIOp::fold(FoldAdaptor adaptor) {
|
|
if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
|
|
getInMutable().assign(lhs.getIn());
|
|
return getResult();
|
|
}
|
|
|
|
Type resType = getElementTypeOrSelf(getType());
|
|
unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
|
|
return constFoldCastOp<IntegerAttr, IntegerAttr>(
|
|
adaptor.getOperands(), getType(),
|
|
[bitWidth](const APInt &a, bool &castStatus) {
|
|
return a.zext(bitWidth);
|
|
});
|
|
}
|
|
|
|
bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
|
|
}
|
|
|
|
LogicalResult arith::ExtUIOp::verify() {
|
|
return verifyExtOp<IntegerType>(*this);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ExtSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::ExtSIOp::fold(FoldAdaptor adaptor) {
|
|
if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
|
|
getInMutable().assign(lhs.getIn());
|
|
return getResult();
|
|
}
|
|
|
|
Type resType = getElementTypeOrSelf(getType());
|
|
unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
|
|
return constFoldCastOp<IntegerAttr, IntegerAttr>(
|
|
adaptor.getOperands(), getType(),
|
|
[bitWidth](const APInt &a, bool &castStatus) {
|
|
return a.sext(bitWidth);
|
|
});
|
|
}
|
|
|
|
bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
|
|
}
|
|
|
|
void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
MLIRContext *context) {
|
|
patterns.add<ExtSIOfExtUI>(context);
|
|
}
|
|
|
|
LogicalResult arith::ExtSIOp::verify() {
|
|
return verifyExtOp<IntegerType>(*this);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ExtFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Always fold extension of FP constants.
|
|
OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
|
|
auto constOperand = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getIn());
|
|
if (!constOperand)
|
|
return {};
|
|
|
|
// Convert to target type via 'double'.
|
|
return FloatAttr::get(getType(), constOperand.getValue().convertToDouble());
|
|
}
|
|
|
|
bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
|
|
}
|
|
|
|
LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TruncIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
|
|
if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
|
|
matchPattern(getOperand(), m_Op<arith::ExtSIOp>())) {
|
|
Value src = getOperand().getDefiningOp()->getOperand(0);
|
|
Type srcType = getElementTypeOrSelf(src.getType());
|
|
Type dstType = getElementTypeOrSelf(getType());
|
|
// trunci(zexti(a)) -> trunci(a)
|
|
// trunci(sexti(a)) -> trunci(a)
|
|
if (llvm::cast<IntegerType>(srcType).getWidth() >
|
|
llvm::cast<IntegerType>(dstType).getWidth()) {
|
|
setOperand(src);
|
|
return getResult();
|
|
}
|
|
|
|
// trunci(zexti(a)) -> a
|
|
// trunci(sexti(a)) -> a
|
|
if (srcType == dstType)
|
|
return src;
|
|
}
|
|
|
|
// trunci(trunci(a)) -> trunci(a))
|
|
if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
|
|
setOperand(getOperand().getDefiningOp()->getOperand(0));
|
|
return getResult();
|
|
}
|
|
|
|
Type resType = getElementTypeOrSelf(getType());
|
|
unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
|
|
return constFoldCastOp<IntegerAttr, IntegerAttr>(
|
|
adaptor.getOperands(), getType(),
|
|
[bitWidth](const APInt &a, bool &castStatus) {
|
|
return a.trunc(bitWidth);
|
|
});
|
|
}
|
|
|
|
bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
|
|
}
|
|
|
|
void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
MLIRContext *context) {
|
|
patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
|
|
TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
|
|
context);
|
|
}
|
|
|
|
LogicalResult arith::TruncIOp::verify() {
|
|
return verifyTruncateOp<IntegerType>(*this);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TruncFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Perform safe const propagation for truncf, i.e. only propagate if FP value
|
|
/// can be represented without precision loss or rounding.
|
|
OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
|
|
auto constOperand = adaptor.getIn();
|
|
if (!constOperand || !llvm::isa<FloatAttr>(constOperand))
|
|
return {};
|
|
|
|
// Convert to target type via 'double'.
|
|
double sourceValue =
|
|
llvm::dyn_cast<FloatAttr>(constOperand).getValue().convertToDouble();
|
|
auto targetAttr = FloatAttr::get(getType(), sourceValue);
|
|
|
|
// Propagate if constant's value does not change after truncation.
|
|
if (sourceValue == targetAttr.getValue().convertToDouble())
|
|
return targetAttr;
|
|
|
|
return {};
|
|
}
|
|
|
|
bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
|
|
}
|
|
|
|
LogicalResult arith::TruncFOp::verify() {
|
|
return verifyTruncateOp<FloatType>(*this);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AndIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
MLIRContext *context) {
|
|
patterns.add<AndOfExtUI, AndOfExtSI>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// OrIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
MLIRContext *context) {
|
|
patterns.add<OrOfExtUI, OrOfExtSI>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Verifiers for casts between integers and floats.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename From, typename To>
|
|
static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
|
|
if (!areValidCastInputsAndOutputs(inputs, outputs))
|
|
return false;
|
|
|
|
auto srcType = getTypeIfLike<From>(inputs.front());
|
|
auto dstType = getTypeIfLike<To>(outputs.back());
|
|
|
|
return srcType && dstType;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// UIToFPOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
|
|
}
|
|
|
|
OpFoldResult arith::UIToFPOp::fold(FoldAdaptor adaptor) {
|
|
Type resEleType = getElementTypeOrSelf(getType());
|
|
return constFoldCastOp<IntegerAttr, FloatAttr>(
|
|
adaptor.getOperands(), getType(),
|
|
[&resEleType](const APInt &a, bool &castStatus) {
|
|
FloatType floatTy = llvm::cast<FloatType>(resEleType);
|
|
APFloat apf(floatTy.getFloatSemantics(),
|
|
APInt::getZero(floatTy.getWidth()));
|
|
apf.convertFromAPInt(a, /*IsSigned=*/false,
|
|
APFloat::rmNearestTiesToEven);
|
|
return apf;
|
|
});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SIToFPOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
|
|
}
|
|
|
|
OpFoldResult arith::SIToFPOp::fold(FoldAdaptor adaptor) {
|
|
Type resEleType = getElementTypeOrSelf(getType());
|
|
return constFoldCastOp<IntegerAttr, FloatAttr>(
|
|
adaptor.getOperands(), getType(),
|
|
[&resEleType](const APInt &a, bool &castStatus) {
|
|
FloatType floatTy = llvm::cast<FloatType>(resEleType);
|
|
APFloat apf(floatTy.getFloatSemantics(),
|
|
APInt::getZero(floatTy.getWidth()));
|
|
apf.convertFromAPInt(a, /*IsSigned=*/true,
|
|
APFloat::rmNearestTiesToEven);
|
|
return apf;
|
|
});
|
|
}
|
|
//===----------------------------------------------------------------------===//
|
|
// FPToUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
|
|
}
|
|
|
|
OpFoldResult arith::FPToUIOp::fold(FoldAdaptor adaptor) {
|
|
Type resType = getElementTypeOrSelf(getType());
|
|
unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
|
|
return constFoldCastOp<FloatAttr, IntegerAttr>(
|
|
adaptor.getOperands(), getType(),
|
|
[&bitWidth](const APFloat &a, bool &castStatus) {
|
|
bool ignored;
|
|
APSInt api(bitWidth, /*isUnsigned=*/true);
|
|
castStatus = APFloat::opInvalidOp !=
|
|
a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
|
|
return api;
|
|
});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// FPToSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
|
|
}
|
|
|
|
OpFoldResult arith::FPToSIOp::fold(FoldAdaptor adaptor) {
|
|
Type resType = getElementTypeOrSelf(getType());
|
|
unsigned bitWidth = llvm::cast<IntegerType>(resType).getWidth();
|
|
return constFoldCastOp<FloatAttr, IntegerAttr>(
|
|
adaptor.getOperands(), getType(),
|
|
[&bitWidth](const APFloat &a, bool &castStatus) {
|
|
bool ignored;
|
|
APSInt api(bitWidth, /*isUnsigned=*/false);
|
|
castStatus = APFloat::opInvalidOp !=
|
|
a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
|
|
return api;
|
|
});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// IndexCastOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static bool areIndexCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
if (!areValidCastInputsAndOutputs(inputs, outputs))
|
|
return false;
|
|
|
|
auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
|
|
auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
|
|
if (!srcType || !dstType)
|
|
return false;
|
|
|
|
return (srcType.isIndex() && dstType.isSignlessInteger()) ||
|
|
(srcType.isSignlessInteger() && dstType.isIndex());
|
|
}
|
|
|
|
bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
|
|
TypeRange outputs) {
|
|
return areIndexCastCompatible(inputs, outputs);
|
|
}
|
|
|
|
OpFoldResult arith::IndexCastOp::fold(FoldAdaptor adaptor) {
|
|
// index_cast(constant) -> constant
|
|
unsigned resultBitwidth = 64; // Default for index integer attributes.
|
|
if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
|
|
resultBitwidth = intTy.getWidth();
|
|
|
|
return constFoldCastOp<IntegerAttr, IntegerAttr>(
|
|
adaptor.getOperands(), getType(),
|
|
[resultBitwidth](const APInt &a, bool & /*castStatus*/) {
|
|
return a.sextOrTrunc(resultBitwidth);
|
|
});
|
|
}
|
|
|
|
void arith::IndexCastOp::getCanonicalizationPatterns(
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// IndexCastUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::IndexCastUIOp::areCastCompatible(TypeRange inputs,
|
|
TypeRange outputs) {
|
|
return areIndexCastCompatible(inputs, outputs);
|
|
}
|
|
|
|
OpFoldResult arith::IndexCastUIOp::fold(FoldAdaptor adaptor) {
|
|
// index_castui(constant) -> constant
|
|
unsigned resultBitwidth = 64; // Default for index integer attributes.
|
|
if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(getType())))
|
|
resultBitwidth = intTy.getWidth();
|
|
|
|
return constFoldCastOp<IntegerAttr, IntegerAttr>(
|
|
adaptor.getOperands(), getType(),
|
|
[resultBitwidth](const APInt &a, bool & /*castStatus*/) {
|
|
return a.zextOrTrunc(resultBitwidth);
|
|
});
|
|
}
|
|
|
|
void arith::IndexCastUIOp::getCanonicalizationPatterns(
|
|
RewritePatternSet &patterns, MLIRContext *context) {
|
|
patterns.add<IndexCastUIOfIndexCastUI, IndexCastUIOfExtUI>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// BitcastOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
|
if (!areValidCastInputsAndOutputs(inputs, outputs))
|
|
return false;
|
|
|
|
auto srcType =
|
|
getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
|
|
auto dstType =
|
|
getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
|
|
if (!srcType || !dstType)
|
|
return false;
|
|
|
|
return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
|
|
}
|
|
|
|
OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) {
|
|
auto resType = getType();
|
|
auto operand = adaptor.getIn();
|
|
if (!operand)
|
|
return {};
|
|
|
|
/// Bitcast dense elements.
|
|
if (auto denseAttr = llvm::dyn_cast_or_null<DenseElementsAttr>(operand))
|
|
return denseAttr.bitcast(llvm::cast<ShapedType>(resType).getElementType());
|
|
/// Other shaped types unhandled.
|
|
if (llvm::isa<ShapedType>(resType))
|
|
return {};
|
|
|
|
/// Bitcast integer or float to integer or float.
|
|
APInt bits = llvm::isa<FloatAttr>(operand)
|
|
? llvm::cast<FloatAttr>(operand).getValue().bitcastToAPInt()
|
|
: llvm::cast<IntegerAttr>(operand).getValue();
|
|
|
|
if (auto resFloatType = llvm::dyn_cast<FloatType>(resType))
|
|
return FloatAttr::get(resType,
|
|
APFloat(resFloatType.getFloatSemantics(), bits));
|
|
return IntegerAttr::get(resType, bits);
|
|
}
|
|
|
|
void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
MLIRContext *context) {
|
|
patterns.add<BitcastOfBitcast>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CmpIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
|
|
/// comparison predicates.
|
|
bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
|
|
const APInt &lhs, const APInt &rhs) {
|
|
switch (predicate) {
|
|
case arith::CmpIPredicate::eq:
|
|
return lhs.eq(rhs);
|
|
case arith::CmpIPredicate::ne:
|
|
return lhs.ne(rhs);
|
|
case arith::CmpIPredicate::slt:
|
|
return lhs.slt(rhs);
|
|
case arith::CmpIPredicate::sle:
|
|
return lhs.sle(rhs);
|
|
case arith::CmpIPredicate::sgt:
|
|
return lhs.sgt(rhs);
|
|
case arith::CmpIPredicate::sge:
|
|
return lhs.sge(rhs);
|
|
case arith::CmpIPredicate::ult:
|
|
return lhs.ult(rhs);
|
|
case arith::CmpIPredicate::ule:
|
|
return lhs.ule(rhs);
|
|
case arith::CmpIPredicate::ugt:
|
|
return lhs.ugt(rhs);
|
|
case arith::CmpIPredicate::uge:
|
|
return lhs.uge(rhs);
|
|
}
|
|
llvm_unreachable("unknown cmpi predicate kind");
|
|
}
|
|
|
|
/// Returns true if the predicate is true for two equal operands.
|
|
static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
|
|
switch (predicate) {
|
|
case arith::CmpIPredicate::eq:
|
|
case arith::CmpIPredicate::sle:
|
|
case arith::CmpIPredicate::sge:
|
|
case arith::CmpIPredicate::ule:
|
|
case arith::CmpIPredicate::uge:
|
|
return true;
|
|
case arith::CmpIPredicate::ne:
|
|
case arith::CmpIPredicate::slt:
|
|
case arith::CmpIPredicate::sgt:
|
|
case arith::CmpIPredicate::ult:
|
|
case arith::CmpIPredicate::ugt:
|
|
return false;
|
|
}
|
|
llvm_unreachable("unknown cmpi predicate kind");
|
|
}
|
|
|
|
static std::optional<int64_t> getIntegerWidth(Type t) {
|
|
if (auto intType = llvm::dyn_cast<IntegerType>(t)) {
|
|
return intType.getWidth();
|
|
}
|
|
if (auto vectorIntType = llvm::dyn_cast<VectorType>(t)) {
|
|
return llvm::cast<IntegerType>(vectorIntType.getElementType()).getWidth();
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
|
|
// cmpi(pred, x, x)
|
|
if (getLhs() == getRhs()) {
|
|
auto val = applyCmpPredicateToEqualOperands(getPredicate());
|
|
return getBoolAttribute(getType(), val);
|
|
}
|
|
|
|
if (matchPattern(adaptor.getRhs(), m_Zero())) {
|
|
if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
|
|
// extsi(%x : i1 -> iN) != 0 -> %x
|
|
std::optional<int64_t> integerWidth =
|
|
getIntegerWidth(extOp.getOperand().getType());
|
|
if (integerWidth && integerWidth.value() == 1 &&
|
|
getPredicate() == arith::CmpIPredicate::ne)
|
|
return extOp.getOperand();
|
|
}
|
|
if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
|
|
// extui(%x : i1 -> iN) != 0 -> %x
|
|
std::optional<int64_t> integerWidth =
|
|
getIntegerWidth(extOp.getOperand().getType());
|
|
if (integerWidth && integerWidth.value() == 1 &&
|
|
getPredicate() == arith::CmpIPredicate::ne)
|
|
return extOp.getOperand();
|
|
}
|
|
}
|
|
|
|
// Move constant to the right side.
|
|
if (adaptor.getLhs() && !adaptor.getRhs()) {
|
|
// Do not use invertPredicate, as it will change eq to ne and vice versa.
|
|
using Pred = CmpIPredicate;
|
|
const std::pair<Pred, Pred> invPreds[] = {
|
|
{Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
|
|
{Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
|
|
{Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
|
|
{Pred::ne, Pred::ne},
|
|
};
|
|
Pred origPred = getPredicate();
|
|
for (auto pred : invPreds) {
|
|
if (origPred == pred.first) {
|
|
setPredicate(pred.second);
|
|
Value lhs = getLhs();
|
|
Value rhs = getRhs();
|
|
getLhsMutable().assign(rhs);
|
|
getRhsMutable().assign(lhs);
|
|
return getResult();
|
|
}
|
|
}
|
|
llvm_unreachable("unknown cmpi predicate kind");
|
|
}
|
|
|
|
// We are moving constants to the right side; So if lhs is constant rhs is
|
|
// guaranteed to be a constant.
|
|
if (auto lhs = llvm::dyn_cast_if_present<TypedAttr>(adaptor.getLhs())) {
|
|
return constFoldBinaryOp<IntegerAttr>(
|
|
adaptor.getOperands(), getI1SameShape(lhs.getType()),
|
|
[pred = getPredicate()](const APInt &lhs, const APInt &rhs) {
|
|
return APInt(1,
|
|
static_cast<int64_t>(applyCmpPredicate(pred, lhs, rhs)));
|
|
});
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
MLIRContext *context) {
|
|
patterns.insert<CmpIExtSI, CmpIExtUI>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CmpFOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
|
|
/// comparison predicates.
|
|
bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
|
|
const APFloat &lhs, const APFloat &rhs) {
|
|
auto cmpResult = lhs.compare(rhs);
|
|
switch (predicate) {
|
|
case arith::CmpFPredicate::AlwaysFalse:
|
|
return false;
|
|
case arith::CmpFPredicate::OEQ:
|
|
return cmpResult == APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::OGT:
|
|
return cmpResult == APFloat::cmpGreaterThan;
|
|
case arith::CmpFPredicate::OGE:
|
|
return cmpResult == APFloat::cmpGreaterThan ||
|
|
cmpResult == APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::OLT:
|
|
return cmpResult == APFloat::cmpLessThan;
|
|
case arith::CmpFPredicate::OLE:
|
|
return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::ONE:
|
|
return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::ORD:
|
|
return cmpResult != APFloat::cmpUnordered;
|
|
case arith::CmpFPredicate::UEQ:
|
|
return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::UGT:
|
|
return cmpResult == APFloat::cmpUnordered ||
|
|
cmpResult == APFloat::cmpGreaterThan;
|
|
case arith::CmpFPredicate::UGE:
|
|
return cmpResult == APFloat::cmpUnordered ||
|
|
cmpResult == APFloat::cmpGreaterThan ||
|
|
cmpResult == APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::ULT:
|
|
return cmpResult == APFloat::cmpUnordered ||
|
|
cmpResult == APFloat::cmpLessThan;
|
|
case arith::CmpFPredicate::ULE:
|
|
return cmpResult == APFloat::cmpUnordered ||
|
|
cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::UNE:
|
|
return cmpResult != APFloat::cmpEqual;
|
|
case arith::CmpFPredicate::UNO:
|
|
return cmpResult == APFloat::cmpUnordered;
|
|
case arith::CmpFPredicate::AlwaysTrue:
|
|
return true;
|
|
}
|
|
llvm_unreachable("unknown cmpf predicate kind");
|
|
}
|
|
|
|
OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) {
|
|
auto lhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getLhs());
|
|
auto rhs = llvm::dyn_cast_if_present<FloatAttr>(adaptor.getRhs());
|
|
|
|
// If one operand is NaN, making them both NaN does not change the result.
|
|
if (lhs && lhs.getValue().isNaN())
|
|
rhs = lhs;
|
|
if (rhs && rhs.getValue().isNaN())
|
|
lhs = rhs;
|
|
|
|
if (!lhs || !rhs)
|
|
return {};
|
|
|
|
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
|
|
return BoolAttr::get(getContext(), val);
|
|
}
|
|
|
|
class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
|
|
public:
|
|
using OpRewritePattern<CmpFOp>::OpRewritePattern;
|
|
|
|
static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
|
|
bool isUnsigned) {
|
|
using namespace arith;
|
|
switch (pred) {
|
|
case CmpFPredicate::UEQ:
|
|
case CmpFPredicate::OEQ:
|
|
return CmpIPredicate::eq;
|
|
case CmpFPredicate::UGT:
|
|
case CmpFPredicate::OGT:
|
|
return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
|
|
case CmpFPredicate::UGE:
|
|
case CmpFPredicate::OGE:
|
|
return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
|
|
case CmpFPredicate::ULT:
|
|
case CmpFPredicate::OLT:
|
|
return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
|
|
case CmpFPredicate::ULE:
|
|
case CmpFPredicate::OLE:
|
|
return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
|
|
case CmpFPredicate::UNE:
|
|
case CmpFPredicate::ONE:
|
|
return CmpIPredicate::ne;
|
|
default:
|
|
llvm_unreachable("Unexpected predicate!");
|
|
}
|
|
}
|
|
|
|
LogicalResult matchAndRewrite(CmpFOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
FloatAttr flt;
|
|
if (!matchPattern(op.getRhs(), m_Constant(&flt)))
|
|
return failure();
|
|
|
|
const APFloat &rhs = flt.getValue();
|
|
|
|
// Don't attempt to fold a nan.
|
|
if (rhs.isNaN())
|
|
return failure();
|
|
|
|
// Get the width of the mantissa. We don't want to hack on conversions that
|
|
// might lose information from the integer, e.g. "i64 -> float"
|
|
FloatType floatTy = llvm::cast<FloatType>(op.getRhs().getType());
|
|
int mantissaWidth = floatTy.getFPMantissaWidth();
|
|
if (mantissaWidth <= 0)
|
|
return failure();
|
|
|
|
bool isUnsigned;
|
|
Value intVal;
|
|
|
|
if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
|
|
isUnsigned = false;
|
|
intVal = si.getIn();
|
|
} else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
|
|
isUnsigned = true;
|
|
intVal = ui.getIn();
|
|
} else {
|
|
return failure();
|
|
}
|
|
|
|
// Check to see that the input is converted from an integer type that is
|
|
// small enough that preserves all bits.
|
|
auto intTy = llvm::cast<IntegerType>(intVal.getType());
|
|
auto intWidth = intTy.getWidth();
|
|
|
|
// Number of bits representing values, as opposed to the sign
|
|
auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
|
|
|
|
// Following test does NOT adjust intWidth downwards for signed inputs,
|
|
// because the most negative value still requires all the mantissa bits
|
|
// to distinguish it from one less than that value.
|
|
if ((int)intWidth > mantissaWidth) {
|
|
// Conversion would lose accuracy. Check if loss can impact comparison.
|
|
int exponent = ilogb(rhs);
|
|
if (exponent == APFloat::IEK_Inf) {
|
|
int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
|
|
if (maxExponent < (int)valueBits) {
|
|
// Conversion could create infinity.
|
|
return failure();
|
|
}
|
|
} else {
|
|
// Note that if rhs is zero or NaN, then Exp is negative
|
|
// and first condition is trivially false.
|
|
if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
|
|
// Conversion could affect comparison.
|
|
return failure();
|
|
}
|
|
}
|
|
}
|
|
|
|
// Convert to equivalent cmpi predicate
|
|
CmpIPredicate pred;
|
|
switch (op.getPredicate()) {
|
|
case CmpFPredicate::ORD:
|
|
// Int to fp conversion doesn't create a nan (ord checks neither is a nan)
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
|
|
/*width=*/1);
|
|
return success();
|
|
case CmpFPredicate::UNO:
|
|
// Int to fp conversion doesn't create a nan (uno checks either is a nan)
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
|
|
/*width=*/1);
|
|
return success();
|
|
default:
|
|
pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
|
|
break;
|
|
}
|
|
|
|
if (!isUnsigned) {
|
|
// If the rhs value is > SignedMax, fold the comparison. This handles
|
|
// +INF and large values.
|
|
APFloat signedMax(rhs.getSemantics());
|
|
signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
|
|
APFloat::rmNearestTiesToEven);
|
|
if (signedMax < rhs) { // smax < 13123.0
|
|
if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
|
|
pred == CmpIPredicate::sle)
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
|
|
/*width=*/1);
|
|
else
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
|
|
/*width=*/1);
|
|
return success();
|
|
}
|
|
} else {
|
|
// If the rhs value is > UnsignedMax, fold the comparison. This handles
|
|
// +INF and large values.
|
|
APFloat unsignedMax(rhs.getSemantics());
|
|
unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
|
|
APFloat::rmNearestTiesToEven);
|
|
if (unsignedMax < rhs) { // umax < 13123.0
|
|
if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
|
|
pred == CmpIPredicate::ule)
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
|
|
/*width=*/1);
|
|
else
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
|
|
/*width=*/1);
|
|
return success();
|
|
}
|
|
}
|
|
|
|
if (!isUnsigned) {
|
|
// See if the rhs value is < SignedMin.
|
|
APFloat signedMin(rhs.getSemantics());
|
|
signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
|
|
APFloat::rmNearestTiesToEven);
|
|
if (signedMin > rhs) { // smin > 12312.0
|
|
if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
|
|
pred == CmpIPredicate::sge)
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
|
|
/*width=*/1);
|
|
else
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
|
|
/*width=*/1);
|
|
return success();
|
|
}
|
|
} else {
|
|
// See if the rhs value is < UnsignedMin.
|
|
APFloat unsignedMin(rhs.getSemantics());
|
|
unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
|
|
APFloat::rmNearestTiesToEven);
|
|
if (unsignedMin > rhs) { // umin > 12312.0
|
|
if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
|
|
pred == CmpIPredicate::uge)
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
|
|
/*width=*/1);
|
|
else
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
|
|
/*width=*/1);
|
|
return success();
|
|
}
|
|
}
|
|
|
|
// Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
|
|
// [0, UMAX], but it may still be fractional. See if it is fractional by
|
|
// casting the FP value to the integer value and back, checking for
|
|
// equality. Don't do this for zero, because -0.0 is not fractional.
|
|
bool ignored;
|
|
APSInt rhsInt(intWidth, isUnsigned);
|
|
if (APFloat::opInvalidOp ==
|
|
rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
|
|
// Undefined behavior invoked - the destination type can't represent
|
|
// the input constant.
|
|
return failure();
|
|
}
|
|
|
|
if (!rhs.isZero()) {
|
|
APFloat apf(floatTy.getFloatSemantics(),
|
|
APInt::getZero(floatTy.getWidth()));
|
|
apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
|
|
|
|
bool equal = apf == rhs;
|
|
if (!equal) {
|
|
// If we had a comparison against a fractional value, we have to adjust
|
|
// the compare predicate and sometimes the value. rhsInt is rounded
|
|
// towards zero at this point.
|
|
switch (pred) {
|
|
case CmpIPredicate::ne: // (float)int != 4.4 --> true
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
|
|
/*width=*/1);
|
|
return success();
|
|
case CmpIPredicate::eq: // (float)int == 4.4 --> false
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
|
|
/*width=*/1);
|
|
return success();
|
|
case CmpIPredicate::ule:
|
|
// (float)int <= 4.4 --> int <= 4
|
|
// (float)int <= -4.4 --> false
|
|
if (rhs.isNegative()) {
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
|
|
/*width=*/1);
|
|
return success();
|
|
}
|
|
break;
|
|
case CmpIPredicate::sle:
|
|
// (float)int <= 4.4 --> int <= 4
|
|
// (float)int <= -4.4 --> int < -4
|
|
if (rhs.isNegative())
|
|
pred = CmpIPredicate::slt;
|
|
break;
|
|
case CmpIPredicate::ult:
|
|
// (float)int < -4.4 --> false
|
|
// (float)int < 4.4 --> int <= 4
|
|
if (rhs.isNegative()) {
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
|
|
/*width=*/1);
|
|
return success();
|
|
}
|
|
pred = CmpIPredicate::ule;
|
|
break;
|
|
case CmpIPredicate::slt:
|
|
// (float)int < -4.4 --> int < -4
|
|
// (float)int < 4.4 --> int <= 4
|
|
if (!rhs.isNegative())
|
|
pred = CmpIPredicate::sle;
|
|
break;
|
|
case CmpIPredicate::ugt:
|
|
// (float)int > 4.4 --> int > 4
|
|
// (float)int > -4.4 --> true
|
|
if (rhs.isNegative()) {
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
|
|
/*width=*/1);
|
|
return success();
|
|
}
|
|
break;
|
|
case CmpIPredicate::sgt:
|
|
// (float)int > 4.4 --> int > 4
|
|
// (float)int > -4.4 --> int >= -4
|
|
if (rhs.isNegative())
|
|
pred = CmpIPredicate::sge;
|
|
break;
|
|
case CmpIPredicate::uge:
|
|
// (float)int >= -4.4 --> true
|
|
// (float)int >= 4.4 --> int > 4
|
|
if (rhs.isNegative()) {
|
|
rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
|
|
/*width=*/1);
|
|
return success();
|
|
}
|
|
pred = CmpIPredicate::ugt;
|
|
break;
|
|
case CmpIPredicate::sge:
|
|
// (float)int >= -4.4 --> int >= -4
|
|
// (float)int >= 4.4 --> int > 4
|
|
if (!rhs.isNegative())
|
|
pred = CmpIPredicate::sgt;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Lower this FP comparison into an appropriate integer version of the
|
|
// comparison.
|
|
rewriter.replaceOpWithNewOp<CmpIOp>(
|
|
op, pred, intVal,
|
|
rewriter.create<ConstantOp>(
|
|
op.getLoc(), intVal.getType(),
|
|
rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|
MLIRContext *context) {
|
|
patterns.insert<CmpFIntToFPConst>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SelectOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Transforms a select of a boolean to arithmetic operations
|
|
//
|
|
// arith.select %arg, %x, %y : i1
|
|
//
|
|
// becomes
|
|
//
|
|
// and(%arg, %x) or and(!%arg, %y)
|
|
struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
|
|
using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(arith::SelectOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!op.getType().isInteger(1))
|
|
return failure();
|
|
|
|
Value falseConstant =
|
|
rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
|
|
Value notCondition = rewriter.create<arith::XOrIOp>(
|
|
op.getLoc(), op.getCondition(), falseConstant);
|
|
|
|
Value trueVal = rewriter.create<arith::AndIOp>(
|
|
op.getLoc(), op.getCondition(), op.getTrueValue());
|
|
Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
|
|
op.getFalseValue());
|
|
rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// select %arg, %c1, %c0 => extui %arg
|
|
struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
|
|
using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(arith::SelectOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
// Cannot extui i1 to i1, or i1 to f32
|
|
if (!llvm::isa<IntegerType>(op.getType()) || op.getType().isInteger(1))
|
|
return failure();
|
|
|
|
// select %x, c1, %c0 => extui %arg
|
|
if (matchPattern(op.getTrueValue(), m_One()) &&
|
|
matchPattern(op.getFalseValue(), m_Zero())) {
|
|
rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
|
|
op.getCondition());
|
|
return success();
|
|
}
|
|
|
|
// select %x, c0, %c1 => extui (xor %arg, true)
|
|
if (matchPattern(op.getTrueValue(), m_Zero()) &&
|
|
matchPattern(op.getFalseValue(), m_One())) {
|
|
rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
|
|
op, op.getType(),
|
|
rewriter.create<arith::XOrIOp>(
|
|
op.getLoc(), op.getCondition(),
|
|
rewriter.create<arith::ConstantIntOp>(
|
|
op.getLoc(), 1, op.getCondition().getType())));
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<RedundantSelectFalse, RedundantSelectTrue, SelectI1Simplify,
|
|
SelectAndCond, SelectAndNotCond, SelectOrCond, SelectOrNotCond,
|
|
SelectNotCond, SelectToExtUI>(context);
|
|
}
|
|
|
|
OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) {
|
|
Value trueVal = getTrueValue();
|
|
Value falseVal = getFalseValue();
|
|
if (trueVal == falseVal)
|
|
return trueVal;
|
|
|
|
Value condition = getCondition();
|
|
|
|
// select true, %0, %1 => %0
|
|
if (matchPattern(adaptor.getCondition(), m_One()))
|
|
return trueVal;
|
|
|
|
// select false, %0, %1 => %1
|
|
if (matchPattern(adaptor.getCondition(), m_Zero()))
|
|
return falseVal;
|
|
|
|
// If either operand is fully poisoned, return the other.
|
|
if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getTrueValue()))
|
|
return falseVal;
|
|
|
|
if (isa_and_nonnull<ub::PoisonAttr>(adaptor.getFalseValue()))
|
|
return trueVal;
|
|
|
|
// select %x, true, false => %x
|
|
if (getType().isInteger(1) && matchPattern(adaptor.getTrueValue(), m_One()) &&
|
|
matchPattern(adaptor.getFalseValue(), m_Zero()))
|
|
return condition;
|
|
|
|
if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
|
|
auto pred = cmp.getPredicate();
|
|
if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
|
|
auto cmpLhs = cmp.getLhs();
|
|
auto cmpRhs = cmp.getRhs();
|
|
|
|
// %0 = arith.cmpi eq, %arg0, %arg1
|
|
// %1 = arith.select %0, %arg0, %arg1 => %arg1
|
|
|
|
// %0 = arith.cmpi ne, %arg0, %arg1
|
|
// %1 = arith.select %0, %arg0, %arg1 => %arg0
|
|
|
|
if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
|
|
(cmpRhs == trueVal && cmpLhs == falseVal))
|
|
return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
|
|
}
|
|
}
|
|
|
|
// Constant-fold constant operands over non-splat constant condition.
|
|
// select %cst_vec, %cst0, %cst1 => %cst2
|
|
if (auto cond =
|
|
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getCondition())) {
|
|
if (auto lhs =
|
|
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getTrueValue())) {
|
|
if (auto rhs =
|
|
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getFalseValue())) {
|
|
SmallVector<Attribute> results;
|
|
results.reserve(static_cast<size_t>(cond.getNumElements()));
|
|
auto condVals = llvm::make_range(cond.value_begin<BoolAttr>(),
|
|
cond.value_end<BoolAttr>());
|
|
auto lhsVals = llvm::make_range(lhs.value_begin<Attribute>(),
|
|
lhs.value_end<Attribute>());
|
|
auto rhsVals = llvm::make_range(rhs.value_begin<Attribute>(),
|
|
rhs.value_end<Attribute>());
|
|
|
|
for (auto [condVal, lhsVal, rhsVal] :
|
|
llvm::zip_equal(condVals, lhsVals, rhsVals))
|
|
results.push_back(condVal.getValue() ? lhsVal : rhsVal);
|
|
|
|
return DenseElementsAttr::get(lhs.getType(), results);
|
|
}
|
|
}
|
|
}
|
|
|
|
return nullptr;
|
|
}
|
|
|
|
ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
Type conditionType, resultType;
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
|
|
if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
parser.parseColonType(resultType))
|
|
return failure();
|
|
|
|
// Check for the explicit condition type if this is a masked tensor or vector.
|
|
if (succeeded(parser.parseOptionalComma())) {
|
|
conditionType = resultType;
|
|
if (parser.parseType(resultType))
|
|
return failure();
|
|
} else {
|
|
conditionType = parser.getBuilder().getI1Type();
|
|
}
|
|
|
|
result.addTypes(resultType);
|
|
return parser.resolveOperands(operands,
|
|
{conditionType, resultType, resultType},
|
|
parser.getNameLoc(), result.operands);
|
|
}
|
|
|
|
void arith::SelectOp::print(OpAsmPrinter &p) {
|
|
p << " " << getOperands();
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
|
p << " : ";
|
|
if (ShapedType condType =
|
|
llvm::dyn_cast<ShapedType>(getCondition().getType()))
|
|
p << condType << ", ";
|
|
p << getType();
|
|
}
|
|
|
|
LogicalResult arith::SelectOp::verify() {
|
|
Type conditionType = getCondition().getType();
|
|
if (conditionType.isSignlessInteger(1))
|
|
return success();
|
|
|
|
// If the result type is a vector or tensor, the type can be a mask with the
|
|
// same elements.
|
|
Type resultType = getType();
|
|
if (!llvm::isa<TensorType, VectorType>(resultType))
|
|
return emitOpError() << "expected condition to be a signless i1, but got "
|
|
<< conditionType;
|
|
Type shapedConditionType = getI1SameShape(resultType);
|
|
if (conditionType != shapedConditionType) {
|
|
return emitOpError() << "expected condition type to have the same shape "
|
|
"as the result type, expected "
|
|
<< shapedConditionType << ", but got "
|
|
<< conditionType;
|
|
}
|
|
return success();
|
|
}
|
|
//===----------------------------------------------------------------------===//
|
|
// ShLIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
|
|
// shli(x, 0) -> x
|
|
if (matchPattern(adaptor.getRhs(), m_Zero()))
|
|
return getLhs();
|
|
// Don't fold if shifting more than the bit width.
|
|
bool bounded = false;
|
|
auto result = constFoldBinaryOp<IntegerAttr>(
|
|
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
|
|
bounded = b.ule(b.getBitWidth());
|
|
return a.shl(b);
|
|
});
|
|
return bounded ? result : Attribute();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ShRUIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
|
|
// shrui(x, 0) -> x
|
|
if (matchPattern(adaptor.getRhs(), m_Zero()))
|
|
return getLhs();
|
|
// Don't fold if shifting more than the bit width.
|
|
bool bounded = false;
|
|
auto result = constFoldBinaryOp<IntegerAttr>(
|
|
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
|
|
bounded = b.ule(b.getBitWidth());
|
|
return a.lshr(b);
|
|
});
|
|
return bounded ? result : Attribute();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ShRSIOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
|
|
// shrsi(x, 0) -> x
|
|
if (matchPattern(adaptor.getRhs(), m_Zero()))
|
|
return getLhs();
|
|
// Don't fold if shifting more than the bit width.
|
|
bool bounded = false;
|
|
auto result = constFoldBinaryOp<IntegerAttr>(
|
|
adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
|
|
bounded = b.ule(b.getBitWidth());
|
|
return a.ashr(b);
|
|
});
|
|
return bounded ? result : Attribute();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Atomic Enum
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Returns the identity value attribute associated with an AtomicRMWKind op.
|
|
TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
|
|
OpBuilder &builder, Location loc,
|
|
bool useOnlyFiniteValue) {
|
|
switch (kind) {
|
|
case AtomicRMWKind::maximumf: {
|
|
const llvm::fltSemantics &semantic =
|
|
llvm::cast<FloatType>(resultType).getFloatSemantics();
|
|
APFloat identity = useOnlyFiniteValue
|
|
? APFloat::getLargest(semantic, /*Negative=*/true)
|
|
: APFloat::getInf(semantic, /*Negative=*/true);
|
|
return builder.getFloatAttr(resultType, identity);
|
|
}
|
|
case AtomicRMWKind::addf:
|
|
case AtomicRMWKind::addi:
|
|
case AtomicRMWKind::maxu:
|
|
case AtomicRMWKind::ori:
|
|
return builder.getZeroAttr(resultType);
|
|
case AtomicRMWKind::andi:
|
|
return builder.getIntegerAttr(
|
|
resultType,
|
|
APInt::getAllOnes(llvm::cast<IntegerType>(resultType).getWidth()));
|
|
case AtomicRMWKind::maxs:
|
|
return builder.getIntegerAttr(
|
|
resultType, APInt::getSignedMinValue(
|
|
llvm::cast<IntegerType>(resultType).getWidth()));
|
|
case AtomicRMWKind::minimumf: {
|
|
const llvm::fltSemantics &semantic =
|
|
llvm::cast<FloatType>(resultType).getFloatSemantics();
|
|
APFloat identity = useOnlyFiniteValue
|
|
? APFloat::getLargest(semantic, /*Negative=*/false)
|
|
: APFloat::getInf(semantic, /*Negative=*/false);
|
|
|
|
return builder.getFloatAttr(resultType, identity);
|
|
}
|
|
case AtomicRMWKind::mins:
|
|
return builder.getIntegerAttr(
|
|
resultType, APInt::getSignedMaxValue(
|
|
llvm::cast<IntegerType>(resultType).getWidth()));
|
|
case AtomicRMWKind::minu:
|
|
return builder.getIntegerAttr(
|
|
resultType,
|
|
APInt::getMaxValue(llvm::cast<IntegerType>(resultType).getWidth()));
|
|
case AtomicRMWKind::muli:
|
|
return builder.getIntegerAttr(resultType, 1);
|
|
case AtomicRMWKind::mulf:
|
|
return builder.getFloatAttr(resultType, 1);
|
|
// TODO: Add remaining reduction operations.
|
|
default:
|
|
(void)emitOptionalError(loc, "Reduction operation type not supported");
|
|
break;
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
/// Return the identity numeric value associated to the give op.
|
|
std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
|
|
std::optional<AtomicRMWKind> maybeKind =
|
|
llvm::TypeSwitch<Operation *, std::optional<AtomicRMWKind>>(op)
|
|
// Floating-point operations.
|
|
.Case([](arith::AddFOp op) { return AtomicRMWKind::addf; })
|
|
.Case([](arith::MulFOp op) { return AtomicRMWKind::mulf; })
|
|
.Case([](arith::MaximumFOp op) { return AtomicRMWKind::maximumf; })
|
|
.Case([](arith::MinimumFOp op) { return AtomicRMWKind::minimumf; })
|
|
// Integer operations.
|
|
.Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
|
|
.Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
|
|
.Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
|
|
.Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
|
|
.Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
|
|
.Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
|
|
.Case([](arith::MaxSIOp op) { return AtomicRMWKind::maxs; })
|
|
.Case([](arith::MinSIOp op) { return AtomicRMWKind::mins; })
|
|
.Case([](arith::MulIOp op) { return AtomicRMWKind::muli; })
|
|
.Default([](Operation *op) { return std::nullopt; });
|
|
if (!maybeKind) {
|
|
op->emitError() << "Unknown neutral element for: " << *op;
|
|
return std::nullopt;
|
|
}
|
|
|
|
bool useOnlyFiniteValue = false;
|
|
auto fmfOpInterface = dyn_cast<ArithFastMathInterface>(op);
|
|
if (fmfOpInterface) {
|
|
arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr();
|
|
useOnlyFiniteValue =
|
|
bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf);
|
|
}
|
|
|
|
// Builder only used as helper for attribute creation.
|
|
OpBuilder b(op->getContext());
|
|
Type resultType = op->getResult(0).getType();
|
|
|
|
return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(),
|
|
useOnlyFiniteValue);
|
|
}
|
|
|
|
/// Returns the identity value associated with an AtomicRMWKind op.
|
|
Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
|
|
OpBuilder &builder, Location loc,
|
|
bool useOnlyFiniteValue) {
|
|
auto attr =
|
|
getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue);
|
|
return builder.create<arith::ConstantOp>(loc, attr);
|
|
}
|
|
|
|
/// Return the value obtained by applying the reduction operation kind
|
|
/// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
|
|
Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
|
|
Location loc, Value lhs, Value rhs) {
|
|
switch (op) {
|
|
case AtomicRMWKind::addf:
|
|
return builder.create<arith::AddFOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::addi:
|
|
return builder.create<arith::AddIOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::mulf:
|
|
return builder.create<arith::MulFOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::muli:
|
|
return builder.create<arith::MulIOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::maximumf:
|
|
return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::minimumf:
|
|
return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::maxnumf:
|
|
return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::minnumf:
|
|
return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::maxs:
|
|
return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::mins:
|
|
return builder.create<arith::MinSIOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::maxu:
|
|
return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::minu:
|
|
return builder.create<arith::MinUIOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::ori:
|
|
return builder.create<arith::OrIOp>(loc, lhs, rhs);
|
|
case AtomicRMWKind::andi:
|
|
return builder.create<arith::AndIOp>(loc, lhs, rhs);
|
|
// TODO: Add remaining reduction operations.
|
|
default:
|
|
(void)emitOptionalError(loc, "Reduction operation type not supported");
|
|
break;
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen'd op method definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc"
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen'd enum attribute definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Arith/IR/ArithOpsEnums.cpp.inc"
|