1058 lines
36 KiB
C++
1058 lines
36 KiB
C++
//===- TosaCanonicalizations.cpp - Canonicalization patterns & folders ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// \file
|
|
// TOSA canonicalization patterns and folders.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Quant/QuantOps.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
|
|
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
|
|
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
|
|
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/DialectImplementation.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Transforms/FoldUtils.h"
|
|
#include "mlir/Transforms/InliningUtils.h"
|
|
#include "mlir/Transforms/RegionUtils.h"
|
|
#include "llvm/ADT/APFloat.h"
|
|
#include "llvm/ADT/APInt.h"
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
|
|
#include <functional>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::tosa;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Operator Canonicalizers.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct ConcatOptimization : public OpRewritePattern<tosa::ConcatOp> {
|
|
using OpRewritePattern<tosa::ConcatOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::ConcatOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (op.getInput1().size() != 1)
|
|
return failure();
|
|
if (op.getInput1().front().getType() != op.getType()) {
|
|
rewriter
|
|
.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(),
|
|
op.getInput1().front())
|
|
.getResult();
|
|
return success();
|
|
}
|
|
|
|
rewriter.replaceOp(op, op.getInput1().front());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<ConcatOptimization>(context);
|
|
}
|
|
|
|
LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) {
|
|
auto notOp = op.getPred().getDefiningOp<tosa::LogicalNotOp>();
|
|
if (!notOp)
|
|
return failure();
|
|
rewriter.modifyOpInPlace(op, [&]() {
|
|
op.getOperation()->setOperands(
|
|
{notOp.getInput1(), op.getOnFalse(), op.getOnTrue()});
|
|
});
|
|
return success();
|
|
}
|
|
|
|
struct ConsolidateTransposeOptimization
|
|
: public OpRewritePattern<tosa::TransposeOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::TransposeOp transposeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// Input is also TransposeOp - transpose(transpose(A)).
|
|
auto innerTranspose =
|
|
transposeOp.getInput1().getDefiningOp<tosa::TransposeOp>();
|
|
if (!innerTranspose)
|
|
return rewriter.notifyMatchFailure(transposeOp,
|
|
"input must be transpose operation");
|
|
|
|
SmallVector<int64_t> transposePerms, innerTransposePerms;
|
|
if (transposeOp.getConstantPerms(transposePerms).failed())
|
|
return rewriter.notifyMatchFailure(transposeOp,
|
|
"transpose perms must be constant");
|
|
if (innerTranspose.getConstantPerms(innerTransposePerms).failed())
|
|
return rewriter.notifyMatchFailure(
|
|
transposeOp, "inner transpose perms must be constant");
|
|
if (transposePerms.size() != innerTransposePerms.size())
|
|
return rewriter.notifyMatchFailure(
|
|
transposeOp,
|
|
"transpose and inner transpose perms sizes must be equal");
|
|
if (transposePerms.empty())
|
|
return rewriter.notifyMatchFailure(
|
|
transposeOp, "transpose perms sizes must be positive");
|
|
|
|
// Consolidate transposes into one transpose.
|
|
SmallVector<int32_t> perms(transposePerms.size());
|
|
for (int i = 0, s = transposePerms.size(); i < s; ++i)
|
|
perms[i] = innerTransposePerms[transposePerms[i]];
|
|
|
|
auto permsTy =
|
|
RankedTensorType::get(transposePerms.size(), rewriter.getI32Type());
|
|
auto permsAttr = DenseIntElementsAttr::get(permsTy, perms);
|
|
Value permsValue =
|
|
rewriter.create<arith::ConstantOp>(transposeOp.getLoc(), permsAttr);
|
|
|
|
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
|
|
transposeOp, transposeOp.getResult().getType(),
|
|
innerTranspose.getInput1(), permsValue);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Determines the case when tosa.transpose is a tosa.reshape operation.
|
|
struct TransposeIsReshape : public OpRewritePattern<tosa::TransposeOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::TransposeOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
DenseIntElementsAttr permAttr;
|
|
if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
|
|
return rewriter.notifyMatchFailure(op, "Non-constant permutation");
|
|
|
|
if (op.getInput1().getDefiningOp<tosa::TransposeOp>())
|
|
return rewriter.notifyMatchFailure(
|
|
op, "Src is from transpose, can compose transposes");
|
|
|
|
Value result = op.getResult();
|
|
for (Operation *subop : result.getUsers()) {
|
|
if (dyn_cast_or_null<tosa::TransposeOp>(subop))
|
|
return rewriter.notifyMatchFailure(
|
|
op, "Dest is used by transpose, can compose transposes");
|
|
}
|
|
|
|
auto input = op.getInput1();
|
|
auto inputTy = llvm::cast<ShapedType>(input.getType());
|
|
if (!inputTy.hasRank())
|
|
return rewriter.notifyMatchFailure(op, "Unranked input.");
|
|
|
|
int64_t numDynDims = 0;
|
|
for (int i = 0; i < inputTy.getRank(); ++i)
|
|
if (inputTy.isDynamicDim(i))
|
|
numDynDims++;
|
|
|
|
if (numDynDims > 1)
|
|
return rewriter.notifyMatchFailure(op, "Has more than one dynamic dim.");
|
|
|
|
SmallVector<int64_t> permValues = llvm::to_vector<6>(
|
|
llvm::map_range(permAttr.getValues<APInt>(),
|
|
[](const APInt &val) { return val.getSExtValue(); }));
|
|
|
|
SmallVector<int64_t> nonZeroPerms;
|
|
nonZeroPerms.reserve(permValues.size());
|
|
for (auto idx : permValues) {
|
|
auto sz = inputTy.getDimSize(idx);
|
|
if (sz != 1)
|
|
nonZeroPerms.push_back(idx);
|
|
}
|
|
|
|
for (int i = 1, s = nonZeroPerms.size(); i < s; ++i)
|
|
if (nonZeroPerms[i - 1] > nonZeroPerms[i])
|
|
return rewriter.notifyMatchFailure(op,
|
|
"Transpose changes memory layout.");
|
|
|
|
SmallVector<int64_t> newShape;
|
|
newShape.reserve(inputTy.getRank());
|
|
for (int i = 0, s = inputTy.getRank(); i < s; ++i)
|
|
newShape.push_back(inputTy.getDimSize(permValues[i]));
|
|
|
|
rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
|
|
op, op.getType(), op.getInput1(),
|
|
rewriter.getDenseI64ArrayAttr(newShape));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<ConsolidateTransposeOptimization, TransposeIsReshape>(context);
|
|
}
|
|
|
|
struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::PadOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (op.getPadConst())
|
|
return failure();
|
|
|
|
auto input = op.getInput1();
|
|
auto padding = op.getPadding();
|
|
|
|
ShapedType inputTy = llvm::cast<ShapedType>(input.getType());
|
|
Type elementTy = inputTy.getElementType();
|
|
|
|
Attribute constantAttr;
|
|
if (llvm::isa<FloatType>(elementTy)) {
|
|
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
|
|
} else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) {
|
|
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
|
|
} else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) {
|
|
auto value = op.getQuantizationInfo()->getInputZp();
|
|
constantAttr = rewriter.getIntegerAttr(elementTy, value);
|
|
}
|
|
|
|
if (!constantAttr) {
|
|
return rewriter.notifyMatchFailure(
|
|
op,
|
|
"tosa.pad to linalg lowering encountered an unknown element type");
|
|
}
|
|
|
|
auto denseAttr = DenseElementsAttr::get(
|
|
RankedTensorType::get({}, elementTy), constantAttr);
|
|
auto constantVal = rewriter.create<tosa::ConstOp>(
|
|
op.getLoc(), denseAttr.getType(), denseAttr);
|
|
|
|
rewriter.replaceOpWithNewOp<tosa::PadOp>(
|
|
op, op.getType(), ValueRange{input, padding, constantVal},
|
|
op->getAttrs());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<MaterializePadValue>(context);
|
|
}
|
|
|
|
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Value input = op.getInput();
|
|
Value output = op.getOutput();
|
|
ShapedType inputType = llvm::cast<ShapedType>(input.getType());
|
|
ShapedType outputType = llvm::cast<ShapedType>(output.getType());
|
|
|
|
if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
|
|
return failure();
|
|
}
|
|
|
|
// If the output and input shapes are 1x1, then this is a no op.
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
if (outputShape[1] != 1 || outputShape[2] != 1) {
|
|
return failure();
|
|
}
|
|
|
|
ArrayRef<int64_t> inputShape = inputType.getShape();
|
|
if (inputShape[1] != 1 || inputShape[2] != 1) {
|
|
return failure();
|
|
}
|
|
|
|
rewriter.replaceOp(op, input);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void MaxPool2dOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<MaxPool2dIsNoOp>(context);
|
|
}
|
|
|
|
struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::ClampOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Value input = op.getInput();
|
|
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
|
|
auto inputElementType = inputType.getElementType();
|
|
|
|
if (!inputType.hasStaticShape()) {
|
|
return failure();
|
|
}
|
|
|
|
if (inputElementType.isa<FloatType>()) {
|
|
// Unlike integer types, floating point types can represent infinity.
|
|
auto minClamp = op.getMinFp();
|
|
auto maxClamp = op.getMaxFp();
|
|
bool isMin = minClamp.isInfinity() && minClamp.isNegative();
|
|
bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative();
|
|
|
|
if (isMin && isMax) {
|
|
rewriter.replaceOp(op, input);
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
if (inputElementType.isUnsignedInteger()) {
|
|
int64_t minClamp = op.getMinInt();
|
|
int64_t maxClamp = op.getMaxInt();
|
|
|
|
int64_t intMin =
|
|
APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
|
|
.getZExtValue();
|
|
int64_t intMax =
|
|
APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
|
|
.getZExtValue();
|
|
|
|
if (minClamp <= intMin && maxClamp >= intMax) {
|
|
rewriter.replaceOp(op, input);
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
if (llvm::isa<IntegerType>(inputElementType)) {
|
|
int64_t minClamp = op.getMinInt();
|
|
int64_t maxClamp = op.getMaxInt();
|
|
|
|
int64_t intMin =
|
|
APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
|
|
.getSExtValue();
|
|
int64_t intMax =
|
|
APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
|
|
.getSExtValue();
|
|
|
|
if (minClamp <= intMin && maxClamp >= intMax) {
|
|
rewriter.replaceOp(op, input);
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
|
|
using OpRewritePattern<tosa::ClampOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::ClampOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
Value input = op.getInput();
|
|
|
|
Operation *definingOp = input.getDefiningOp();
|
|
if (!definingOp)
|
|
return failure();
|
|
|
|
if (tosa::ClampOp clampOp = dyn_cast<tosa::ClampOp>(definingOp)) {
|
|
auto minFp = std::max(op.getMinFp(), clampOp.getMinFp()).convertToFloat();
|
|
auto maxFp = std::min(op.getMaxFp(), clampOp.getMaxFp()).convertToFloat();
|
|
|
|
auto minInt = std::max(op.getMinInt(), clampOp.getMinInt());
|
|
auto maxInt = std::min(op.getMaxInt(), clampOp.getMaxInt());
|
|
|
|
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
|
|
op, op.getType(), clampOp.getInput(),
|
|
rewriter.getI64IntegerAttr(minInt),
|
|
rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
|
|
rewriter.getF32FloatAttr(maxFp));
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
void ClampOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<ClampIsNoOp>(context);
|
|
results.add<ClampClampOptimization>(context);
|
|
}
|
|
|
|
struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
|
|
using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
|
|
PatternRewriter &rewriter) const override {
|
|
Value sliceInput = sliceOp.getInput();
|
|
auto concatOp = sliceInput.getDefiningOp<tosa::ConcatOp>();
|
|
if (!concatOp)
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "slice input must be concat operation");
|
|
|
|
OperandRange inputs = concatOp.getInput1();
|
|
auto concatType = dyn_cast<RankedTensorType>(concatOp.getType());
|
|
if (!concatType || !concatType.hasStaticShape())
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "slice input must be a static ranked tensor");
|
|
int32_t axis = concatOp.getAxis();
|
|
|
|
llvm::SmallVector<int64_t> sliceStart(sliceOp.getStart());
|
|
llvm::ArrayRef<int64_t> sliceSize = sliceOp.getSize();
|
|
|
|
// Validate slice on the concatenated axis. Slicing along this
|
|
// axis should span only one of the inputs to the concatenate
|
|
// operation.
|
|
std::optional<Value> replaceWithSlice;
|
|
for (auto input : inputs) {
|
|
auto inputType = dyn_cast<RankedTensorType>(input.getType());
|
|
if (!inputType || !inputType.hasStaticShape())
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "concat input must be a static ranked tensor");
|
|
|
|
if (sliceStart[axis] >= 0 &&
|
|
(sliceStart[axis] + sliceSize[axis]) <= inputType.getDimSize(axis)) {
|
|
replaceWithSlice = rewriter
|
|
.create<tosa::SliceOp>(
|
|
sliceOp.getLoc(), sliceOp.getType(), input,
|
|
rewriter.getDenseI64ArrayAttr(sliceStart),
|
|
rewriter.getDenseI64ArrayAttr(sliceSize))
|
|
.getResult();
|
|
break;
|
|
}
|
|
sliceStart[axis] -= inputType.getDimSize(axis);
|
|
}
|
|
|
|
if (!replaceWithSlice)
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "corresponding concat input not found for slice");
|
|
|
|
rewriter.replaceOp(sliceOp, replaceWithSlice.value());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<ConcatSliceOptimization>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Operator Folders.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename IntFolder, typename FloatFolder>
|
|
DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
|
|
RankedTensorType returnTy) {
|
|
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
|
|
auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
|
|
auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
|
|
if (lETy != rETy)
|
|
return {};
|
|
|
|
if (llvm::isa<IntegerType>(lETy)) {
|
|
APInt l = lhs.getSplatValue<APInt>();
|
|
APInt r = rhs.getSplatValue<APInt>();
|
|
auto result = IntFolder()(l, r);
|
|
return DenseElementsAttr::get(returnTy, result);
|
|
}
|
|
|
|
if (llvm::isa<FloatType>(lETy)) {
|
|
APFloat l = lhs.getSplatValue<APFloat>();
|
|
APFloat r = rhs.getSplatValue<APFloat>();
|
|
auto result = FloatFolder()(l, r);
|
|
return DenseElementsAttr::get(returnTy, result);
|
|
}
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
static bool isSplatZero(Type elemType, DenseElementsAttr val) {
|
|
if (llvm::isa<FloatType>(elemType))
|
|
return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
|
|
if (llvm::isa<IntegerType>(elemType))
|
|
return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
|
|
return false;
|
|
}
|
|
|
|
static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) {
|
|
if (llvm::isa<FloatType>(elemType))
|
|
return val && val.isSplat() &&
|
|
val.getSplatValue<APFloat>().isExactlyValue(1.0);
|
|
if (llvm::isa<IntegerType>(elemType)) {
|
|
const int64_t shifted = 1LL << shift;
|
|
return val && val.isSplat() &&
|
|
val.getSplatValue<APInt>().getSExtValue() == shifted;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
|
|
auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
|
|
auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
|
|
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
|
if (!lhsTy || !rhsTy || !resultTy)
|
|
return {};
|
|
|
|
auto resultETy = resultTy.getElementType();
|
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
|
|
|
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
|
|
return getInput1();
|
|
if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr))
|
|
return getInput2();
|
|
|
|
if (!lhsAttr || !rhsAttr)
|
|
return {};
|
|
|
|
return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
|
|
resultTy);
|
|
}
|
|
|
|
OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
|
|
auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
|
|
auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
|
|
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
|
if (!lhsTy || !rhsTy || !resultTy)
|
|
return {};
|
|
if (lhsTy != rhsTy)
|
|
return {};
|
|
|
|
auto resultETy = resultTy.getElementType();
|
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
|
if (lhsAttr && lhsAttr.isSplat()) {
|
|
if (llvm::isa<IntegerType>(resultETy) &&
|
|
lhsAttr.getSplatValue<APInt>().isZero())
|
|
return lhsAttr;
|
|
}
|
|
|
|
if (rhsAttr && rhsAttr.isSplat()) {
|
|
if (llvm::isa<IntegerType>(resultETy) &&
|
|
rhsAttr.getSplatValue<APInt>().isOne())
|
|
return getInput1();
|
|
}
|
|
|
|
if (rhsAttr && lhsAttr && rhsAttr.isSplat() && lhsAttr.isSplat()) {
|
|
if (llvm::isa<IntegerType>(resultETy)) {
|
|
APInt l = lhsAttr.getSplatValue<APInt>();
|
|
APInt r = rhsAttr.getSplatValue<APInt>();
|
|
APInt result = l.sdiv(r);
|
|
return DenseElementsAttr::get(resultTy, result);
|
|
}
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
namespace {
|
|
DenseElementsAttr mulBinaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
|
|
RankedTensorType ty, int32_t shift) {
|
|
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
|
|
if (llvm::isa<IntegerType>(ty.getElementType())) {
|
|
APInt l = lhs.getSplatValue<APInt>();
|
|
APInt r = rhs.getSplatValue<APInt>();
|
|
|
|
if (shift == 0) {
|
|
return DenseElementsAttr::get(ty, l * r);
|
|
}
|
|
|
|
auto bitwidth = ty.getElementType().getIntOrFloatBitWidth();
|
|
l = l.sext(bitwidth * 2);
|
|
r = r.sext(bitwidth * 2);
|
|
auto result = l * r;
|
|
result.lshrInPlace(shift);
|
|
result = result.trunc(bitwidth);
|
|
return DenseElementsAttr::get(ty, result);
|
|
}
|
|
|
|
if (llvm::isa<FloatType>(ty.getElementType())) {
|
|
APFloat l = lhs.getSplatValue<APFloat>();
|
|
APFloat r = rhs.getSplatValue<APFloat>();
|
|
APFloat result = l * r;
|
|
return DenseElementsAttr::get(ty, result);
|
|
}
|
|
}
|
|
|
|
return {};
|
|
}
|
|
} // namespace
|
|
|
|
OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
|
|
auto lhs = getInput1();
|
|
auto rhs = getInput2();
|
|
auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
|
|
auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
|
|
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
|
if (!lhsTy || !rhsTy || !resultTy)
|
|
return {};
|
|
|
|
auto resultETy = resultTy.getElementType();
|
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
|
|
|
const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
|
|
if (rhsTy == resultTy) {
|
|
if (isSplatZero(resultETy, lhsAttr))
|
|
return lhsAttr.resizeSplat(resultTy);
|
|
if (isSplatOne(resultETy, lhsAttr, shift))
|
|
return rhs;
|
|
}
|
|
if (lhsTy == resultTy) {
|
|
if (isSplatZero(resultETy, rhsAttr))
|
|
return rhsAttr.resizeSplat(resultTy);
|
|
if (isSplatOne(resultETy, rhsAttr, shift))
|
|
return lhs;
|
|
}
|
|
|
|
return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
|
|
}
|
|
|
|
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
|
|
auto lhsTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
|
|
auto rhsTy = llvm::dyn_cast<RankedTensorType>(getInput2().getType());
|
|
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
|
if (!lhsTy || !rhsTy || !resultTy)
|
|
return {};
|
|
|
|
auto resultETy = resultTy.getElementType();
|
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
|
|
|
if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr))
|
|
return getInput1();
|
|
|
|
if (!lhsAttr || !rhsAttr)
|
|
return {};
|
|
|
|
return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
|
|
resultTy);
|
|
}
|
|
|
|
namespace {
|
|
template <typename Cmp>
|
|
struct ComparisonFold {
|
|
ComparisonFold() = default;
|
|
APInt operator()(const APInt &l, const APInt &r) {
|
|
return APInt(1, Cmp()(l, r));
|
|
}
|
|
|
|
APInt operator()(const APFloat &l, const APFloat &r) {
|
|
return APInt(1, Cmp()(l, r));
|
|
}
|
|
};
|
|
|
|
struct APIntFoldGreater {
|
|
APIntFoldGreater() = default;
|
|
APInt operator()(const APInt &l, const APInt &r) {
|
|
return APInt(1, l.sgt(r));
|
|
}
|
|
};
|
|
|
|
struct APIntFoldGreaterEqual {
|
|
APIntFoldGreaterEqual() = default;
|
|
APInt operator()(const APInt &l, const APInt &r) {
|
|
return APInt(1, l.sge(r));
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
|
|
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
|
|
|
if (!lhsAttr || !rhsAttr)
|
|
return {};
|
|
|
|
return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
|
|
lhsAttr, rhsAttr, resultTy);
|
|
}
|
|
|
|
OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
|
|
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
|
|
|
if (!lhsAttr || !rhsAttr)
|
|
return {};
|
|
|
|
return binaryFolder<APIntFoldGreaterEqual,
|
|
ComparisonFold<std::greater_equal<APFloat>>>(
|
|
lhsAttr, rhsAttr, resultTy);
|
|
}
|
|
|
|
OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
|
|
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
|
|
auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
|
|
auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
|
|
Value lhs = getInput1();
|
|
Value rhs = getInput2();
|
|
auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
|
|
|
|
// If we are comparing an integer value to itself it is always true. We can
|
|
// not do this with float due to float values.
|
|
if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
|
|
resultTy.hasStaticShape() && lhs == rhs) {
|
|
return DenseElementsAttr::get(resultTy, true);
|
|
}
|
|
|
|
if (!lhsAttr || !rhsAttr)
|
|
return {};
|
|
|
|
return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
|
|
ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
|
|
resultTy);
|
|
}
|
|
|
|
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
|
|
if (getInput().getType() == getType())
|
|
return getInput();
|
|
|
|
auto operand = llvm::dyn_cast_if_present<ElementsAttr>(adaptor.getInput());
|
|
if (!operand)
|
|
return {};
|
|
|
|
auto inTy = llvm::cast<ShapedType>(getInput().getType());
|
|
auto outTy = llvm::cast<ShapedType>(getType());
|
|
auto inETy = inTy.getElementType();
|
|
auto outETy = outTy.getElementType();
|
|
|
|
if (operand.isSplat()) {
|
|
if (llvm::isa<FloatType>(inETy) && llvm::isa<FloatType>(outETy)) {
|
|
bool overflow;
|
|
auto splatVal = operand.getSplatValue<APFloat>();
|
|
auto &semantics = llvm::cast<FloatType>(outETy).getFloatSemantics();
|
|
splatVal.convert(semantics, llvm::RoundingMode::NearestTiesToEven,
|
|
&overflow);
|
|
return SplatElementsAttr::get(outTy, splatVal);
|
|
}
|
|
|
|
if (llvm::isa<IntegerType>(inETy) && llvm::isa<FloatType>(outETy)) {
|
|
auto unsign = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
|
|
APFloat splatVal(llvm::cast<FloatType>(outETy).getFloatSemantics());
|
|
splatVal.convertFromAPInt(operand.getSplatValue<APInt>(), !unsign,
|
|
llvm::RoundingMode::NearestTiesToEven);
|
|
return SplatElementsAttr::get(outTy, splatVal);
|
|
}
|
|
|
|
if (llvm::isa<FloatType>(inETy) && llvm::isa<IntegerType>(outETy)) {
|
|
auto unsign = llvm::cast<IntegerType>(outETy).isUnsignedInteger();
|
|
auto intVal = APSInt(
|
|
llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
|
|
auto floatVal = operand.getSplatValue<APFloat>();
|
|
bool exact;
|
|
floatVal.convertToInteger(intVal, llvm::RoundingMode::TowardZero, &exact);
|
|
return SplatElementsAttr::get(outTy, intVal);
|
|
}
|
|
|
|
if (llvm::isa<IntegerType>(inETy) && llvm::isa<IntegerType>(outETy)) {
|
|
auto unsignIn = llvm::cast<IntegerType>(inETy).isUnsignedInteger();
|
|
bool trunc =
|
|
inETy.getIntOrFloatBitWidth() > outETy.getIntOrFloatBitWidth();
|
|
auto intVal = operand.getSplatValue<APInt>();
|
|
auto bitwidth = outETy.getIntOrFloatBitWidth();
|
|
|
|
if (trunc) {
|
|
intVal = intVal.trunc(bitwidth);
|
|
} else if (unsignIn) {
|
|
intVal = intVal.zext(bitwidth);
|
|
} else {
|
|
intVal = intVal.sext(bitwidth);
|
|
}
|
|
|
|
return SplatElementsAttr::get(outTy, intVal);
|
|
}
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
|
|
|
|
#define REDUCE_FOLDER(OP) \
|
|
OpFoldResult OP::fold(FoldAdaptor adaptor) { \
|
|
ShapedType inputTy = llvm::cast<ShapedType>(getInput().getType()); \
|
|
if (!inputTy.hasRank()) \
|
|
return {}; \
|
|
if (inputTy != getType()) \
|
|
return {}; \
|
|
if (inputTy.getRank() == 0 || inputTy.getDimSize(getAxis()) == 1) \
|
|
return getInput(); \
|
|
return {}; \
|
|
}
|
|
|
|
REDUCE_FOLDER(ReduceAllOp)
|
|
REDUCE_FOLDER(ReduceAnyOp)
|
|
REDUCE_FOLDER(ReduceMaxOp)
|
|
REDUCE_FOLDER(ReduceMinOp)
|
|
REDUCE_FOLDER(ReduceProdOp)
|
|
REDUCE_FOLDER(ReduceSumOp)
|
|
#undef REDUCE_FOLDER
|
|
|
|
OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
|
|
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
|
|
auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
|
|
|
|
if (!inputTy || !outputTy)
|
|
return {};
|
|
|
|
if (inputTy == outputTy)
|
|
return getInput1();
|
|
|
|
// reshape(reshape(x)) -> reshape(x)
|
|
if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
|
|
getInput1().getDefiningOp())) {
|
|
getInput1Mutable().assign(reshapeOp.getInput1());
|
|
return getResult();
|
|
}
|
|
|
|
// reshape(const(x)) -> const(reshape-attr(x))
|
|
if (auto operand = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
|
|
// Constants must have static shape.
|
|
if (!outputTy.hasStaticShape())
|
|
return {};
|
|
|
|
// Okay to duplicate splat constants.
|
|
if (operand.isSplat())
|
|
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
|
|
|
|
// Don't duplicate other constants.
|
|
if (!getInput1().hasOneUse())
|
|
return {};
|
|
|
|
return operand.reshape(
|
|
llvm::cast<ShapedType>(operand.getType()).clone(getNewShape()));
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
|
|
// If the pad is all zeros we can fold this operation away.
|
|
if (adaptor.getPadding()) {
|
|
auto densePad = llvm::cast<DenseElementsAttr>(adaptor.getPadding());
|
|
if (densePad.isSplat() && densePad.getSplatValue<APInt>().isZero()) {
|
|
return getInput1();
|
|
}
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
// Fold away cases where a tosa.resize operation returns a copy
|
|
// of the input image.
|
|
OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
|
|
ArrayRef<int64_t> offset = getOffset();
|
|
ArrayRef<int64_t> border = getBorder();
|
|
ArrayRef<int64_t> scale = getScale();
|
|
|
|
// Check unit scaling.
|
|
if (scale[0] != scale[1] || scale[2] != scale[3]) {
|
|
return {};
|
|
}
|
|
|
|
// There should be no offset.
|
|
if (offset[0] != 0 || offset[1] != 0) {
|
|
return {};
|
|
}
|
|
|
|
// There should be no border.
|
|
if (border[0] != 0 || border[1] != 0) {
|
|
return {};
|
|
}
|
|
|
|
auto input = getInput();
|
|
auto inputTy = llvm::cast<RankedTensorType>(input.getType());
|
|
auto resultTy = llvm::cast<RankedTensorType>(getType());
|
|
if (inputTy != resultTy)
|
|
return {};
|
|
|
|
return input;
|
|
}
|
|
|
|
OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
|
|
auto operand = getInput();
|
|
auto operandTy = llvm::cast<ShapedType>(operand.getType());
|
|
auto axis = getAxis();
|
|
auto operandAttr = llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput());
|
|
if (operandAttr)
|
|
return operandAttr;
|
|
|
|
// If the dim-length is 1, tosa.reverse is a no-op.
|
|
if (operandTy.hasRank() &&
|
|
(operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
|
|
return operand;
|
|
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
|
|
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput().getType());
|
|
auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
|
|
|
|
if (!inputTy || !outputTy)
|
|
return {};
|
|
|
|
if (inputTy == outputTy && inputTy.hasStaticShape())
|
|
return getInput();
|
|
|
|
if (!adaptor.getInput())
|
|
return {};
|
|
|
|
// Cannot create an ElementsAttr from non-int/float/index types
|
|
if (!inputTy.getElementType().isIntOrIndexOrFloat() ||
|
|
!outputTy.getElementType().isIntOrIndexOrFloat())
|
|
return {};
|
|
|
|
auto operand = llvm::cast<ElementsAttr>(adaptor.getInput());
|
|
if (operand.isSplat() && outputTy.hasStaticShape()) {
|
|
return SplatElementsAttr::get(outputTy, operand.getSplatValue<Attribute>());
|
|
}
|
|
|
|
if (inputTy.hasStaticShape() && outputTy.hasStaticShape() &&
|
|
outputTy.getNumElements() == 1) {
|
|
llvm::SmallVector<uint64_t> indices(getStart());
|
|
auto value = operand.getValues<Attribute>()[indices];
|
|
return SplatElementsAttr::get(outputTy, value);
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
|
|
if (getOnTrue() == getOnFalse())
|
|
return getOnTrue();
|
|
|
|
auto predicate = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getPred());
|
|
if (!predicate)
|
|
return {};
|
|
|
|
if (!predicate.isSplat())
|
|
return {};
|
|
return predicate.getSplatValue<APInt>().getBoolValue() ? getOnTrue()
|
|
: getOnFalse();
|
|
}
|
|
|
|
OpFoldResult TileOp::fold(FoldAdaptor adaptor) {
|
|
bool allOnes = llvm::all_of(getMultiples(), [](int64_t v) { return v == 1; });
|
|
if (allOnes && getInput1().getType() == getType())
|
|
return getInput1();
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
|
|
auto inputTy = llvm::cast<ShapedType>(getInput1().getType());
|
|
auto resultTy = llvm::cast<ShapedType>(getType());
|
|
|
|
// Transposing splat values just means reshaping.
|
|
if (auto input = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
|
|
if (input.isSplat() && resultTy.hasStaticShape() &&
|
|
inputTy.getElementType() == resultTy.getElementType())
|
|
return input.reshape(resultTy);
|
|
}
|
|
|
|
// Transpose does not change the input type.
|
|
if (getInput1().getType() != getType())
|
|
return {};
|
|
|
|
// Transpose is not the identity transpose.
|
|
SmallVector<int64_t> perms;
|
|
if (getConstantPerms(perms).failed())
|
|
return {};
|
|
|
|
if (!llvm::equal(llvm::seq<int64_t>(0, perms.size()), perms))
|
|
return {};
|
|
|
|
return getInput1();
|
|
}
|
|
|
|
OpFoldResult tosa::LogOp::fold(FoldAdaptor adaptor) {
|
|
auto input = getInput1();
|
|
// Element-wise log(exp(x)) = x
|
|
if (auto op = input.getDefiningOp<tosa::ExpOp>()) {
|
|
return op.getInput1();
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
|
|
auto input = getInput1();
|
|
// Element-wise exp(log(x)) = x
|
|
if (auto op = input.getDefiningOp<tosa::LogOp>()) {
|
|
return op.getInput1();
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
|
|
auto input = getInput1();
|
|
// Element-wise negate(negate(x)) = x
|
|
if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
|
|
return op.getInput1();
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
|
|
auto input = getInput1();
|
|
// Element-wise abs(abs(x)) = abs(x)
|
|
if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
|
|
return input;
|
|
}
|
|
|
|
return {};
|
|
}
|
|
|
|
OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
|
|
// Fold consecutive concats on the same axis into a single op.
|
|
// Keep track of the operands so we are able to construct a new concat
|
|
// later. Conservatively assume that we double the number of operands when
|
|
// folding
|
|
SmallVector<Value, 8> concatOperands;
|
|
concatOperands.reserve(2 * getNumOperands());
|
|
|
|
// Find all operands that are foldable concats
|
|
bool foundFoldableConcat = false;
|
|
for (Value operand : getOperands()) {
|
|
concatOperands.emplace_back(operand);
|
|
|
|
auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp());
|
|
if (!producer)
|
|
continue;
|
|
|
|
// Not foldable if axes are not the same
|
|
if (getAxis() != producer.getAxis())
|
|
continue;
|
|
|
|
// Replace the original operand with all incoming operands
|
|
foundFoldableConcat = true;
|
|
concatOperands.pop_back();
|
|
llvm::append_range(concatOperands, producer->getOperands());
|
|
}
|
|
|
|
if (!foundFoldableConcat)
|
|
return {};
|
|
|
|
getOperation()->setOperands(concatOperands);
|
|
return getResult();
|
|
}
|
|
|
|
OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
|
|
auto input = adaptor.getInput1();
|
|
|
|
auto inputAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(input);
|
|
// Fold splat inputs only.
|
|
if (!inputAttr || !inputAttr.isSplat())
|
|
return {};
|
|
|
|
auto shapeType = llvm::cast<ShapedType>(getType());
|
|
if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
|
|
auto floatVal = inputAttr.getSplatValue<APFloat>();
|
|
return DenseElementsAttr::get(shapeType,
|
|
ReciprocalOp::calcOneElement(floatVal));
|
|
}
|
|
|
|
return {};
|
|
}
|