//===- IntNarrowing.cpp - Integer bitwidth reduction optimizations --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Analysis/Presburger/IntegerRelation.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Transforms.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include #include namespace mlir::arith { #define GEN_PASS_DEF_ARITHINTNARROWING #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" } // namespace mlir::arith namespace mlir::arith { namespace { //===----------------------------------------------------------------------===// // Common Helpers //===----------------------------------------------------------------------===// /// The base for integer bitwidth narrowing patterns. template struct NarrowingPattern : OpRewritePattern { NarrowingPattern(MLIRContext *ctx, const ArithIntNarrowingOptions &options, PatternBenefit benefit = 1) : OpRewritePattern(ctx, benefit), supportedBitwidths(options.bitwidthsSupported.begin(), options.bitwidthsSupported.end()) { assert(!supportedBitwidths.empty() && "Invalid options"); assert(!llvm::is_contained(supportedBitwidths, 0) && "Invalid bitwidth"); llvm::sort(supportedBitwidths); } FailureOr getNarrowestCompatibleBitwidth(unsigned bitsRequired) const { for (unsigned candidate : supportedBitwidths) if (candidate >= bitsRequired) return candidate; return failure(); } /// Returns the narrowest supported type that fits `bitsRequired`. FailureOr getNarrowType(unsigned bitsRequired, Type origTy) const { assert(origTy); FailureOr bestBitwidth = getNarrowestCompatibleBitwidth(bitsRequired); if (failed(bestBitwidth)) return failure(); Type elemTy = getElementTypeOrSelf(origTy); if (!isa(elemTy)) return failure(); auto newElemTy = IntegerType::get(origTy.getContext(), *bestBitwidth); if (newElemTy == elemTy) return failure(); if (origTy == elemTy) return newElemTy; if (auto shapedTy = dyn_cast(origTy)) if (dyn_cast(shapedTy.getElementType())) return shapedTy.clone(shapedTy.getShape(), newElemTy); return failure(); } private: // Supported integer bitwidths in the ascending order. llvm::SmallVector supportedBitwidths; }; /// Returns the integer bitwidth required to represent `type`. FailureOr calculateBitsRequired(Type type) { assert(type); if (auto intTy = dyn_cast(getElementTypeOrSelf(type))) return intTy.getWidth(); return failure(); } enum class ExtensionKind { Sign, Zero }; /// Wrapper around `arith::ExtSIOp` and `arith::ExtUIOp` ops that abstracts away /// the exact op type. Exposes helper functions to query the types, operands, /// and the result. This is so that we can handle both extension kinds without /// needing to use templates or branching. class ExtensionOp { public: /// Attemps to create a new extension op from `op`. Returns an extension op /// wrapper when `op` is either `arith.extsi` or `arith.extui`, and failure /// otherwise. static FailureOr from(Operation *op) { if (dyn_cast_or_null(op)) return ExtensionOp{op, ExtensionKind::Sign}; if (dyn_cast_or_null(op)) return ExtensionOp{op, ExtensionKind::Zero}; return failure(); } ExtensionOp(const ExtensionOp &) = default; ExtensionOp &operator=(const ExtensionOp &) = default; /// Creates a new extension op of the same kind. Operation *recreate(PatternRewriter &rewriter, Location loc, Type newType, Value in) { if (kind == ExtensionKind::Sign) return rewriter.create(loc, newType, in); return rewriter.create(loc, newType, in); } /// Replaces `toReplace` with a new extension op of the same kind. void recreateAndReplace(PatternRewriter &rewriter, Operation *toReplace, Value in) { assert(toReplace->getNumResults() == 1); Type newType = toReplace->getResult(0).getType(); Operation *newOp = recreate(rewriter, toReplace->getLoc(), newType, in); rewriter.replaceOp(toReplace, newOp->getResult(0)); } ExtensionKind getKind() { return kind; } Value getResult() { return op->getResult(0); } Value getIn() { return op->getOperand(0); } Type getType() { return getResult().getType(); } Type getElementType() { return getElementTypeOrSelf(getType()); } Type getInType() { return getIn().getType(); } Type getInElementType() { return getElementTypeOrSelf(getInType()); } private: ExtensionOp(Operation *op, ExtensionKind kind) : op(op), kind(kind) { assert(op); assert((isa(op)) && "Not an extension op"); } Operation *op = nullptr; ExtensionKind kind = {}; }; /// Returns the integer bitwidth required to represent `value`. unsigned calculateBitsRequired(const APInt &value, ExtensionKind lookThroughExtension) { // For unsigned values, we only need the active bits. As a special case, zero // requires one bit. if (lookThroughExtension == ExtensionKind::Zero) return std::max(value.getActiveBits(), 1u); // If a signed value is nonnegative, we need one extra bit for the sign. if (value.isNonNegative()) return value.getActiveBits() + 1; // For the signed min, we need all the bits. if (value.isMinSignedValue()) return value.getBitWidth(); // For negative values, we need all the non-sign bits and one extra bit for // the sign. return value.getBitWidth() - value.getNumSignBits() + 1; } /// Returns the integer bitwidth required to represent `value`. /// Looks through either sign- or zero-extension as specified by /// `lookThroughExtension`. FailureOr calculateBitsRequired(Value value, ExtensionKind lookThroughExtension) { // Handle constants. if (TypedAttr attr; matchPattern(value, m_Constant(&attr))) { if (auto intAttr = dyn_cast(attr)) return calculateBitsRequired(intAttr.getValue(), lookThroughExtension); if (auto elemsAttr = dyn_cast(attr)) { if (elemsAttr.getElementType().isIntOrIndex()) { if (elemsAttr.isSplat()) return calculateBitsRequired(elemsAttr.getSplatValue(), lookThroughExtension); unsigned maxBits = 1; for (const APInt &elemValue : elemsAttr.getValues()) maxBits = std::max( maxBits, calculateBitsRequired(elemValue, lookThroughExtension)); return maxBits; } } } if (lookThroughExtension == ExtensionKind::Sign) { if (auto sext = value.getDefiningOp()) return calculateBitsRequired(sext.getIn().getType()); } else if (lookThroughExtension == ExtensionKind::Zero) { if (auto zext = value.getDefiningOp()) return calculateBitsRequired(zext.getIn().getType()); } // If nothing else worked, return the type requirements for this element type. return calculateBitsRequired(value.getType()); } /// Base pattern for arith binary ops. /// Example: /// ``` /// %lhs = arith.extsi %a : i8 to i32 /// %rhs = arith.extsi %b : i8 to i32 /// %r = arith.addi %lhs, %rhs : i32 /// ==> /// %lhs = arith.extsi %a : i8 to i16 /// %rhs = arith.extsi %b : i8 to i16 /// %add = arith.addi %lhs, %rhs : i16 /// %r = arith.extsi %add : i16 to i32 /// ``` template struct BinaryOpNarrowingPattern : NarrowingPattern { using NarrowingPattern::NarrowingPattern; /// Returns the number of bits required to represent the full result, assuming /// that both operands are `operandBits`-wide. Derived classes must implement /// this, taking into account `BinaryOp` semantics. virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0; /// Customization point for patterns that should only apply with /// zero/sign-extension ops as arguments. virtual bool isSupported(ExtensionOp) const { return true; } LogicalResult matchAndRewrite(BinaryOp op, PatternRewriter &rewriter) const final { Type origTy = op.getType(); FailureOr resultBits = calculateBitsRequired(origTy); if (failed(resultBits)) return failure(); // For the optimization to apply, we expect the lhs to be an extension op, // and for the rhs to either be the same extension op or a constant. FailureOr ext = ExtensionOp::from(op.getLhs().getDefiningOp()); if (failed(ext) || !isSupported(*ext)) return failure(); FailureOr lhsBitsRequired = calculateBitsRequired(ext->getIn(), ext->getKind()); if (failed(lhsBitsRequired) || *lhsBitsRequired >= *resultBits) return failure(); FailureOr rhsBitsRequired = calculateBitsRequired(op.getRhs(), ext->getKind()); if (failed(rhsBitsRequired) || *rhsBitsRequired >= *resultBits) return failure(); // Negotiate a common bit requirements for both lhs and rhs, accounting for // the result requiring more bits than the operands. unsigned commonBitsRequired = getResultBitsProduced(std::max(*lhsBitsRequired, *rhsBitsRequired)); FailureOr narrowTy = this->getNarrowType(commonBitsRequired, origTy); if (failed(narrowTy) || calculateBitsRequired(*narrowTy) >= *resultBits) return failure(); Location loc = op.getLoc(); Value newLhs = rewriter.createOrFold(loc, *narrowTy, op.getLhs()); Value newRhs = rewriter.createOrFold(loc, *narrowTy, op.getRhs()); Value newAdd = rewriter.create(loc, newLhs, newRhs); ext->recreateAndReplace(rewriter, op, newAdd); return success(); } }; //===----------------------------------------------------------------------===// // AddIOp Pattern //===----------------------------------------------------------------------===// struct AddIPattern final : BinaryOpNarrowingPattern { using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; // Addition may require one extra bit for the result. // Example: `UINT8_MAX + 1 == 255 + 1 == 256`. unsigned getResultBitsProduced(unsigned operandBits) const override { return operandBits + 1; } }; //===----------------------------------------------------------------------===// // SubIOp Pattern //===----------------------------------------------------------------------===// struct SubIPattern final : BinaryOpNarrowingPattern { using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; // This optimization only applies to signed arguments. bool isSupported(ExtensionOp ext) const override { return ext.getKind() == ExtensionKind::Sign; } // Subtraction may require one extra bit for the result. // Example: `INT8_MAX - (-1) == 127 - (-1) == 128`. unsigned getResultBitsProduced(unsigned operandBits) const override { return operandBits + 1; } }; //===----------------------------------------------------------------------===// // MulIOp Pattern //===----------------------------------------------------------------------===// struct MulIPattern final : BinaryOpNarrowingPattern { using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; // Multiplication may require up double the operand bits. // Example: `UNT8_MAX * UINT8_MAX == 255 * 255 == 65025`. unsigned getResultBitsProduced(unsigned operandBits) const override { return 2 * operandBits; } }; //===----------------------------------------------------------------------===// // DivSIOp Pattern //===----------------------------------------------------------------------===// struct DivSIPattern final : BinaryOpNarrowingPattern { using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; // This optimization only applies to signed arguments. bool isSupported(ExtensionOp ext) const override { return ext.getKind() == ExtensionKind::Sign; } // Unlike multiplication, signed division requires only one more result bit. // Example: `INT8_MIN / (-1) == -128 / (-1) == 128`. unsigned getResultBitsProduced(unsigned operandBits) const override { return operandBits + 1; } }; //===----------------------------------------------------------------------===// // DivUIOp Pattern //===----------------------------------------------------------------------===// struct DivUIPattern final : BinaryOpNarrowingPattern { using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; // This optimization only applies to unsigned arguments. bool isSupported(ExtensionOp ext) const override { return ext.getKind() == ExtensionKind::Zero; } // Unsigned division does not require any extra result bits. unsigned getResultBitsProduced(unsigned operandBits) const override { return operandBits; } }; //===----------------------------------------------------------------------===// // Min/Max Patterns //===----------------------------------------------------------------------===// template struct MinMaxPattern final : BinaryOpNarrowingPattern { using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern; bool isSupported(ExtensionOp ext) const override { return ext.getKind() == Kind; } // Min/max returns one of the arguments and does not require any extra result // bits. unsigned getResultBitsProduced(unsigned operandBits) const override { return operandBits; } }; using MaxSIPattern = MinMaxPattern; using MaxUIPattern = MinMaxPattern; using MinSIPattern = MinMaxPattern; using MinUIPattern = MinMaxPattern; //===----------------------------------------------------------------------===// // *IToFPOp Patterns //===----------------------------------------------------------------------===// template struct IToFPPattern final : NarrowingPattern { using NarrowingPattern::NarrowingPattern; LogicalResult matchAndRewrite(IToFPOp op, PatternRewriter &rewriter) const override { FailureOr narrowestWidth = calculateBitsRequired(op.getIn(), Extension); if (failed(narrowestWidth)) return failure(); FailureOr narrowTy = this->getNarrowType(*narrowestWidth, op.getIn().getType()); if (failed(narrowTy)) return failure(); Value newIn = rewriter.createOrFold(op.getLoc(), *narrowTy, op.getIn()); rewriter.replaceOpWithNewOp(op, op.getType(), newIn); return success(); } }; using SIToFPPattern = IToFPPattern; using UIToFPPattern = IToFPPattern; //===----------------------------------------------------------------------===// // Index Cast Patterns //===----------------------------------------------------------------------===// // These rely on the `ValueBounds` interface for index values. For example, we // can often statically tell index value bounds of loop induction variables. template struct IndexCastPattern final : NarrowingPattern { using NarrowingPattern::NarrowingPattern; LogicalResult matchAndRewrite(CastOp op, PatternRewriter &rewriter) const override { Value in = op.getIn(); // We only support scalar index -> integer casts. if (!isa(in.getType())) return failure(); // Check the lower bound in both the signed and unsigned cast case. We // conservatively assume that even unsigned casts may be performed on // negative indices. FailureOr lb = ValueBoundsConstraintSet::computeConstantBound( presburger::BoundType::LB, in); if (failed(lb)) return failure(); FailureOr ub = ValueBoundsConstraintSet::computeConstantBound( presburger::BoundType::UB, in, /*dim=*/std::nullopt, /*stopCondition=*/nullptr, /*closedUB=*/true); if (failed(ub)) return failure(); assert(*lb <= *ub && "Invalid bounds"); unsigned lbBitsRequired = calculateBitsRequired(APInt(64, *lb), Kind); unsigned ubBitsRequired = calculateBitsRequired(APInt(64, *ub), Kind); unsigned bitsRequired = std::max(lbBitsRequired, ubBitsRequired); IntegerType resultTy = cast(op.getType()); if (resultTy.getWidth() <= bitsRequired) return failure(); FailureOr narrowTy = this->getNarrowType(bitsRequired, resultTy); if (failed(narrowTy)) return failure(); Value newCast = rewriter.create(op.getLoc(), *narrowTy, op.getIn()); if (Kind == ExtensionKind::Sign) rewriter.replaceOpWithNewOp(op, resultTy, newCast); else rewriter.replaceOpWithNewOp(op, resultTy, newCast); return success(); } }; using IndexCastSIPattern = IndexCastPattern; using IndexCastUIPattern = IndexCastPattern; //===----------------------------------------------------------------------===// // Patterns to Commute Extension Ops //===----------------------------------------------------------------------===// struct ExtensionOverBroadcast final : NarrowingPattern { using NarrowingPattern::NarrowingPattern; LogicalResult matchAndRewrite(vector::BroadcastOp op, PatternRewriter &rewriter) const override { FailureOr ext = ExtensionOp::from(op.getSource().getDefiningOp()); if (failed(ext)) return failure(); VectorType origTy = op.getResultVectorType(); VectorType newTy = origTy.cloneWith(origTy.getShape(), ext->getInElementType()); Value newBroadcast = rewriter.create(op.getLoc(), newTy, ext->getIn()); ext->recreateAndReplace(rewriter, op, newBroadcast); return success(); } }; struct ExtensionOverExtract final : NarrowingPattern { using NarrowingPattern::NarrowingPattern; LogicalResult matchAndRewrite(vector::ExtractOp op, PatternRewriter &rewriter) const override { FailureOr ext = ExtensionOp::from(op.getVector().getDefiningOp()); if (failed(ext)) return failure(); Value newExtract = rewriter.create( op.getLoc(), ext->getIn(), op.getMixedPosition()); ext->recreateAndReplace(rewriter, op, newExtract); return success(); } }; struct ExtensionOverExtractElement final : NarrowingPattern { using NarrowingPattern::NarrowingPattern; LogicalResult matchAndRewrite(vector::ExtractElementOp op, PatternRewriter &rewriter) const override { FailureOr ext = ExtensionOp::from(op.getVector().getDefiningOp()); if (failed(ext)) return failure(); Value newExtract = rewriter.create( op.getLoc(), ext->getIn(), op.getPosition()); ext->recreateAndReplace(rewriter, op, newExtract); return success(); } }; struct ExtensionOverExtractStridedSlice final : NarrowingPattern { using NarrowingPattern::NarrowingPattern; LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { FailureOr ext = ExtensionOp::from(op.getVector().getDefiningOp()); if (failed(ext)) return failure(); VectorType origTy = op.getType(); VectorType extractTy = origTy.cloneWith(origTy.getShape(), ext->getInElementType()); Value newExtract = rewriter.create( op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(), op.getStrides()); ext->recreateAndReplace(rewriter, op, newExtract); return success(); } }; /// Base pattern for `vector.insert` narrowing patterns. template struct ExtensionOverInsertionPattern : NarrowingPattern { using NarrowingPattern::NarrowingPattern; /// Derived classes must provide a function to create the matching insertion /// op based on the original op and new arguments. virtual InsertionOp createInsertionOp(PatternRewriter &rewriter, InsertionOp origInsert, Value narrowValue, Value narrowDest) const = 0; LogicalResult matchAndRewrite(InsertionOp op, PatternRewriter &rewriter) const final { FailureOr ext = ExtensionOp::from(op.getSource().getDefiningOp()); if (failed(ext)) return failure(); FailureOr newInsert = createNarrowInsert(op, rewriter, *ext); if (failed(newInsert)) return failure(); ext->recreateAndReplace(rewriter, op, *newInsert); return success(); } FailureOr createNarrowInsert(InsertionOp op, PatternRewriter &rewriter, ExtensionOp insValue) const { // Calculate the operand and result bitwidths. We can only apply narrowing // when the inserted source value and destination vector require fewer bits // than the result. Because the source and destination may have different // bitwidths requirements, we have to find the common narrow bitwidth that // is greater equal to the operand bitwidth requirements and still narrower // than the result. FailureOr origBitsRequired = calculateBitsRequired(op.getType()); if (failed(origBitsRequired)) return failure(); // TODO: We could relax this check by disregarding bitwidth requirements of // elements that we know will be replaced by the insertion. FailureOr destBitsRequired = calculateBitsRequired(op.getDest(), insValue.getKind()); if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired) return failure(); FailureOr insertedBitsRequired = calculateBitsRequired(insValue.getIn(), insValue.getKind()); if (failed(insertedBitsRequired) || *insertedBitsRequired >= *origBitsRequired) return failure(); // Find a narrower element type that satisfies the bitwidth requirements of // both the source and the destination values. unsigned newInsertionBits = std::max(*destBitsRequired, *insertedBitsRequired); FailureOr newVecTy = this->getNarrowType(newInsertionBits, op.getType()); if (failed(newVecTy) || *newVecTy == op.getType()) return failure(); FailureOr newInsertedValueTy = this->getNarrowType(newInsertionBits, insValue.getType()); if (failed(newInsertedValueTy)) return failure(); Location loc = op.getLoc(); Value narrowValue = rewriter.createOrFold( loc, *newInsertedValueTy, insValue.getResult()); Value narrowDest = rewriter.createOrFold(loc, *newVecTy, op.getDest()); return createInsertionOp(rewriter, op, narrowValue, narrowDest); } }; struct ExtensionOverInsert final : ExtensionOverInsertionPattern { using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern; vector::InsertOp createInsertionOp(PatternRewriter &rewriter, vector::InsertOp origInsert, Value narrowValue, Value narrowDest) const override { return rewriter.create(origInsert.getLoc(), narrowValue, narrowDest, origInsert.getMixedPosition()); } }; struct ExtensionOverInsertElement final : ExtensionOverInsertionPattern { using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern; vector::InsertElementOp createInsertionOp(PatternRewriter &rewriter, vector::InsertElementOp origInsert, Value narrowValue, Value narrowDest) const override { return rewriter.create( origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition()); } }; struct ExtensionOverInsertStridedSlice final : ExtensionOverInsertionPattern { using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern; vector::InsertStridedSliceOp createInsertionOp(PatternRewriter &rewriter, vector::InsertStridedSliceOp origInsert, Value narrowValue, Value narrowDest) const override { return rewriter.create( origInsert.getLoc(), narrowValue, narrowDest, origInsert.getOffsets(), origInsert.getStrides()); } }; struct ExtensionOverShapeCast final : NarrowingPattern { using NarrowingPattern::NarrowingPattern; LogicalResult matchAndRewrite(vector::ShapeCastOp op, PatternRewriter &rewriter) const override { FailureOr ext = ExtensionOp::from(op.getSource().getDefiningOp()); if (failed(ext)) return failure(); VectorType origTy = op.getResultVectorType(); VectorType newTy = origTy.cloneWith(origTy.getShape(), ext->getInElementType()); Value newCast = rewriter.create(op.getLoc(), newTy, ext->getIn()); ext->recreateAndReplace(rewriter, op, newCast); return success(); } }; struct ExtensionOverTranspose final : NarrowingPattern { using NarrowingPattern::NarrowingPattern; LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override { FailureOr ext = ExtensionOp::from(op.getVector().getDefiningOp()); if (failed(ext)) return failure(); VectorType origTy = op.getResultVectorType(); VectorType newTy = origTy.cloneWith(origTy.getShape(), ext->getInElementType()); Value newTranspose = rewriter.create( op.getLoc(), newTy, ext->getIn(), op.getPermutation()); ext->recreateAndReplace(rewriter, op, newTranspose); return success(); } }; struct ExtensionOverFlatTranspose final : NarrowingPattern { using NarrowingPattern::NarrowingPattern; LogicalResult matchAndRewrite(vector::FlatTransposeOp op, PatternRewriter &rewriter) const override { FailureOr ext = ExtensionOp::from(op.getMatrix().getDefiningOp()); if (failed(ext)) return failure(); VectorType origTy = op.getType(); VectorType newTy = origTy.cloneWith(origTy.getShape(), ext->getInElementType()); Value newTranspose = rewriter.create( op.getLoc(), newTy, ext->getIn(), op.getRowsAttr(), op.getColumnsAttr()); ext->recreateAndReplace(rewriter, op, newTranspose); return success(); } }; //===----------------------------------------------------------------------===// // Pass Definitions //===----------------------------------------------------------------------===// struct ArithIntNarrowingPass final : impl::ArithIntNarrowingBase { using ArithIntNarrowingBase::ArithIntNarrowingBase; void runOnOperation() override { if (bitwidthsSupported.empty() || llvm::is_contained(bitwidthsSupported, 0)) { // Invalid pass options. return signalPassFailure(); } Operation *op = getOperation(); MLIRContext *ctx = op->getContext(); RewritePatternSet patterns(ctx); populateArithIntNarrowingPatterns( patterns, ArithIntNarrowingOptions{bitwidthsSupported}); if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) signalPassFailure(); } }; } // namespace //===----------------------------------------------------------------------===// // Public API //===----------------------------------------------------------------------===// void populateArithIntNarrowingPatterns( RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) { // Add commute patterns with a higher benefit. This is to expose more // optimization opportunities to narrowing patterns. patterns.add( patterns.getContext(), options, PatternBenefit(2)); patterns.add(patterns.getContext(), options); } } // namespace mlir::arith