426 lines
16 KiB
C++
426 lines
16 KiB
C++
//===- TosaFolders.cpp ----------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Fold TOSA operations
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include <functional>
|
|
#include <numeric>
|
|
|
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
|
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "llvm/ADT/APFloat.h"
|
|
#include "llvm/ADT/FloatingPointMode.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::tosa;
|
|
|
|
namespace {
|
|
|
|
/// Apply the given transformation \p toApply to every element of the tensor to
|
|
/// be transformed \p toTransform.
|
|
///
|
|
/// Elements of \p toTransform are extracted as \p SrcValueType.
|
|
///
|
|
/// \returns A tensor with the same size as \p toTransform, containing
|
|
/// \p TargetValueType values of type \p TargetType.
|
|
template <class SrcValType, class TargetValType, class TargetType>
|
|
DenseElementsAttr applyElementWise(
|
|
const DenseElementsAttr &toTransform,
|
|
const std::function<TargetValType(const SrcValType &)> &toApply,
|
|
TargetType targetType) {
|
|
SmallVector<TargetValType> transformedValues;
|
|
// We already know the amount of values we will insert, reserve space for
|
|
// all of them to avoid dynamic resizing
|
|
transformedValues.reserve(toTransform.getNumElements());
|
|
for (auto val : toTransform.getValues<SrcValType>()) {
|
|
auto transformedVal = toApply(val);
|
|
transformedValues.push_back(transformedVal);
|
|
}
|
|
|
|
// Make sure that the output tensor has the expected output type
|
|
auto inShape = toTransform.getType();
|
|
auto outTy = inShape.cloneWith({}, targetType);
|
|
|
|
return DenseElementsAttr::get(outTy, transformedValues);
|
|
}
|
|
|
|
template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
|
|
const DenseElementsAttr &toTransform,
|
|
const std::function<APFloat(const APFloat &)> &toApply,
|
|
FloatType targetType);
|
|
|
|
/// Function that checks if the type contained in \p toCheck is float.
|
|
LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
|
|
PatternRewriter &rewriter) {
|
|
if (isa<FloatType>(toCheck.getType().getElementType())) {
|
|
return success();
|
|
}
|
|
return rewriter.notifyMatchFailure(location,
|
|
"Unexpected input tensor type: the "
|
|
"TOSA spec only allows floats");
|
|
}
|
|
|
|
/// Function that checks if \p toCheck is a dense TOSA constant tensor.
|
|
LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
|
|
TosaOp location,
|
|
PatternRewriter &rewriter) {
|
|
// Check whether the tensor is constant and dense
|
|
// TODO We currently ensure the tensor is dense by using the correct type for
|
|
// the bind_value, however we do not actually need this value. It would be
|
|
// nicer to only have a check here.
|
|
DenseElementsAttr tmp;
|
|
if (!matchPattern(toCheck, m_Constant(&tmp))) {
|
|
return rewriter.notifyMatchFailure(location,
|
|
"Non-const or non-dense input tensor");
|
|
}
|
|
|
|
// Make sure it actually is a TOSA constant (the match allows for other
|
|
// constants as well)
|
|
if (isa<ConstOp>(toCheck.getDefiningOp())) {
|
|
return success();
|
|
}
|
|
|
|
return rewriter.notifyMatchFailure(location,
|
|
"The reciprocal can only be folded if "
|
|
"it operates on a TOSA constant");
|
|
}
|
|
|
|
/// Function that checks if \p toCheck is a dense TOSA constant float tensor.
|
|
LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
|
|
TosaOp location,
|
|
PatternRewriter &rewriter) {
|
|
auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
|
|
if (failed(floatCheck)) {
|
|
return floatCheck;
|
|
}
|
|
return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
|
|
}
|
|
|
|
/// Heuristic to decide when to replace a unary operation on a constant with the
|
|
/// folded value.
|
|
/// Folding operations on constants can lead to an increased memory usage
|
|
/// whenever the input cannot be replaced but a new constant is inserted. Hence,
|
|
/// this will currently only suggest folding when the memory impact is
|
|
/// negligible.
|
|
/// Takes the \p unaryOp and the constant input \p values.
|
|
/// \returns Whether folding should be applied.
|
|
bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
|
|
assert(unaryOp->getNumOperands() == 1);
|
|
auto inputOp = unaryOp->getOperand(0);
|
|
|
|
// If the input is a splat, we don't care for the number of users
|
|
if (isa<SplatElementsAttr>(values)) {
|
|
return true;
|
|
}
|
|
|
|
// If this is the only use of the tensor it should be replaced as no
|
|
// additional memory is required
|
|
return inputOp.hasOneUse();
|
|
}
|
|
|
|
template <typename BaseType>
|
|
DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
|
|
ShapedType outputType,
|
|
llvm::ArrayRef<int64_t> permValues) {
|
|
if (inputType.getNumElements() == 0)
|
|
return DenseElementsAttr::get(outputType, llvm::ArrayRef<BaseType>{});
|
|
|
|
auto attrValues = attr.getValues<BaseType>();
|
|
auto inputShape = inputType.getShape();
|
|
|
|
// The inverted permutation map and strides of the output are used to compute
|
|
// the contribution of a given dimension to the destination linear index in
|
|
// an order-independent way.
|
|
auto outputStrides = computeStrides(outputType.getShape());
|
|
auto invertedPermValues = invertPermutationVector(permValues);
|
|
|
|
auto initialValue = *std::begin(attrValues);
|
|
SmallVector<BaseType> outputValues(inputType.getNumElements(), initialValue);
|
|
|
|
for (const auto &it : llvm::enumerate(attrValues)) {
|
|
auto srcLinearIndex = it.index();
|
|
|
|
uint64_t dstLinearIndex = 0;
|
|
for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
|
|
// Compute the index into the current dimension of the source vector.
|
|
auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
|
|
srcLinearIndex /= inputShape[dim];
|
|
|
|
// Add the contribution of the current dimension to the output using the
|
|
// permutation map.
|
|
dstLinearIndex +=
|
|
outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
|
|
}
|
|
|
|
outputValues[dstLinearIndex] = it.value();
|
|
}
|
|
|
|
return DenseElementsAttr::get(outputType,
|
|
llvm::ArrayRef<BaseType>(outputValues));
|
|
}
|
|
|
|
// A type specialized transposition of an ElementsAttr.
|
|
// This implementation tries to operate on the underlying data in its raw
|
|
// representation when possible to avoid allocating a large number of Attribute
|
|
// objects.
|
|
DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
|
|
ShapedType outputType,
|
|
llvm::ArrayRef<int64_t> permValues) {
|
|
auto baseType = inputType.getElementType();
|
|
|
|
// Handle possible integer types
|
|
if (auto intType = dyn_cast<IntegerType>(baseType)) {
|
|
switch (intType.getWidth()) {
|
|
case 1:
|
|
return transposeType<bool>(attr, inputType, outputType, permValues);
|
|
case 8:
|
|
return transposeType<int8_t>(attr, inputType, outputType, permValues);
|
|
case 16:
|
|
return transposeType<int16_t>(attr, inputType, outputType, permValues);
|
|
case 32:
|
|
return transposeType<int32_t>(attr, inputType, outputType, permValues);
|
|
case 64:
|
|
return transposeType<int64_t>(attr, inputType, outputType, permValues);
|
|
default:
|
|
return transposeType<APInt>(attr, inputType, outputType, permValues);
|
|
}
|
|
}
|
|
|
|
// Handle possible float types
|
|
if (baseType.isF32()) {
|
|
return transposeType<float>(attr, inputType, outputType, permValues);
|
|
}
|
|
|
|
return transposeType<APFloat>(attr, inputType, outputType, permValues);
|
|
}
|
|
|
|
struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tosa::TransposeOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
auto outputType = cast<ShapedType>(op.getType());
|
|
// TOSA supports quantized types.
|
|
if (!outputType.getElementType().isIntOrIndexOrFloat())
|
|
return failure();
|
|
|
|
ElementsAttr inputValues;
|
|
if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
|
|
return failure();
|
|
// Make sure the input is a constant that has a single user.
|
|
if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
|
|
return failure();
|
|
|
|
DenseIntElementsAttr permAttr;
|
|
if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
|
|
return failure();
|
|
auto permValues = llvm::to_vector<6>(llvm::map_range(
|
|
// TOSA allows both 32- and 64-bit integer tensors here.
|
|
permAttr.getValues<APInt>(),
|
|
[](const APInt &val) { return val.getSExtValue(); }));
|
|
|
|
auto inputType = cast<ShapedType>(op.getInput1().getType());
|
|
|
|
auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
|
|
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
|
|
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ReciprocalOp recip,
|
|
PatternRewriter &rewriter) const override {
|
|
auto inputTensor = recip.getInput1();
|
|
|
|
// Check that we can apply folding
|
|
auto preCondCheck =
|
|
notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter);
|
|
if (failed(preCondCheck)) {
|
|
return preCondCheck;
|
|
}
|
|
|
|
// Extract the tensor values
|
|
DenseElementsAttr inputValues;
|
|
matchPattern(inputTensor, m_Constant(&inputValues));
|
|
|
|
// Check whether this should be folded.
|
|
if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
|
|
return rewriter.notifyMatchFailure(
|
|
recip, "Currently, reciprocals will only be folded if the input "
|
|
"tensor has a single user");
|
|
}
|
|
|
|
// Create a new tensor with the updated values
|
|
auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
|
|
inputValues, &ReciprocalOp::calcOneElement,
|
|
cast<FloatType>(inputValues.getElementType()));
|
|
|
|
// Replace the use of the reciprocal with the transformed tensor
|
|
rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Getting the axes position of the element which is located
|
|
/// in the tensor at the counter index
|
|
|
|
llvm::SmallVector<int64_t>
|
|
getPositionFromIndex(int64_t index, llvm::ArrayRef<int64_t> tensorShape) {
|
|
int64_t remaining = index;
|
|
llvm::SmallVector<int64_t> position(tensorShape.size(), 0);
|
|
for (int64_t i = tensorShape.size() - 1; i >= 0; --i) {
|
|
position[i] = remaining % tensorShape[i];
|
|
remaining /= tensorShape[i];
|
|
}
|
|
return position;
|
|
}
|
|
|
|
/// Getting the index of the element which is located at the
|
|
/// axes position in the tensor
|
|
|
|
int64_t getIndexFromPosition(llvm::ArrayRef<int64_t> position,
|
|
llvm::ArrayRef<int64_t> tensorShape) {
|
|
int64_t index = 0;
|
|
int64_t multiplierTmp = 1;
|
|
for (int64_t i = position.size() - 1; i >= 0; --i) {
|
|
index += position[i] * multiplierTmp;
|
|
multiplierTmp *= tensorShape[i];
|
|
}
|
|
return index;
|
|
}
|
|
|
|
template <typename OperationType>
|
|
llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
|
|
llvm::ArrayRef<int64_t> oldShape,
|
|
int64_t reductionAxis,
|
|
int64_t reductionIndex) {
|
|
|
|
llvm::SmallVector<int64_t> newShape(oldShape);
|
|
newShape[reductionAxis] = 1;
|
|
/// Let's calculate the position of the index
|
|
llvm::SmallVector<int64_t> position =
|
|
getPositionFromIndex(reductionIndex, newShape);
|
|
auto oldTensor = oldTensorAttr.getValues<llvm::APInt>();
|
|
/// Starting from the first positon along the reduction axis
|
|
position[reductionAxis] = 0;
|
|
int64_t indexAtOldTensor = getIndexFromPosition(position, oldShape);
|
|
llvm::APInt reducedValue = oldTensor[indexAtOldTensor];
|
|
|
|
for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis];
|
|
++reductionAxisVal) {
|
|
|
|
int64_t stride = std::accumulate(oldShape.begin() + reductionAxis + 1,
|
|
oldShape.end(), 1, std::multiplies<int>());
|
|
int64_t index = indexAtOldTensor + stride * reductionAxisVal;
|
|
reducedValue =
|
|
OperationType::calcOneElement(reducedValue, oldTensor[index]);
|
|
}
|
|
return reducedValue;
|
|
}
|
|
|
|
template <typename OperationType>
|
|
struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
|
|
|
|
ReduceConstantOptimization(MLIRContext *context,
|
|
bool aggressiveReduceConstant)
|
|
: OpRewritePattern<OperationType>(context),
|
|
aggressiveReduceConstant(aggressiveReduceConstant) {}
|
|
|
|
using OpRewritePattern<OperationType>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OperationType op,
|
|
PatternRewriter &rewriter) const override {
|
|
Value inputOp = op.getInput();
|
|
auto constOp = inputOp.getDefiningOp<tosa::ConstOp>();
|
|
|
|
if (!constOp)
|
|
return rewriter.notifyMatchFailure(
|
|
op, "reduce input must be const operation");
|
|
|
|
if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant)
|
|
return rewriter.notifyMatchFailure(
|
|
op, "input operation has more than one user");
|
|
|
|
auto resultType = cast<ShapedType>(op.getOutput().getType());
|
|
|
|
if (!resultType.hasStaticShape())
|
|
return rewriter.notifyMatchFailure(op, "result type shape is not static");
|
|
|
|
auto reductionAxis = op.getAxis();
|
|
const auto denseElementsAttr = constOp.getValue();
|
|
const auto shapedOldElementsValues =
|
|
denseElementsAttr.getType().cast<ShapedType>();
|
|
|
|
if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType()))
|
|
return rewriter.notifyMatchFailure(
|
|
op, "reduce input currently supported with integer type");
|
|
|
|
auto oldShape = shapedOldElementsValues.getShape();
|
|
auto newShape = resultType.getShape();
|
|
|
|
auto newNumOfElements = std::accumulate(newShape.begin(), newShape.end(), 1,
|
|
std::multiplies<int>());
|
|
llvm::SmallVector<APInt> newReducedTensor(newNumOfElements);
|
|
|
|
for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements;
|
|
++reductionIndex) {
|
|
|
|
/// Let's reduce all the elements along this reduction axis
|
|
newReducedTensor[reductionIndex] = calculateReducedValue<OperationType>(
|
|
denseElementsAttr, oldShape, reductionAxis, reductionIndex);
|
|
}
|
|
|
|
auto rankedTensorType = cast<RankedTensorType>(resultType);
|
|
auto denseAttr =
|
|
mlir::DenseElementsAttr::get(rankedTensorType, newReducedTensor);
|
|
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr);
|
|
return success();
|
|
}
|
|
const bool aggressiveReduceConstant;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::tosa::populateTosaConstantReduction(MLIRContext *ctx,
|
|
RewritePatternSet &patterns,
|
|
bool aggressiveReduceConstant) {
|
|
patterns.add<ReduceConstantOptimization<ReduceAllOp>>(
|
|
ctx, aggressiveReduceConstant);
|
|
patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(
|
|
ctx, aggressiveReduceConstant);
|
|
patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(
|
|
ctx, aggressiveReduceConstant);
|
|
patterns.add<ReduceConstantOptimization<ReduceMinOp>>(
|
|
ctx, aggressiveReduceConstant);
|
|
patterns.add<ReduceConstantOptimization<ReduceProdOp>>(
|
|
ctx, aggressiveReduceConstant);
|
|
patterns.add<ReduceConstantOptimization<ReduceSumOp>>(
|
|
ctx, aggressiveReduceConstant);
|
|
}
|
|
|
|
void mlir::tosa::populateTosaFoldConstantTransposePatterns(
|
|
MLIRContext *ctx, RewritePatternSet &patterns) {
|
|
patterns.add<TosaFoldConstantTranspose>(ctx);
|
|
}
|
|
|
|
void mlir::tosa::populateTosaFoldConstantReciprocalPatterns(
|
|
MLIRContext *ctx, RewritePatternSet &patterns) {
|
|
patterns.add<TosaFoldConstantReciprocal>(ctx);
|
|
}
|