bolt/deps/llvm-18.1.8/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
2025-02-14 19:21:04 +01:00

426 lines
16 KiB
C++

//===- TosaFolders.cpp ----------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Fold TOSA operations
//
//===----------------------------------------------------------------------===//
#include <functional>
#include <numeric>
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/FloatingPointMode.h"
#include "llvm/ADT/SmallVector.h"
using namespace mlir;
using namespace mlir::tosa;
namespace {
/// Apply the given transformation \p toApply to every element of the tensor to
/// be transformed \p toTransform.
///
/// Elements of \p toTransform are extracted as \p SrcValueType.
///
/// \returns A tensor with the same size as \p toTransform, containing
/// \p TargetValueType values of type \p TargetType.
template <class SrcValType, class TargetValType, class TargetType>
DenseElementsAttr applyElementWise(
const DenseElementsAttr &toTransform,
const std::function<TargetValType(const SrcValType &)> &toApply,
TargetType targetType) {
SmallVector<TargetValType> transformedValues;
// We already know the amount of values we will insert, reserve space for
// all of them to avoid dynamic resizing
transformedValues.reserve(toTransform.getNumElements());
for (auto val : toTransform.getValues<SrcValType>()) {
auto transformedVal = toApply(val);
transformedValues.push_back(transformedVal);
}
// Make sure that the output tensor has the expected output type
auto inShape = toTransform.getType();
auto outTy = inShape.cloneWith({}, targetType);
return DenseElementsAttr::get(outTy, transformedValues);
}
template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
const DenseElementsAttr &toTransform,
const std::function<APFloat(const APFloat &)> &toApply,
FloatType targetType);
/// Function that checks if the type contained in \p toCheck is float.
LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
PatternRewriter &rewriter) {
if (isa<FloatType>(toCheck.getType().getElementType())) {
return success();
}
return rewriter.notifyMatchFailure(location,
"Unexpected input tensor type: the "
"TOSA spec only allows floats");
}
/// Function that checks if \p toCheck is a dense TOSA constant tensor.
LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue<TensorType> toCheck,
TosaOp location,
PatternRewriter &rewriter) {
// Check whether the tensor is constant and dense
// TODO We currently ensure the tensor is dense by using the correct type for
// the bind_value, however we do not actually need this value. It would be
// nicer to only have a check here.
DenseElementsAttr tmp;
if (!matchPattern(toCheck, m_Constant(&tmp))) {
return rewriter.notifyMatchFailure(location,
"Non-const or non-dense input tensor");
}
// Make sure it actually is a TOSA constant (the match allows for other
// constants as well)
if (isa<ConstOp>(toCheck.getDefiningOp())) {
return success();
}
return rewriter.notifyMatchFailure(location,
"The reciprocal can only be folded if "
"it operates on a TOSA constant");
}
/// Function that checks if \p toCheck is a dense TOSA constant float tensor.
LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue<TensorType> toCheck,
TosaOp location,
PatternRewriter &rewriter) {
auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter);
if (failed(floatCheck)) {
return floatCheck;
}
return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter);
}
/// Heuristic to decide when to replace a unary operation on a constant with the
/// folded value.
/// Folding operations on constants can lead to an increased memory usage
/// whenever the input cannot be replaced but a new constant is inserted. Hence,
/// this will currently only suggest folding when the memory impact is
/// negligible.
/// Takes the \p unaryOp and the constant input \p values.
/// \returns Whether folding should be applied.
bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) {
assert(unaryOp->getNumOperands() == 1);
auto inputOp = unaryOp->getOperand(0);
// If the input is a splat, we don't care for the number of users
if (isa<SplatElementsAttr>(values)) {
return true;
}
// If this is the only use of the tensor it should be replaced as no
// additional memory is required
return inputOp.hasOneUse();
}
template <typename BaseType>
DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
ShapedType outputType,
llvm::ArrayRef<int64_t> permValues) {
if (inputType.getNumElements() == 0)
return DenseElementsAttr::get(outputType, llvm::ArrayRef<BaseType>{});
auto attrValues = attr.getValues<BaseType>();
auto inputShape = inputType.getShape();
// The inverted permutation map and strides of the output are used to compute
// the contribution of a given dimension to the destination linear index in
// an order-independent way.
auto outputStrides = computeStrides(outputType.getShape());
auto invertedPermValues = invertPermutationVector(permValues);
auto initialValue = *std::begin(attrValues);
SmallVector<BaseType> outputValues(inputType.getNumElements(), initialValue);
for (const auto &it : llvm::enumerate(attrValues)) {
auto srcLinearIndex = it.index();
uint64_t dstLinearIndex = 0;
for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
// Compute the index into the current dimension of the source vector.
auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
srcLinearIndex /= inputShape[dim];
// Add the contribution of the current dimension to the output using the
// permutation map.
dstLinearIndex +=
outputStrides[invertedPermValues[dim]] * sourceIndexForDim;
}
outputValues[dstLinearIndex] = it.value();
}
return DenseElementsAttr::get(outputType,
llvm::ArrayRef<BaseType>(outputValues));
}
// A type specialized transposition of an ElementsAttr.
// This implementation tries to operate on the underlying data in its raw
// representation when possible to avoid allocating a large number of Attribute
// objects.
DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
ShapedType outputType,
llvm::ArrayRef<int64_t> permValues) {
auto baseType = inputType.getElementType();
// Handle possible integer types
if (auto intType = dyn_cast<IntegerType>(baseType)) {
switch (intType.getWidth()) {
case 1:
return transposeType<bool>(attr, inputType, outputType, permValues);
case 8:
return transposeType<int8_t>(attr, inputType, outputType, permValues);
case 16:
return transposeType<int16_t>(attr, inputType, outputType, permValues);
case 32:
return transposeType<int32_t>(attr, inputType, outputType, permValues);
case 64:
return transposeType<int64_t>(attr, inputType, outputType, permValues);
default:
return transposeType<APInt>(attr, inputType, outputType, permValues);
}
}
// Handle possible float types
if (baseType.isF32()) {
return transposeType<float>(attr, inputType, outputType, permValues);
}
return transposeType<APFloat>(attr, inputType, outputType, permValues);
}
struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::TransposeOp op,
PatternRewriter &rewriter) const override {
auto outputType = cast<ShapedType>(op.getType());
// TOSA supports quantized types.
if (!outputType.getElementType().isIntOrIndexOrFloat())
return failure();
ElementsAttr inputValues;
if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
return failure();
// Make sure the input is a constant that has a single user.
if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
return failure();
DenseIntElementsAttr permAttr;
if (!matchPattern(op.getPerms(), m_Constant(&permAttr)))
return failure();
auto permValues = llvm::to_vector<6>(llvm::map_range(
// TOSA allows both 32- and 64-bit integer tensors here.
permAttr.getValues<APInt>(),
[](const APInt &val) { return val.getSExtValue(); }));
auto inputType = cast<ShapedType>(op.getInput1().getType());
auto resultAttr = transpose(inputValues, inputType, outputType, permValues);
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
return success();
}
};
struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ReciprocalOp recip,
PatternRewriter &rewriter) const override {
auto inputTensor = recip.getInput1();
// Check that we can apply folding
auto preCondCheck =
notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter);
if (failed(preCondCheck)) {
return preCondCheck;
}
// Extract the tensor values
DenseElementsAttr inputValues;
matchPattern(inputTensor, m_Constant(&inputValues));
// Check whether this should be folded.
if (!constantUnaryOpShouldBeFolded(recip, inputValues)) {
return rewriter.notifyMatchFailure(
recip, "Currently, reciprocals will only be folded if the input "
"tensor has a single user");
}
// Create a new tensor with the updated values
auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
inputValues, &ReciprocalOp::calcOneElement,
cast<FloatType>(inputValues.getElementType()));
// Replace the use of the reciprocal with the transformed tensor
rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
return success();
}
};
/// Getting the axes position of the element which is located
/// in the tensor at the counter index
llvm::SmallVector<int64_t>
getPositionFromIndex(int64_t index, llvm::ArrayRef<int64_t> tensorShape) {
int64_t remaining = index;
llvm::SmallVector<int64_t> position(tensorShape.size(), 0);
for (int64_t i = tensorShape.size() - 1; i >= 0; --i) {
position[i] = remaining % tensorShape[i];
remaining /= tensorShape[i];
}
return position;
}
/// Getting the index of the element which is located at the
/// axes position in the tensor
int64_t getIndexFromPosition(llvm::ArrayRef<int64_t> position,
llvm::ArrayRef<int64_t> tensorShape) {
int64_t index = 0;
int64_t multiplierTmp = 1;
for (int64_t i = position.size() - 1; i >= 0; --i) {
index += position[i] * multiplierTmp;
multiplierTmp *= tensorShape[i];
}
return index;
}
template <typename OperationType>
llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
llvm::ArrayRef<int64_t> oldShape,
int64_t reductionAxis,
int64_t reductionIndex) {
llvm::SmallVector<int64_t> newShape(oldShape);
newShape[reductionAxis] = 1;
/// Let's calculate the position of the index
llvm::SmallVector<int64_t> position =
getPositionFromIndex(reductionIndex, newShape);
auto oldTensor = oldTensorAttr.getValues<llvm::APInt>();
/// Starting from the first positon along the reduction axis
position[reductionAxis] = 0;
int64_t indexAtOldTensor = getIndexFromPosition(position, oldShape);
llvm::APInt reducedValue = oldTensor[indexAtOldTensor];
for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis];
++reductionAxisVal) {
int64_t stride = std::accumulate(oldShape.begin() + reductionAxis + 1,
oldShape.end(), 1, std::multiplies<int>());
int64_t index = indexAtOldTensor + stride * reductionAxisVal;
reducedValue =
OperationType::calcOneElement(reducedValue, oldTensor[index]);
}
return reducedValue;
}
template <typename OperationType>
struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
ReduceConstantOptimization(MLIRContext *context,
bool aggressiveReduceConstant)
: OpRewritePattern<OperationType>(context),
aggressiveReduceConstant(aggressiveReduceConstant) {}
using OpRewritePattern<OperationType>::OpRewritePattern;
LogicalResult matchAndRewrite(OperationType op,
PatternRewriter &rewriter) const override {
Value inputOp = op.getInput();
auto constOp = inputOp.getDefiningOp<tosa::ConstOp>();
if (!constOp)
return rewriter.notifyMatchFailure(
op, "reduce input must be const operation");
if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant)
return rewriter.notifyMatchFailure(
op, "input operation has more than one user");
auto resultType = cast<ShapedType>(op.getOutput().getType());
if (!resultType.hasStaticShape())
return rewriter.notifyMatchFailure(op, "result type shape is not static");
auto reductionAxis = op.getAxis();
const auto denseElementsAttr = constOp.getValue();
const auto shapedOldElementsValues =
denseElementsAttr.getType().cast<ShapedType>();
if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType()))
return rewriter.notifyMatchFailure(
op, "reduce input currently supported with integer type");
auto oldShape = shapedOldElementsValues.getShape();
auto newShape = resultType.getShape();
auto newNumOfElements = std::accumulate(newShape.begin(), newShape.end(), 1,
std::multiplies<int>());
llvm::SmallVector<APInt> newReducedTensor(newNumOfElements);
for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements;
++reductionIndex) {
/// Let's reduce all the elements along this reduction axis
newReducedTensor[reductionIndex] = calculateReducedValue<OperationType>(
denseElementsAttr, oldShape, reductionAxis, reductionIndex);
}
auto rankedTensorType = cast<RankedTensorType>(resultType);
auto denseAttr =
mlir::DenseElementsAttr::get(rankedTensorType, newReducedTensor);
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr);
return success();
}
const bool aggressiveReduceConstant;
};
} // namespace
void mlir::tosa::populateTosaConstantReduction(MLIRContext *ctx,
RewritePatternSet &patterns,
bool aggressiveReduceConstant) {
patterns.add<ReduceConstantOptimization<ReduceAllOp>>(
ctx, aggressiveReduceConstant);
patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(
ctx, aggressiveReduceConstant);
patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(
ctx, aggressiveReduceConstant);
patterns.add<ReduceConstantOptimization<ReduceMinOp>>(
ctx, aggressiveReduceConstant);
patterns.add<ReduceConstantOptimization<ReduceProdOp>>(
ctx, aggressiveReduceConstant);
patterns.add<ReduceConstantOptimization<ReduceSumOp>>(
ctx, aggressiveReduceConstant);
}
void mlir::tosa::populateTosaFoldConstantTransposePatterns(
MLIRContext *ctx, RewritePatternSet &patterns) {
patterns.add<TosaFoldConstantTranspose>(ctx);
}
void mlir::tosa::populateTosaFoldConstantReciprocalPatterns(
MLIRContext *ctx, RewritePatternSet &patterns) {
patterns.add<TosaFoldConstantReciprocal>(ctx);
}