//===- ArithOps.cpp - MLIR Arith dialect ops implementation -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include #include #include #include #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::arith; //===----------------------------------------------------------------------===// // Pattern helpers //===----------------------------------------------------------------------===// static IntegerAttr applyToIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs, function_ref binFn) { APInt lhsVal = llvm::cast(lhs).getValue(); APInt rhsVal = llvm::cast(rhs).getValue(); APInt value = binFn(lhsVal, rhsVal); return IntegerAttr::get(res.getType(), value); } static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { return applyToIntegerAttrs(builder, res, lhs, rhs, std::plus()); } static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { return applyToIntegerAttrs(builder, res, lhs, rhs, std::minus()); } static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res, Attribute lhs, Attribute rhs) { return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies()); } static IntegerOverflowFlagsAttr getDefOverflowFlags(OpBuilder &builder) { return IntegerOverflowFlagsAttr::get(builder.getContext(), IntegerOverflowFlags::none); } /// Invert an integer comparison predicate. arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) { switch (pred) { case arith::CmpIPredicate::eq: return arith::CmpIPredicate::ne; case arith::CmpIPredicate::ne: return arith::CmpIPredicate::eq; case arith::CmpIPredicate::slt: return arith::CmpIPredicate::sge; case arith::CmpIPredicate::sle: return arith::CmpIPredicate::sgt; case arith::CmpIPredicate::sgt: return arith::CmpIPredicate::sle; case arith::CmpIPredicate::sge: return arith::CmpIPredicate::slt; case arith::CmpIPredicate::ult: return arith::CmpIPredicate::uge; case arith::CmpIPredicate::ule: return arith::CmpIPredicate::ugt; case arith::CmpIPredicate::ugt: return arith::CmpIPredicate::ule; case arith::CmpIPredicate::uge: return arith::CmpIPredicate::ult; } llvm_unreachable("unknown cmpi predicate kind"); } static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) { return arith::CmpIPredicateAttr::get(pred.getContext(), invertPredicate(pred.getValue())); } static int64_t getScalarOrElementWidth(Type type) { Type elemTy = getElementTypeOrSelf(type); if (elemTy.isIntOrFloat()) return elemTy.getIntOrFloatBitWidth(); return -1; } static int64_t getScalarOrElementWidth(Value value) { return getScalarOrElementWidth(value.getType()); } static FailureOr getIntOrSplatIntValue(Attribute attr) { APInt value; if (matchPattern(attr, m_ConstantInt(&value))) return value; return failure(); } static Attribute getBoolAttribute(Type type, bool value) { auto boolAttr = BoolAttr::get(type.getContext(), value); ShapedType shapedType = llvm::dyn_cast_or_null(type); if (!shapedType) return boolAttr; return DenseElementsAttr::get(shapedType, boolAttr); } //===----------------------------------------------------------------------===// // TableGen'd canonicalization patterns //===----------------------------------------------------------------------===// namespace { #include "ArithCanonicalization.inc" } // namespace //===----------------------------------------------------------------------===// // Common helpers //===----------------------------------------------------------------------===// /// Return the type of the same shape (scalar, vector or tensor) containing i1. static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(type.getContext(), 1); if (auto shapedType = llvm::dyn_cast(type)) return shapedType.cloneWith(std::nullopt, i1Type); if (llvm::isa(type)) return UnrankedTensorType::get(i1Type); return i1Type; } //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// void arith::ConstantOp::getAsmResultNames( function_ref setNameFn) { auto type = getType(); if (auto intCst = llvm::dyn_cast(getValue())) { auto intType = llvm::dyn_cast(type); // Sugar i1 constants with 'true' and 'false'. if (intType && intType.getWidth() == 1) return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); // Otherwise, build a complex name with the value and type. SmallString<32> specialNameBuffer; llvm::raw_svector_ostream specialName(specialNameBuffer); specialName << 'c' << intCst.getValue(); if (intType) specialName << '_' << type; setNameFn(getResult(), specialName.str()); } else { setNameFn(getResult(), "cst"); } } /// TODO: disallow arith.constant to return anything other than signless integer /// or float like. LogicalResult arith::ConstantOp::verify() { auto type = getType(); // The value's type must match the return type. if (getValue().getType() != type) { return emitOpError() << "value type " << getValue().getType() << " must match return type: " << type; } // Integer values must be signless. if (llvm::isa(type) && !llvm::cast(type).isSignless()) return emitOpError("integer return type must be signless"); // Any float or elements attribute are acceptable. if (!llvm::isa(getValue())) { return emitOpError( "value must be an integer, float, or elements attribute"); } return success(); } bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { // The value's type must be the same as the provided type. auto typedAttr = llvm::dyn_cast(value); if (!typedAttr || typedAttr.getType() != type) return false; // Integer values must be signless. if (llvm::isa(type) && !llvm::cast(type).isSignless()) return false; // Integer, float, and element attributes are buildable. return llvm::isa(value); } ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value, Type type, Location loc) { if (isBuildableWith(value, type)) return builder.create(loc, cast(value)); return nullptr; } OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width) { auto type = builder.getIntegerType(width); arith::ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value)); } void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, int64_t value, Type type) { assert(type.isSignlessInteger() && "ConstantIntOp can only have signless integer type values"); arith::ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value)); } bool arith::ConstantIntOp::classof(Operation *op) { if (auto constOp = dyn_cast_or_null(op)) return constOp.getType().isSignlessInteger(); return false; } void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type) { arith::ConstantOp::build(builder, result, type, builder.getFloatAttr(type, value)); } bool arith::ConstantFloatOp::classof(Operation *op) { if (auto constOp = dyn_cast_or_null(op)) return llvm::isa(constOp.getType()); return false; } void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, int64_t value) { arith::ConstantOp::build(builder, result, builder.getIndexType(), builder.getIndexAttr(value)); } bool arith::ConstantIndexOp::classof(Operation *op) { if (auto constOp = dyn_cast_or_null(op)) return constOp.getType().isIndex(); return false; } //===----------------------------------------------------------------------===// // AddIOp //===----------------------------------------------------------------------===// OpFoldResult arith::AddIOp::fold(FoldAdaptor adaptor) { // addi(x, 0) -> x if (matchPattern(adaptor.getRhs(), m_Zero())) return getLhs(); // addi(subi(a, b), b) -> a if (auto sub = getLhs().getDefiningOp()) if (getRhs() == sub.getRhs()) return sub.getLhs(); // addi(b, subi(a, b)) -> a if (auto sub = getRhs().getDefiningOp()) if (getLhs() == sub.getRhs()) return sub.getLhs(); return constFoldBinaryOp( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) + b; }); } void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); } //===----------------------------------------------------------------------===// // AddUIExtendedOp //===----------------------------------------------------------------------===// std::optional> arith::AddUIExtendedOp::getShapeForUnroll() { if (auto vt = llvm::dyn_cast(getType(0))) return llvm::to_vector<4>(vt.getShape()); return std::nullopt; } // Returns the overflow bit, assuming that `sum` is the result of unsigned // addition of `operand` and another number. static APInt calculateUnsignedOverflow(const APInt &sum, const APInt &operand) { return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1); } LogicalResult arith::AddUIExtendedOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { Type overflowTy = getOverflow().getType(); // addui_extended(x, 0) -> x, false if (matchPattern(getRhs(), m_Zero())) { Builder builder(getContext()); auto falseValue = builder.getZeroAttr(overflowTy); results.push_back(getLhs()); results.push_back(falseValue); return success(); } // addui_extended(constant_a, constant_b) -> constant_sum, constant_carry // Let the `constFoldBinaryOp` utility attempt to fold the sum of both // operands. If that succeeds, calculate the overflow bit based on the sum // and the first (constant) operand, `lhs`. if (Attribute sumAttr = constFoldBinaryOp( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) + b; })) { Attribute overflowAttr = constFoldBinaryOp( ArrayRef({sumAttr, adaptor.getLhs()}), getI1SameShape(llvm::cast(sumAttr).getType()), calculateUnsignedOverflow); if (!overflowAttr) return failure(); results.push_back(sumAttr); results.push_back(overflowAttr); return success(); } return failure(); } void arith::AddUIExtendedOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); } //===----------------------------------------------------------------------===// // SubIOp //===----------------------------------------------------------------------===// OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) { // subi(x,x) -> 0 if (getOperand(0) == getOperand(1)) return Builder(getContext()).getZeroAttr(getType()); // subi(x,0) -> x if (matchPattern(adaptor.getRhs(), m_Zero())) return getLhs(); if (auto add = getLhs().getDefiningOp()) { // subi(addi(a, b), b) -> a if (getRhs() == add.getRhs()) return add.getLhs(); // subi(addi(a, b), a) -> b if (getRhs() == add.getLhs()) return add.getRhs(); } return constFoldBinaryOp( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) - b; }); } void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); } //===----------------------------------------------------------------------===// // MulIOp //===----------------------------------------------------------------------===// OpFoldResult arith::MulIOp::fold(FoldAdaptor adaptor) { // muli(x, 0) -> 0 if (matchPattern(adaptor.getRhs(), m_Zero())) return getRhs(); // muli(x, 1) -> x if (matchPattern(adaptor.getRhs(), m_One())) return getLhs(); // TODO: Handle the overflow case. // default folder return constFoldBinaryOp( adaptor.getOperands(), [](const APInt &a, const APInt &b) { return a * b; }); } void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); } //===----------------------------------------------------------------------===// // MulSIExtendedOp //===----------------------------------------------------------------------===// std::optional> arith::MulSIExtendedOp::getShapeForUnroll() { if (auto vt = llvm::dyn_cast(getType(0))) return llvm::to_vector<4>(vt.getShape()); return std::nullopt; } LogicalResult arith::MulSIExtendedOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { // mulsi_extended(x, 0) -> 0, 0 if (matchPattern(adaptor.getRhs(), m_Zero())) { Attribute zero = adaptor.getRhs(); results.push_back(zero); results.push_back(zero); return success(); } // mulsi_extended(cst_a, cst_b) -> cst_low, cst_high if (Attribute lowAttr = constFoldBinaryOp( adaptor.getOperands(), [](const APInt &a, const APInt &b) { return a * b; })) { // Invoke the constant fold helper again to calculate the 'high' result. Attribute highAttr = constFoldBinaryOp( adaptor.getOperands(), [](const APInt &a, const APInt &b) { unsigned bitWidth = a.getBitWidth(); APInt fullProduct = a.sext(bitWidth * 2) * b.sext(bitWidth * 2); return fullProduct.extractBits(bitWidth, bitWidth); }); assert(highAttr && "Unexpected constant-folding failure"); results.push_back(lowAttr); results.push_back(highAttr); return success(); } return failure(); } void arith::MulSIExtendedOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); } //===----------------------------------------------------------------------===// // MulUIExtendedOp //===----------------------------------------------------------------------===// std::optional> arith::MulUIExtendedOp::getShapeForUnroll() { if (auto vt = llvm::dyn_cast(getType(0))) return llvm::to_vector<4>(vt.getShape()); return std::nullopt; } LogicalResult arith::MulUIExtendedOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { // mului_extended(x, 0) -> 0, 0 if (matchPattern(adaptor.getRhs(), m_Zero())) { Attribute zero = adaptor.getRhs(); results.push_back(zero); results.push_back(zero); return success(); } // mului_extended(x, 1) -> x, 0 if (matchPattern(adaptor.getRhs(), m_One())) { Builder builder(getContext()); Attribute zero = builder.getZeroAttr(getLhs().getType()); results.push_back(getLhs()); results.push_back(zero); return success(); } // mului_extended(cst_a, cst_b) -> cst_low, cst_high if (Attribute lowAttr = constFoldBinaryOp( adaptor.getOperands(), [](const APInt &a, const APInt &b) { return a * b; })) { // Invoke the constant fold helper again to calculate the 'high' result. Attribute highAttr = constFoldBinaryOp( adaptor.getOperands(), [](const APInt &a, const APInt &b) { unsigned bitWidth = a.getBitWidth(); APInt fullProduct = a.zext(bitWidth * 2) * b.zext(bitWidth * 2); return fullProduct.extractBits(bitWidth, bitWidth); }); assert(highAttr && "Unexpected constant-folding failure"); results.push_back(lowAttr); results.push_back(highAttr); return success(); } return failure(); } void arith::MulUIExtendedOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); } //===----------------------------------------------------------------------===// // DivUIOp //===----------------------------------------------------------------------===// OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) { // divui (x, 1) -> x. if (matchPattern(adaptor.getRhs(), m_One())) return getLhs(); // Don't fold if it would require a division by zero. bool div0 = false; auto result = constFoldBinaryOp(adaptor.getOperands(), [&](APInt a, const APInt &b) { if (div0 || !b) { div0 = true; return a; } return a.udiv(b); }); return div0 ? Attribute() : result; } Speculation::Speculatability arith::DivUIOp::getSpeculatability() { // X / 0 => UB return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable : Speculation::NotSpeculatable; } //===----------------------------------------------------------------------===// // DivSIOp //===----------------------------------------------------------------------===// OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) { // divsi (x, 1) -> x. if (matchPattern(adaptor.getRhs(), m_One())) return getLhs(); // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; auto result = constFoldBinaryOp( adaptor.getOperands(), [&](APInt a, const APInt &b) { if (overflowOrDiv0 || !b) { overflowOrDiv0 = true; return a; } return a.sdiv_ov(b, overflowOrDiv0); }); return overflowOrDiv0 ? Attribute() : result; } Speculation::Speculatability arith::DivSIOp::getSpeculatability() { bool mayHaveUB = true; APInt constRHS; // X / 0 => UB // INT_MIN / -1 => UB if (matchPattern(getRhs(), m_ConstantInt(&constRHS))) mayHaveUB = constRHS.isAllOnes() || constRHS.isZero(); return mayHaveUB ? Speculation::NotSpeculatable : Speculation::Speculatable; } //===----------------------------------------------------------------------===// // Ceil and floor division folding helpers //===----------------------------------------------------------------------===// static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b, bool &overflow) { // Returns (a-1)/b + 1 APInt one(a.getBitWidth(), 1, true); // Signed value 1. APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); return val.sadd_ov(one, overflow); } //===----------------------------------------------------------------------===// // CeilDivUIOp //===----------------------------------------------------------------------===// OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) { // ceildivui (x, 1) -> x. if (matchPattern(adaptor.getRhs(), m_One())) return getLhs(); bool overflowOrDiv0 = false; auto result = constFoldBinaryOp( adaptor.getOperands(), [&](APInt a, const APInt &b) { if (overflowOrDiv0 || !b) { overflowOrDiv0 = true; return a; } APInt quotient = a.udiv(b); if (!a.urem(b)) return quotient; APInt one(a.getBitWidth(), 1, true); return quotient.uadd_ov(one, overflowOrDiv0); }); return overflowOrDiv0 ? Attribute() : result; } Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() { // X / 0 => UB return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable : Speculation::NotSpeculatable; } //===----------------------------------------------------------------------===// // CeilDivSIOp //===----------------------------------------------------------------------===// OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) { // ceildivsi (x, 1) -> x. if (matchPattern(adaptor.getRhs(), m_One())) return getLhs(); // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; auto result = constFoldBinaryOp( adaptor.getOperands(), [&](APInt a, const APInt &b) { if (overflowOrDiv0 || !b) { overflowOrDiv0 = true; return a; } if (!a) return a; // After this point we know that neither a or b are zero. unsigned bits = a.getBitWidth(); APInt zero = APInt::getZero(bits); bool aGtZero = a.sgt(zero); bool bGtZero = b.sgt(zero); if (aGtZero && bGtZero) { // Both positive, return ceil(a, b). return signedCeilNonnegInputs(a, b, overflowOrDiv0); } if (!aGtZero && !bGtZero) { // Both negative, return ceil(-a, -b). APInt posA = zero.ssub_ov(a, overflowOrDiv0); APInt posB = zero.ssub_ov(b, overflowOrDiv0); return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); } if (!aGtZero && bGtZero) { // A is negative, b is positive, return - ( -a / b). APInt posA = zero.ssub_ov(a, overflowOrDiv0); APInt div = posA.sdiv_ov(b, overflowOrDiv0); return zero.ssub_ov(div, overflowOrDiv0); } // A is positive, b is negative, return - (a / -b). APInt posB = zero.ssub_ov(b, overflowOrDiv0); APInt div = a.sdiv_ov(posB, overflowOrDiv0); return zero.ssub_ov(div, overflowOrDiv0); }); return overflowOrDiv0 ? Attribute() : result; } Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() { bool mayHaveUB = true; APInt constRHS; // X / 0 => UB // INT_MIN / -1 => UB if (matchPattern(getRhs(), m_ConstantInt(&constRHS))) mayHaveUB = constRHS.isAllOnes() || constRHS.isZero(); return mayHaveUB ? Speculation::NotSpeculatable : Speculation::Speculatable; } //===----------------------------------------------------------------------===// // FloorDivSIOp //===----------------------------------------------------------------------===// OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) { // floordivsi (x, 1) -> x. if (matchPattern(adaptor.getRhs(), m_One())) return getLhs(); // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; auto result = constFoldBinaryOp( adaptor.getOperands(), [&](APInt a, const APInt &b) { if (overflowOrDiv0 || !b) { overflowOrDiv0 = true; return a; } if (!a) return a; // After this point we know that neither a or b are zero. unsigned bits = a.getBitWidth(); APInt zero = APInt::getZero(bits); bool aGtZero = a.sgt(zero); bool bGtZero = b.sgt(zero); if (aGtZero && bGtZero) { // Both positive, return a / b. return a.sdiv_ov(b, overflowOrDiv0); } if (!aGtZero && !bGtZero) { // Both negative, return -a / -b. APInt posA = zero.ssub_ov(a, overflowOrDiv0); APInt posB = zero.ssub_ov(b, overflowOrDiv0); return posA.sdiv_ov(posB, overflowOrDiv0); } if (!aGtZero && bGtZero) { // A is negative, b is positive, return - ceil(-a, b). APInt posA = zero.ssub_ov(a, overflowOrDiv0); APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); return zero.ssub_ov(ceil, overflowOrDiv0); } // A is positive, b is negative, return - ceil(a, -b). APInt posB = zero.ssub_ov(b, overflowOrDiv0); APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); return zero.ssub_ov(ceil, overflowOrDiv0); }); return overflowOrDiv0 ? Attribute() : result; } //===----------------------------------------------------------------------===// // RemUIOp //===----------------------------------------------------------------------===// OpFoldResult arith::RemUIOp::fold(FoldAdaptor adaptor) { // remui (x, 1) -> 0. if (matchPattern(adaptor.getRhs(), m_One())) return Builder(getContext()).getZeroAttr(getType()); // Don't fold if it would require a division by zero. bool div0 = false; auto result = constFoldBinaryOp(adaptor.getOperands(), [&](APInt a, const APInt &b) { if (div0 || b.isZero()) { div0 = true; return a; } return a.urem(b); }); return div0 ? Attribute() : result; } //===----------------------------------------------------------------------===// // RemSIOp //===----------------------------------------------------------------------===// OpFoldResult arith::RemSIOp::fold(FoldAdaptor adaptor) { // remsi (x, 1) -> 0. if (matchPattern(adaptor.getRhs(), m_One())) return Builder(getContext()).getZeroAttr(getType()); // Don't fold if it would require a division by zero. bool div0 = false; auto result = constFoldBinaryOp(adaptor.getOperands(), [&](APInt a, const APInt &b) { if (div0 || b.isZero()) { div0 = true; return a; } return a.srem(b); }); return div0 ? Attribute() : result; } //===----------------------------------------------------------------------===// // AndIOp //===----------------------------------------------------------------------===// /// Fold `and(a, and(a, b))` to `and(a, b)` static Value foldAndIofAndI(arith::AndIOp op) { for (bool reversePrev : {false, true}) { auto prev = (reversePrev ? op.getRhs() : op.getLhs()) .getDefiningOp(); if (!prev) continue; Value other = (reversePrev ? op.getLhs() : op.getRhs()); if (other != prev.getLhs() && other != prev.getRhs()) continue; return prev.getResult(); } return {}; } OpFoldResult arith::AndIOp::fold(FoldAdaptor adaptor) { /// and(x, 0) -> 0 if (matchPattern(adaptor.getRhs(), m_Zero())) return getRhs(); /// and(x, allOnes) -> x APInt intValue; if (matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes()) return getLhs(); /// and(x, not(x)) -> 0 if (matchPattern(getRhs(), m_Op(matchers::m_Val(getLhs()), m_ConstantInt(&intValue))) && intValue.isAllOnes()) return Builder(getContext()).getZeroAttr(getType()); /// and(not(x), x) -> 0 if (matchPattern(getLhs(), m_Op(matchers::m_Val(getRhs()), m_ConstantInt(&intValue))) && intValue.isAllOnes()) return Builder(getContext()).getZeroAttr(getType()); /// and(a, and(a, b)) -> and(a, b) if (Value result = foldAndIofAndI(*this)) return result; return constFoldBinaryOp( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) & b; }); } //===----------------------------------------------------------------------===// // OrIOp //===----------------------------------------------------------------------===// OpFoldResult arith::OrIOp::fold(FoldAdaptor adaptor) { if (APInt rhsVal; matchPattern(adaptor.getRhs(), m_ConstantInt(&rhsVal))) { /// or(x, 0) -> x if (rhsVal.isZero()) return getLhs(); /// or(x, ) -> if (rhsVal.isAllOnes()) return adaptor.getRhs(); } APInt intValue; /// or(x, xor(x, 1)) -> 1 if (matchPattern(getRhs(), m_Op(matchers::m_Val(getLhs()), m_ConstantInt(&intValue))) && intValue.isAllOnes()) return getRhs().getDefiningOp().getRhs(); /// or(xor(x, 1), x) -> 1 if (matchPattern(getLhs(), m_Op(matchers::m_Val(getRhs()), m_ConstantInt(&intValue))) && intValue.isAllOnes()) return getLhs().getDefiningOp().getRhs(); return constFoldBinaryOp( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) | b; }); } //===----------------------------------------------------------------------===// // XOrIOp //===----------------------------------------------------------------------===// OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) { /// xor(x, 0) -> x if (matchPattern(adaptor.getRhs(), m_Zero())) return getLhs(); /// xor(x, x) -> 0 if (getLhs() == getRhs()) return Builder(getContext()).getZeroAttr(getType()); /// xor(xor(x, a), a) -> x /// xor(xor(a, x), a) -> x if (arith::XOrIOp prev = getLhs().getDefiningOp()) { if (prev.getRhs() == getRhs()) return prev.getLhs(); if (prev.getLhs() == getRhs()) return prev.getRhs(); } /// xor(a, xor(x, a)) -> x /// xor(a, xor(a, x)) -> x if (arith::XOrIOp prev = getRhs().getDefiningOp()) { if (prev.getRhs() == getLhs()) return prev.getLhs(); if (prev.getLhs() == getLhs()) return prev.getRhs(); } return constFoldBinaryOp( adaptor.getOperands(), [](APInt a, const APInt &b) { return std::move(a) ^ b; }); } void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); } //===----------------------------------------------------------------------===// // NegFOp //===----------------------------------------------------------------------===// OpFoldResult arith::NegFOp::fold(FoldAdaptor adaptor) { /// negf(negf(x)) -> x if (auto op = this->getOperand().getDefiningOp()) return op.getOperand(); return constFoldUnaryOp(adaptor.getOperands(), [](const APFloat &a) { return -a; }); } //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// OpFoldResult arith::AddFOp::fold(FoldAdaptor adaptor) { // addf(x, -0) -> x if (matchPattern(adaptor.getRhs(), m_NegZeroFloat())) return getLhs(); return constFoldBinaryOp( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return a + b; }); } //===----------------------------------------------------------------------===// // SubFOp //===----------------------------------------------------------------------===// OpFoldResult arith::SubFOp::fold(FoldAdaptor adaptor) { // subf(x, +0) -> x if (matchPattern(adaptor.getRhs(), m_PosZeroFloat())) return getLhs(); return constFoldBinaryOp( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return a - b; }); } //===----------------------------------------------------------------------===// // MaximumFOp //===----------------------------------------------------------------------===// OpFoldResult arith::MaximumFOp::fold(FoldAdaptor adaptor) { // maximumf(x,x) -> x if (getLhs() == getRhs()) return getRhs(); // maximumf(x, -inf) -> x if (matchPattern(adaptor.getRhs(), m_NegInfFloat())) return getLhs(); return constFoldBinaryOp( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); } //===----------------------------------------------------------------------===// // MaxNumFOp //===----------------------------------------------------------------------===// OpFoldResult arith::MaxNumFOp::fold(FoldAdaptor adaptor) { // maxnumf(x,x) -> x if (getLhs() == getRhs()) return getRhs(); // maxnumf(x, -inf) -> x if (matchPattern(adaptor.getRhs(), m_NegInfFloat())) return getLhs(); return constFoldBinaryOp( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); }); } //===----------------------------------------------------------------------===// // MaxSIOp //===----------------------------------------------------------------------===// OpFoldResult MaxSIOp::fold(FoldAdaptor adaptor) { // maxsi(x,x) -> x if (getLhs() == getRhs()) return getRhs(); if (APInt intValue; matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { // maxsi(x,MAX_INT) -> MAX_INT if (intValue.isMaxSignedValue()) return getRhs(); // maxsi(x, MIN_INT) -> x if (intValue.isMinSignedValue()) return getLhs(); } return constFoldBinaryOp(adaptor.getOperands(), [](const APInt &a, const APInt &b) { return llvm::APIntOps::smax(a, b); }); } //===----------------------------------------------------------------------===// // MaxUIOp //===----------------------------------------------------------------------===// OpFoldResult MaxUIOp::fold(FoldAdaptor adaptor) { // maxui(x,x) -> x if (getLhs() == getRhs()) return getRhs(); if (APInt intValue; matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { // maxui(x,MAX_INT) -> MAX_INT if (intValue.isMaxValue()) return getRhs(); // maxui(x, MIN_INT) -> x if (intValue.isMinValue()) return getLhs(); } return constFoldBinaryOp(adaptor.getOperands(), [](const APInt &a, const APInt &b) { return llvm::APIntOps::umax(a, b); }); } //===----------------------------------------------------------------------===// // MinimumFOp //===----------------------------------------------------------------------===// OpFoldResult arith::MinimumFOp::fold(FoldAdaptor adaptor) { // minimumf(x,x) -> x if (getLhs() == getRhs()) return getRhs(); // minimumf(x, +inf) -> x if (matchPattern(adaptor.getRhs(), m_PosInfFloat())) return getLhs(); return constFoldBinaryOp( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); }); } //===----------------------------------------------------------------------===// // MinNumFOp //===----------------------------------------------------------------------===// OpFoldResult arith::MinNumFOp::fold(FoldAdaptor adaptor) { // minnumf(x,x) -> x if (getLhs() == getRhs()) return getRhs(); // minnumf(x, +inf) -> x if (matchPattern(adaptor.getRhs(), m_PosInfFloat())) return getLhs(); return constFoldBinaryOp( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return llvm::minnum(a, b); }); } //===----------------------------------------------------------------------===// // MinSIOp //===----------------------------------------------------------------------===// OpFoldResult MinSIOp::fold(FoldAdaptor adaptor) { // minsi(x,x) -> x if (getLhs() == getRhs()) return getRhs(); if (APInt intValue; matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { // minsi(x,MIN_INT) -> MIN_INT if (intValue.isMinSignedValue()) return getRhs(); // minsi(x, MAX_INT) -> x if (intValue.isMaxSignedValue()) return getLhs(); } return constFoldBinaryOp(adaptor.getOperands(), [](const APInt &a, const APInt &b) { return llvm::APIntOps::smin(a, b); }); } //===----------------------------------------------------------------------===// // MinUIOp //===----------------------------------------------------------------------===// OpFoldResult MinUIOp::fold(FoldAdaptor adaptor) { // minui(x,x) -> x if (getLhs() == getRhs()) return getRhs(); if (APInt intValue; matchPattern(adaptor.getRhs(), m_ConstantInt(&intValue))) { // minui(x,MIN_INT) -> MIN_INT if (intValue.isMinValue()) return getRhs(); // minui(x, MAX_INT) -> x if (intValue.isMaxValue()) return getLhs(); } return constFoldBinaryOp(adaptor.getOperands(), [](const APInt &a, const APInt &b) { return llvm::APIntOps::umin(a, b); }); } //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) { // mulf(x, 1) -> x if (matchPattern(adaptor.getRhs(), m_OneFloat())) return getLhs(); return constFoldBinaryOp( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return a * b; }); } void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); } //===----------------------------------------------------------------------===// // DivFOp //===----------------------------------------------------------------------===// OpFoldResult arith::DivFOp::fold(FoldAdaptor adaptor) { // divf(x, 1) -> x if (matchPattern(adaptor.getRhs(), m_OneFloat())) return getLhs(); return constFoldBinaryOp( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return a / b; }); } void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(context); } //===----------------------------------------------------------------------===// // RemFOp //===----------------------------------------------------------------------===// OpFoldResult arith::RemFOp::fold(FoldAdaptor adaptor) { return constFoldBinaryOp(adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { APFloat result(a); (void)result.remainder(b); return result; }); } //===----------------------------------------------------------------------===// // Utility functions for verifying cast ops //===----------------------------------------------------------------------===// template using type_list = std::tuple *; /// Returns a non-null type only if the provided type is one of the allowed /// types or one of the allowed shaped types of the allowed types. Returns the /// element type if a valid shaped type is provided. template static Type getUnderlyingType(Type type, type_list, type_list) { if (llvm::isa(type) && !llvm::isa(type)) return {}; auto underlyingType = getElementTypeOrSelf(type); if (!llvm::isa(underlyingType)) return {}; return underlyingType; } /// Get allowed underlying types for vectors and tensors. template static Type getTypeIfLike(Type type) { return getUnderlyingType(type, type_list(), type_list()); } /// Get allowed underlying types for vectors, tensors, and memrefs. template static Type getTypeIfLikeOrMemRef(Type type) { return getUnderlyingType(type, type_list(), type_list()); } /// Return false if both types are ranked tensor with mismatching encoding. static bool hasSameEncoding(Type typeA, Type typeB) { auto rankedTensorA = dyn_cast(typeA); auto rankedTensorB = dyn_cast(typeB); if (!rankedTensorA || !rankedTensorB) return true; return rankedTensorA.getEncoding() == rankedTensorB.getEncoding(); } static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) { if (inputs.size() != 1 || outputs.size() != 1) return false; if (!hasSameEncoding(inputs.front(), outputs.front())) return false; return succeeded(verifyCompatibleShapes(inputs.front(), outputs.front())); } //===----------------------------------------------------------------------===// // Verifiers for integer and floating point extension/truncation ops //===----------------------------------------------------------------------===// // Extend ops can only extend to a wider type. template static LogicalResult verifyExtOp(Op op) { Type srcType = getElementTypeOrSelf(op.getIn().getType()); Type dstType = getElementTypeOrSelf(op.getType()); if (llvm::cast(srcType).getWidth() >= llvm::cast(dstType).getWidth()) return op.emitError("result type ") << dstType << " must be wider than operand type " << srcType; return success(); } // Truncate ops can only truncate to a shorter type. template static LogicalResult verifyTruncateOp(Op op) { Type srcType = getElementTypeOrSelf(op.getIn().getType()); Type dstType = getElementTypeOrSelf(op.getType()); if (llvm::cast(srcType).getWidth() <= llvm::cast(dstType).getWidth()) return op.emitError("result type ") << dstType << " must be shorter than operand type " << srcType; return success(); } /// Validate a cast that changes the width of a type. template