790 lines
30 KiB
C++
790 lines
30 KiB
C++
//===- IntNarrowing.cpp - Integer bitwidth reduction optimizations --------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Arith/Transforms/Passes.h"
|
|
|
|
#include "mlir/Analysis/Presburger/IntegerRelation.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Arith/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include <cassert>
|
|
#include <cstdint>
|
|
|
|
namespace mlir::arith {
|
|
#define GEN_PASS_DEF_ARITHINTNARROWING
|
|
#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
|
|
} // namespace mlir::arith
|
|
|
|
namespace mlir::arith {
|
|
namespace {
|
|
//===----------------------------------------------------------------------===//
|
|
// Common Helpers
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// The base for integer bitwidth narrowing patterns.
|
|
template <typename SourceOp>
|
|
struct NarrowingPattern : OpRewritePattern<SourceOp> {
|
|
NarrowingPattern(MLIRContext *ctx, const ArithIntNarrowingOptions &options,
|
|
PatternBenefit benefit = 1)
|
|
: OpRewritePattern<SourceOp>(ctx, benefit),
|
|
supportedBitwidths(options.bitwidthsSupported.begin(),
|
|
options.bitwidthsSupported.end()) {
|
|
assert(!supportedBitwidths.empty() && "Invalid options");
|
|
assert(!llvm::is_contained(supportedBitwidths, 0) && "Invalid bitwidth");
|
|
llvm::sort(supportedBitwidths);
|
|
}
|
|
|
|
FailureOr<unsigned>
|
|
getNarrowestCompatibleBitwidth(unsigned bitsRequired) const {
|
|
for (unsigned candidate : supportedBitwidths)
|
|
if (candidate >= bitsRequired)
|
|
return candidate;
|
|
|
|
return failure();
|
|
}
|
|
|
|
/// Returns the narrowest supported type that fits `bitsRequired`.
|
|
FailureOr<Type> getNarrowType(unsigned bitsRequired, Type origTy) const {
|
|
assert(origTy);
|
|
FailureOr<unsigned> bestBitwidth =
|
|
getNarrowestCompatibleBitwidth(bitsRequired);
|
|
if (failed(bestBitwidth))
|
|
return failure();
|
|
|
|
Type elemTy = getElementTypeOrSelf(origTy);
|
|
if (!isa<IntegerType>(elemTy))
|
|
return failure();
|
|
|
|
auto newElemTy = IntegerType::get(origTy.getContext(), *bestBitwidth);
|
|
if (newElemTy == elemTy)
|
|
return failure();
|
|
|
|
if (origTy == elemTy)
|
|
return newElemTy;
|
|
|
|
if (auto shapedTy = dyn_cast<ShapedType>(origTy))
|
|
if (dyn_cast<IntegerType>(shapedTy.getElementType()))
|
|
return shapedTy.clone(shapedTy.getShape(), newElemTy);
|
|
|
|
return failure();
|
|
}
|
|
|
|
private:
|
|
// Supported integer bitwidths in the ascending order.
|
|
llvm::SmallVector<unsigned, 6> supportedBitwidths;
|
|
};
|
|
|
|
/// Returns the integer bitwidth required to represent `type`.
|
|
FailureOr<unsigned> calculateBitsRequired(Type type) {
|
|
assert(type);
|
|
if (auto intTy = dyn_cast<IntegerType>(getElementTypeOrSelf(type)))
|
|
return intTy.getWidth();
|
|
|
|
return failure();
|
|
}
|
|
|
|
enum class ExtensionKind { Sign, Zero };
|
|
|
|
/// Wrapper around `arith::ExtSIOp` and `arith::ExtUIOp` ops that abstracts away
|
|
/// the exact op type. Exposes helper functions to query the types, operands,
|
|
/// and the result. This is so that we can handle both extension kinds without
|
|
/// needing to use templates or branching.
|
|
class ExtensionOp {
|
|
public:
|
|
/// Attemps to create a new extension op from `op`. Returns an extension op
|
|
/// wrapper when `op` is either `arith.extsi` or `arith.extui`, and failure
|
|
/// otherwise.
|
|
static FailureOr<ExtensionOp> from(Operation *op) {
|
|
if (dyn_cast_or_null<arith::ExtSIOp>(op))
|
|
return ExtensionOp{op, ExtensionKind::Sign};
|
|
if (dyn_cast_or_null<arith::ExtUIOp>(op))
|
|
return ExtensionOp{op, ExtensionKind::Zero};
|
|
|
|
return failure();
|
|
}
|
|
|
|
ExtensionOp(const ExtensionOp &) = default;
|
|
ExtensionOp &operator=(const ExtensionOp &) = default;
|
|
|
|
/// Creates a new extension op of the same kind.
|
|
Operation *recreate(PatternRewriter &rewriter, Location loc, Type newType,
|
|
Value in) {
|
|
if (kind == ExtensionKind::Sign)
|
|
return rewriter.create<arith::ExtSIOp>(loc, newType, in);
|
|
|
|
return rewriter.create<arith::ExtUIOp>(loc, newType, in);
|
|
}
|
|
|
|
/// Replaces `toReplace` with a new extension op of the same kind.
|
|
void recreateAndReplace(PatternRewriter &rewriter, Operation *toReplace,
|
|
Value in) {
|
|
assert(toReplace->getNumResults() == 1);
|
|
Type newType = toReplace->getResult(0).getType();
|
|
Operation *newOp = recreate(rewriter, toReplace->getLoc(), newType, in);
|
|
rewriter.replaceOp(toReplace, newOp->getResult(0));
|
|
}
|
|
|
|
ExtensionKind getKind() { return kind; }
|
|
|
|
Value getResult() { return op->getResult(0); }
|
|
Value getIn() { return op->getOperand(0); }
|
|
|
|
Type getType() { return getResult().getType(); }
|
|
Type getElementType() { return getElementTypeOrSelf(getType()); }
|
|
Type getInType() { return getIn().getType(); }
|
|
Type getInElementType() { return getElementTypeOrSelf(getInType()); }
|
|
|
|
private:
|
|
ExtensionOp(Operation *op, ExtensionKind kind) : op(op), kind(kind) {
|
|
assert(op);
|
|
assert((isa<arith::ExtSIOp, arith::ExtUIOp>(op)) && "Not an extension op");
|
|
}
|
|
Operation *op = nullptr;
|
|
ExtensionKind kind = {};
|
|
};
|
|
|
|
/// Returns the integer bitwidth required to represent `value`.
|
|
unsigned calculateBitsRequired(const APInt &value,
|
|
ExtensionKind lookThroughExtension) {
|
|
// For unsigned values, we only need the active bits. As a special case, zero
|
|
// requires one bit.
|
|
if (lookThroughExtension == ExtensionKind::Zero)
|
|
return std::max(value.getActiveBits(), 1u);
|
|
|
|
// If a signed value is nonnegative, we need one extra bit for the sign.
|
|
if (value.isNonNegative())
|
|
return value.getActiveBits() + 1;
|
|
|
|
// For the signed min, we need all the bits.
|
|
if (value.isMinSignedValue())
|
|
return value.getBitWidth();
|
|
|
|
// For negative values, we need all the non-sign bits and one extra bit for
|
|
// the sign.
|
|
return value.getBitWidth() - value.getNumSignBits() + 1;
|
|
}
|
|
|
|
/// Returns the integer bitwidth required to represent `value`.
|
|
/// Looks through either sign- or zero-extension as specified by
|
|
/// `lookThroughExtension`.
|
|
FailureOr<unsigned> calculateBitsRequired(Value value,
|
|
ExtensionKind lookThroughExtension) {
|
|
// Handle constants.
|
|
if (TypedAttr attr; matchPattern(value, m_Constant(&attr))) {
|
|
if (auto intAttr = dyn_cast<IntegerAttr>(attr))
|
|
return calculateBitsRequired(intAttr.getValue(), lookThroughExtension);
|
|
|
|
if (auto elemsAttr = dyn_cast<DenseElementsAttr>(attr)) {
|
|
if (elemsAttr.getElementType().isIntOrIndex()) {
|
|
if (elemsAttr.isSplat())
|
|
return calculateBitsRequired(elemsAttr.getSplatValue<APInt>(),
|
|
lookThroughExtension);
|
|
|
|
unsigned maxBits = 1;
|
|
for (const APInt &elemValue : elemsAttr.getValues<APInt>())
|
|
maxBits = std::max(
|
|
maxBits, calculateBitsRequired(elemValue, lookThroughExtension));
|
|
return maxBits;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (lookThroughExtension == ExtensionKind::Sign) {
|
|
if (auto sext = value.getDefiningOp<arith::ExtSIOp>())
|
|
return calculateBitsRequired(sext.getIn().getType());
|
|
} else if (lookThroughExtension == ExtensionKind::Zero) {
|
|
if (auto zext = value.getDefiningOp<arith::ExtUIOp>())
|
|
return calculateBitsRequired(zext.getIn().getType());
|
|
}
|
|
|
|
// If nothing else worked, return the type requirements for this element type.
|
|
return calculateBitsRequired(value.getType());
|
|
}
|
|
|
|
/// Base pattern for arith binary ops.
|
|
/// Example:
|
|
/// ```
|
|
/// %lhs = arith.extsi %a : i8 to i32
|
|
/// %rhs = arith.extsi %b : i8 to i32
|
|
/// %r = arith.addi %lhs, %rhs : i32
|
|
/// ==>
|
|
/// %lhs = arith.extsi %a : i8 to i16
|
|
/// %rhs = arith.extsi %b : i8 to i16
|
|
/// %add = arith.addi %lhs, %rhs : i16
|
|
/// %r = arith.extsi %add : i16 to i32
|
|
/// ```
|
|
template <typename BinaryOp>
|
|
struct BinaryOpNarrowingPattern : NarrowingPattern<BinaryOp> {
|
|
using NarrowingPattern<BinaryOp>::NarrowingPattern;
|
|
|
|
/// Returns the number of bits required to represent the full result, assuming
|
|
/// that both operands are `operandBits`-wide. Derived classes must implement
|
|
/// this, taking into account `BinaryOp` semantics.
|
|
virtual unsigned getResultBitsProduced(unsigned operandBits) const = 0;
|
|
|
|
/// Customization point for patterns that should only apply with
|
|
/// zero/sign-extension ops as arguments.
|
|
virtual bool isSupported(ExtensionOp) const { return true; }
|
|
|
|
LogicalResult matchAndRewrite(BinaryOp op,
|
|
PatternRewriter &rewriter) const final {
|
|
Type origTy = op.getType();
|
|
FailureOr<unsigned> resultBits = calculateBitsRequired(origTy);
|
|
if (failed(resultBits))
|
|
return failure();
|
|
|
|
// For the optimization to apply, we expect the lhs to be an extension op,
|
|
// and for the rhs to either be the same extension op or a constant.
|
|
FailureOr<ExtensionOp> ext = ExtensionOp::from(op.getLhs().getDefiningOp());
|
|
if (failed(ext) || !isSupported(*ext))
|
|
return failure();
|
|
|
|
FailureOr<unsigned> lhsBitsRequired =
|
|
calculateBitsRequired(ext->getIn(), ext->getKind());
|
|
if (failed(lhsBitsRequired) || *lhsBitsRequired >= *resultBits)
|
|
return failure();
|
|
|
|
FailureOr<unsigned> rhsBitsRequired =
|
|
calculateBitsRequired(op.getRhs(), ext->getKind());
|
|
if (failed(rhsBitsRequired) || *rhsBitsRequired >= *resultBits)
|
|
return failure();
|
|
|
|
// Negotiate a common bit requirements for both lhs and rhs, accounting for
|
|
// the result requiring more bits than the operands.
|
|
unsigned commonBitsRequired =
|
|
getResultBitsProduced(std::max(*lhsBitsRequired, *rhsBitsRequired));
|
|
FailureOr<Type> narrowTy = this->getNarrowType(commonBitsRequired, origTy);
|
|
if (failed(narrowTy) || calculateBitsRequired(*narrowTy) >= *resultBits)
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
Value newLhs =
|
|
rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getLhs());
|
|
Value newRhs =
|
|
rewriter.createOrFold<arith::TruncIOp>(loc, *narrowTy, op.getRhs());
|
|
Value newAdd = rewriter.create<BinaryOp>(loc, newLhs, newRhs);
|
|
ext->recreateAndReplace(rewriter, op, newAdd);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AddIOp Pattern
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct AddIPattern final : BinaryOpNarrowingPattern<arith::AddIOp> {
|
|
using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
|
|
|
|
// Addition may require one extra bit for the result.
|
|
// Example: `UINT8_MAX + 1 == 255 + 1 == 256`.
|
|
unsigned getResultBitsProduced(unsigned operandBits) const override {
|
|
return operandBits + 1;
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SubIOp Pattern
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct SubIPattern final : BinaryOpNarrowingPattern<arith::SubIOp> {
|
|
using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
|
|
|
|
// This optimization only applies to signed arguments.
|
|
bool isSupported(ExtensionOp ext) const override {
|
|
return ext.getKind() == ExtensionKind::Sign;
|
|
}
|
|
|
|
// Subtraction may require one extra bit for the result.
|
|
// Example: `INT8_MAX - (-1) == 127 - (-1) == 128`.
|
|
unsigned getResultBitsProduced(unsigned operandBits) const override {
|
|
return operandBits + 1;
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// MulIOp Pattern
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct MulIPattern final : BinaryOpNarrowingPattern<arith::MulIOp> {
|
|
using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
|
|
|
|
// Multiplication may require up double the operand bits.
|
|
// Example: `UNT8_MAX * UINT8_MAX == 255 * 255 == 65025`.
|
|
unsigned getResultBitsProduced(unsigned operandBits) const override {
|
|
return 2 * operandBits;
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DivSIOp Pattern
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct DivSIPattern final : BinaryOpNarrowingPattern<arith::DivSIOp> {
|
|
using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
|
|
|
|
// This optimization only applies to signed arguments.
|
|
bool isSupported(ExtensionOp ext) const override {
|
|
return ext.getKind() == ExtensionKind::Sign;
|
|
}
|
|
|
|
// Unlike multiplication, signed division requires only one more result bit.
|
|
// Example: `INT8_MIN / (-1) == -128 / (-1) == 128`.
|
|
unsigned getResultBitsProduced(unsigned operandBits) const override {
|
|
return operandBits + 1;
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DivUIOp Pattern
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct DivUIPattern final : BinaryOpNarrowingPattern<arith::DivUIOp> {
|
|
using BinaryOpNarrowingPattern::BinaryOpNarrowingPattern;
|
|
|
|
// This optimization only applies to unsigned arguments.
|
|
bool isSupported(ExtensionOp ext) const override {
|
|
return ext.getKind() == ExtensionKind::Zero;
|
|
}
|
|
|
|
// Unsigned division does not require any extra result bits.
|
|
unsigned getResultBitsProduced(unsigned operandBits) const override {
|
|
return operandBits;
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Min/Max Patterns
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename MinMaxOp, ExtensionKind Kind>
|
|
struct MinMaxPattern final : BinaryOpNarrowingPattern<MinMaxOp> {
|
|
using BinaryOpNarrowingPattern<MinMaxOp>::BinaryOpNarrowingPattern;
|
|
|
|
bool isSupported(ExtensionOp ext) const override {
|
|
return ext.getKind() == Kind;
|
|
}
|
|
|
|
// Min/max returns one of the arguments and does not require any extra result
|
|
// bits.
|
|
unsigned getResultBitsProduced(unsigned operandBits) const override {
|
|
return operandBits;
|
|
}
|
|
};
|
|
using MaxSIPattern = MinMaxPattern<arith::MaxSIOp, ExtensionKind::Sign>;
|
|
using MaxUIPattern = MinMaxPattern<arith::MaxUIOp, ExtensionKind::Zero>;
|
|
using MinSIPattern = MinMaxPattern<arith::MinSIOp, ExtensionKind::Sign>;
|
|
using MinUIPattern = MinMaxPattern<arith::MinUIOp, ExtensionKind::Zero>;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// *IToFPOp Patterns
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename IToFPOp, ExtensionKind Extension>
|
|
struct IToFPPattern final : NarrowingPattern<IToFPOp> {
|
|
using NarrowingPattern<IToFPOp>::NarrowingPattern;
|
|
|
|
LogicalResult matchAndRewrite(IToFPOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
FailureOr<unsigned> narrowestWidth =
|
|
calculateBitsRequired(op.getIn(), Extension);
|
|
if (failed(narrowestWidth))
|
|
return failure();
|
|
|
|
FailureOr<Type> narrowTy =
|
|
this->getNarrowType(*narrowestWidth, op.getIn().getType());
|
|
if (failed(narrowTy))
|
|
return failure();
|
|
|
|
Value newIn = rewriter.createOrFold<arith::TruncIOp>(op.getLoc(), *narrowTy,
|
|
op.getIn());
|
|
rewriter.replaceOpWithNewOp<IToFPOp>(op, op.getType(), newIn);
|
|
return success();
|
|
}
|
|
};
|
|
using SIToFPPattern = IToFPPattern<arith::SIToFPOp, ExtensionKind::Sign>;
|
|
using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Index Cast Patterns
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// These rely on the `ValueBounds` interface for index values. For example, we
|
|
// can often statically tell index value bounds of loop induction variables.
|
|
|
|
template <typename CastOp, ExtensionKind Kind>
|
|
struct IndexCastPattern final : NarrowingPattern<CastOp> {
|
|
using NarrowingPattern<CastOp>::NarrowingPattern;
|
|
|
|
LogicalResult matchAndRewrite(CastOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Value in = op.getIn();
|
|
// We only support scalar index -> integer casts.
|
|
if (!isa<IndexType>(in.getType()))
|
|
return failure();
|
|
|
|
// Check the lower bound in both the signed and unsigned cast case. We
|
|
// conservatively assume that even unsigned casts may be performed on
|
|
// negative indices.
|
|
FailureOr<int64_t> lb = ValueBoundsConstraintSet::computeConstantBound(
|
|
presburger::BoundType::LB, in);
|
|
if (failed(lb))
|
|
return failure();
|
|
|
|
FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
|
|
presburger::BoundType::UB, in, /*dim=*/std::nullopt,
|
|
/*stopCondition=*/nullptr, /*closedUB=*/true);
|
|
if (failed(ub))
|
|
return failure();
|
|
|
|
assert(*lb <= *ub && "Invalid bounds");
|
|
unsigned lbBitsRequired = calculateBitsRequired(APInt(64, *lb), Kind);
|
|
unsigned ubBitsRequired = calculateBitsRequired(APInt(64, *ub), Kind);
|
|
unsigned bitsRequired = std::max(lbBitsRequired, ubBitsRequired);
|
|
|
|
IntegerType resultTy = cast<IntegerType>(op.getType());
|
|
if (resultTy.getWidth() <= bitsRequired)
|
|
return failure();
|
|
|
|
FailureOr<Type> narrowTy = this->getNarrowType(bitsRequired, resultTy);
|
|
if (failed(narrowTy))
|
|
return failure();
|
|
|
|
Value newCast = rewriter.create<CastOp>(op.getLoc(), *narrowTy, op.getIn());
|
|
|
|
if (Kind == ExtensionKind::Sign)
|
|
rewriter.replaceOpWithNewOp<arith::ExtSIOp>(op, resultTy, newCast);
|
|
else
|
|
rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, resultTy, newCast);
|
|
return success();
|
|
}
|
|
};
|
|
using IndexCastSIPattern =
|
|
IndexCastPattern<arith::IndexCastOp, ExtensionKind::Sign>;
|
|
using IndexCastUIPattern =
|
|
IndexCastPattern<arith::IndexCastUIOp, ExtensionKind::Zero>;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Patterns to Commute Extension Ops
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct ExtensionOverBroadcast final : NarrowingPattern<vector::BroadcastOp> {
|
|
using NarrowingPattern::NarrowingPattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::BroadcastOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
FailureOr<ExtensionOp> ext =
|
|
ExtensionOp::from(op.getSource().getDefiningOp());
|
|
if (failed(ext))
|
|
return failure();
|
|
|
|
VectorType origTy = op.getResultVectorType();
|
|
VectorType newTy =
|
|
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
|
|
Value newBroadcast =
|
|
rewriter.create<vector::BroadcastOp>(op.getLoc(), newTy, ext->getIn());
|
|
ext->recreateAndReplace(rewriter, op, newBroadcast);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
|
|
using NarrowingPattern::NarrowingPattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::ExtractOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
FailureOr<ExtensionOp> ext =
|
|
ExtensionOp::from(op.getVector().getDefiningOp());
|
|
if (failed(ext))
|
|
return failure();
|
|
|
|
Value newExtract = rewriter.create<vector::ExtractOp>(
|
|
op.getLoc(), ext->getIn(), op.getMixedPosition());
|
|
ext->recreateAndReplace(rewriter, op, newExtract);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ExtensionOverExtractElement final
|
|
: NarrowingPattern<vector::ExtractElementOp> {
|
|
using NarrowingPattern::NarrowingPattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::ExtractElementOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
FailureOr<ExtensionOp> ext =
|
|
ExtensionOp::from(op.getVector().getDefiningOp());
|
|
if (failed(ext))
|
|
return failure();
|
|
|
|
Value newExtract = rewriter.create<vector::ExtractElementOp>(
|
|
op.getLoc(), ext->getIn(), op.getPosition());
|
|
ext->recreateAndReplace(rewriter, op, newExtract);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ExtensionOverExtractStridedSlice final
|
|
: NarrowingPattern<vector::ExtractStridedSliceOp> {
|
|
using NarrowingPattern::NarrowingPattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
FailureOr<ExtensionOp> ext =
|
|
ExtensionOp::from(op.getVector().getDefiningOp());
|
|
if (failed(ext))
|
|
return failure();
|
|
|
|
VectorType origTy = op.getType();
|
|
VectorType extractTy =
|
|
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
|
|
Value newExtract = rewriter.create<vector::ExtractStridedSliceOp>(
|
|
op.getLoc(), extractTy, ext->getIn(), op.getOffsets(), op.getSizes(),
|
|
op.getStrides());
|
|
ext->recreateAndReplace(rewriter, op, newExtract);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Base pattern for `vector.insert` narrowing patterns.
|
|
template <typename InsertionOp>
|
|
struct ExtensionOverInsertionPattern : NarrowingPattern<InsertionOp> {
|
|
using NarrowingPattern<InsertionOp>::NarrowingPattern;
|
|
|
|
/// Derived classes must provide a function to create the matching insertion
|
|
/// op based on the original op and new arguments.
|
|
virtual InsertionOp createInsertionOp(PatternRewriter &rewriter,
|
|
InsertionOp origInsert,
|
|
Value narrowValue,
|
|
Value narrowDest) const = 0;
|
|
|
|
LogicalResult matchAndRewrite(InsertionOp op,
|
|
PatternRewriter &rewriter) const final {
|
|
FailureOr<ExtensionOp> ext =
|
|
ExtensionOp::from(op.getSource().getDefiningOp());
|
|
if (failed(ext))
|
|
return failure();
|
|
|
|
FailureOr<InsertionOp> newInsert = createNarrowInsert(op, rewriter, *ext);
|
|
if (failed(newInsert))
|
|
return failure();
|
|
ext->recreateAndReplace(rewriter, op, *newInsert);
|
|
return success();
|
|
}
|
|
|
|
FailureOr<InsertionOp> createNarrowInsert(InsertionOp op,
|
|
PatternRewriter &rewriter,
|
|
ExtensionOp insValue) const {
|
|
// Calculate the operand and result bitwidths. We can only apply narrowing
|
|
// when the inserted source value and destination vector require fewer bits
|
|
// than the result. Because the source and destination may have different
|
|
// bitwidths requirements, we have to find the common narrow bitwidth that
|
|
// is greater equal to the operand bitwidth requirements and still narrower
|
|
// than the result.
|
|
FailureOr<unsigned> origBitsRequired = calculateBitsRequired(op.getType());
|
|
if (failed(origBitsRequired))
|
|
return failure();
|
|
|
|
// TODO: We could relax this check by disregarding bitwidth requirements of
|
|
// elements that we know will be replaced by the insertion.
|
|
FailureOr<unsigned> destBitsRequired =
|
|
calculateBitsRequired(op.getDest(), insValue.getKind());
|
|
if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired)
|
|
return failure();
|
|
|
|
FailureOr<unsigned> insertedBitsRequired =
|
|
calculateBitsRequired(insValue.getIn(), insValue.getKind());
|
|
if (failed(insertedBitsRequired) ||
|
|
*insertedBitsRequired >= *origBitsRequired)
|
|
return failure();
|
|
|
|
// Find a narrower element type that satisfies the bitwidth requirements of
|
|
// both the source and the destination values.
|
|
unsigned newInsertionBits =
|
|
std::max(*destBitsRequired, *insertedBitsRequired);
|
|
FailureOr<Type> newVecTy =
|
|
this->getNarrowType(newInsertionBits, op.getType());
|
|
if (failed(newVecTy) || *newVecTy == op.getType())
|
|
return failure();
|
|
|
|
FailureOr<Type> newInsertedValueTy =
|
|
this->getNarrowType(newInsertionBits, insValue.getType());
|
|
if (failed(newInsertedValueTy))
|
|
return failure();
|
|
|
|
Location loc = op.getLoc();
|
|
Value narrowValue = rewriter.createOrFold<arith::TruncIOp>(
|
|
loc, *newInsertedValueTy, insValue.getResult());
|
|
Value narrowDest =
|
|
rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest());
|
|
return createInsertionOp(rewriter, op, narrowValue, narrowDest);
|
|
}
|
|
};
|
|
|
|
struct ExtensionOverInsert final
|
|
: ExtensionOverInsertionPattern<vector::InsertOp> {
|
|
using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
|
|
|
|
vector::InsertOp createInsertionOp(PatternRewriter &rewriter,
|
|
vector::InsertOp origInsert,
|
|
Value narrowValue,
|
|
Value narrowDest) const override {
|
|
return rewriter.create<vector::InsertOp>(origInsert.getLoc(), narrowValue,
|
|
narrowDest,
|
|
origInsert.getMixedPosition());
|
|
}
|
|
};
|
|
|
|
struct ExtensionOverInsertElement final
|
|
: ExtensionOverInsertionPattern<vector::InsertElementOp> {
|
|
using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
|
|
|
|
vector::InsertElementOp createInsertionOp(PatternRewriter &rewriter,
|
|
vector::InsertElementOp origInsert,
|
|
Value narrowValue,
|
|
Value narrowDest) const override {
|
|
return rewriter.create<vector::InsertElementOp>(
|
|
origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition());
|
|
}
|
|
};
|
|
|
|
struct ExtensionOverInsertStridedSlice final
|
|
: ExtensionOverInsertionPattern<vector::InsertStridedSliceOp> {
|
|
using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
|
|
|
|
vector::InsertStridedSliceOp
|
|
createInsertionOp(PatternRewriter &rewriter,
|
|
vector::InsertStridedSliceOp origInsert, Value narrowValue,
|
|
Value narrowDest) const override {
|
|
return rewriter.create<vector::InsertStridedSliceOp>(
|
|
origInsert.getLoc(), narrowValue, narrowDest, origInsert.getOffsets(),
|
|
origInsert.getStrides());
|
|
}
|
|
};
|
|
|
|
struct ExtensionOverShapeCast final : NarrowingPattern<vector::ShapeCastOp> {
|
|
using NarrowingPattern::NarrowingPattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
FailureOr<ExtensionOp> ext =
|
|
ExtensionOp::from(op.getSource().getDefiningOp());
|
|
if (failed(ext))
|
|
return failure();
|
|
|
|
VectorType origTy = op.getResultVectorType();
|
|
VectorType newTy =
|
|
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
|
|
Value newCast =
|
|
rewriter.create<vector::ShapeCastOp>(op.getLoc(), newTy, ext->getIn());
|
|
ext->recreateAndReplace(rewriter, op, newCast);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
|
|
using NarrowingPattern::NarrowingPattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::TransposeOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
FailureOr<ExtensionOp> ext =
|
|
ExtensionOp::from(op.getVector().getDefiningOp());
|
|
if (failed(ext))
|
|
return failure();
|
|
|
|
VectorType origTy = op.getResultVectorType();
|
|
VectorType newTy =
|
|
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
|
|
Value newTranspose = rewriter.create<vector::TransposeOp>(
|
|
op.getLoc(), newTy, ext->getIn(), op.getPermutation());
|
|
ext->recreateAndReplace(rewriter, op, newTranspose);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ExtensionOverFlatTranspose final
|
|
: NarrowingPattern<vector::FlatTransposeOp> {
|
|
using NarrowingPattern::NarrowingPattern;
|
|
|
|
LogicalResult matchAndRewrite(vector::FlatTransposeOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
FailureOr<ExtensionOp> ext =
|
|
ExtensionOp::from(op.getMatrix().getDefiningOp());
|
|
if (failed(ext))
|
|
return failure();
|
|
|
|
VectorType origTy = op.getType();
|
|
VectorType newTy =
|
|
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
|
|
Value newTranspose = rewriter.create<vector::FlatTransposeOp>(
|
|
op.getLoc(), newTy, ext->getIn(), op.getRowsAttr(),
|
|
op.getColumnsAttr());
|
|
ext->recreateAndReplace(rewriter, op, newTranspose);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Pass Definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct ArithIntNarrowingPass final
|
|
: impl::ArithIntNarrowingBase<ArithIntNarrowingPass> {
|
|
using ArithIntNarrowingBase::ArithIntNarrowingBase;
|
|
|
|
void runOnOperation() override {
|
|
if (bitwidthsSupported.empty() ||
|
|
llvm::is_contained(bitwidthsSupported, 0)) {
|
|
// Invalid pass options.
|
|
return signalPassFailure();
|
|
}
|
|
|
|
Operation *op = getOperation();
|
|
MLIRContext *ctx = op->getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
populateArithIntNarrowingPatterns(
|
|
patterns, ArithIntNarrowingOptions{bitwidthsSupported});
|
|
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Public API
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void populateArithIntNarrowingPatterns(
|
|
RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) {
|
|
// Add commute patterns with a higher benefit. This is to expose more
|
|
// optimization opportunities to narrowing patterns.
|
|
patterns.add<ExtensionOverBroadcast, ExtensionOverExtract,
|
|
ExtensionOverExtractElement, ExtensionOverExtractStridedSlice,
|
|
ExtensionOverInsert, ExtensionOverInsertElement,
|
|
ExtensionOverInsertStridedSlice, ExtensionOverShapeCast,
|
|
ExtensionOverTranspose, ExtensionOverFlatTranspose>(
|
|
patterns.getContext(), options, PatternBenefit(2));
|
|
|
|
patterns.add<AddIPattern, SubIPattern, MulIPattern, DivSIPattern,
|
|
DivUIPattern, MaxSIPattern, MaxUIPattern, MinSIPattern,
|
|
MinUIPattern, SIToFPPattern, UIToFPPattern, IndexCastSIPattern,
|
|
IndexCastUIPattern>(patterns.getContext(), options);
|
|
}
|
|
|
|
} // namespace mlir::arith
|