//===- TosaFolders.cpp ----------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Fold TOSA operations // //===----------------------------------------------------------------------===// #include #include #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/FloatingPointMode.h" #include "llvm/ADT/SmallVector.h" using namespace mlir; using namespace mlir::tosa; namespace { /// Apply the given transformation \p toApply to every element of the tensor to /// be transformed \p toTransform. /// /// Elements of \p toTransform are extracted as \p SrcValueType. /// /// \returns A tensor with the same size as \p toTransform, containing /// \p TargetValueType values of type \p TargetType. template DenseElementsAttr applyElementWise( const DenseElementsAttr &toTransform, const std::function &toApply, TargetType targetType) { SmallVector transformedValues; // We already know the amount of values we will insert, reserve space for // all of them to avoid dynamic resizing transformedValues.reserve(toTransform.getNumElements()); for (auto val : toTransform.getValues()) { auto transformedVal = toApply(val); transformedValues.push_back(transformedVal); } // Make sure that the output tensor has the expected output type auto inShape = toTransform.getType(); auto outTy = inShape.cloneWith({}, targetType); return DenseElementsAttr::get(outTy, transformedValues); } template DenseElementsAttr applyElementWise( const DenseElementsAttr &toTransform, const std::function &toApply, FloatType targetType); /// Function that checks if the type contained in \p toCheck is float. LogicalResult notifyIfNotFloat(TypedValue toCheck, TosaOp location, PatternRewriter &rewriter) { if (isa(toCheck.getType().getElementType())) { return success(); } return rewriter.notifyMatchFailure(location, "Unexpected input tensor type: the " "TOSA spec only allows floats"); } /// Function that checks if \p toCheck is a dense TOSA constant tensor. LogicalResult notifyIfNoTosaDenseConstantTensor(TypedValue toCheck, TosaOp location, PatternRewriter &rewriter) { // Check whether the tensor is constant and dense // TODO We currently ensure the tensor is dense by using the correct type for // the bind_value, however we do not actually need this value. It would be // nicer to only have a check here. DenseElementsAttr tmp; if (!matchPattern(toCheck, m_Constant(&tmp))) { return rewriter.notifyMatchFailure(location, "Non-const or non-dense input tensor"); } // Make sure it actually is a TOSA constant (the match allows for other // constants as well) if (isa(toCheck.getDefiningOp())) { return success(); } return rewriter.notifyMatchFailure(location, "The reciprocal can only be folded if " "it operates on a TOSA constant"); } /// Function that checks if \p toCheck is a dense TOSA constant float tensor. LogicalResult notifyIfNotConstantFloatTosaTensor(TypedValue toCheck, TosaOp location, PatternRewriter &rewriter) { auto floatCheck = notifyIfNotFloat(toCheck, location, rewriter); if (failed(floatCheck)) { return floatCheck; } return notifyIfNoTosaDenseConstantTensor(toCheck, location, rewriter); } /// Heuristic to decide when to replace a unary operation on a constant with the /// folded value. /// Folding operations on constants can lead to an increased memory usage /// whenever the input cannot be replaced but a new constant is inserted. Hence, /// this will currently only suggest folding when the memory impact is /// negligible. /// Takes the \p unaryOp and the constant input \p values. /// \returns Whether folding should be applied. bool constantUnaryOpShouldBeFolded(TosaOp unaryOp, DenseElementsAttr values) { assert(unaryOp->getNumOperands() == 1); auto inputOp = unaryOp->getOperand(0); // If the input is a splat, we don't care for the number of users if (isa(values)) { return true; } // If this is the only use of the tensor it should be replaced as no // additional memory is required return inputOp.hasOneUse(); } template DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType, ShapedType outputType, llvm::ArrayRef permValues) { if (inputType.getNumElements() == 0) return DenseElementsAttr::get(outputType, llvm::ArrayRef{}); auto attrValues = attr.getValues(); auto inputShape = inputType.getShape(); // The inverted permutation map and strides of the output are used to compute // the contribution of a given dimension to the destination linear index in // an order-independent way. auto outputStrides = computeStrides(outputType.getShape()); auto invertedPermValues = invertPermutationVector(permValues); auto initialValue = *std::begin(attrValues); SmallVector outputValues(inputType.getNumElements(), initialValue); for (const auto &it : llvm::enumerate(attrValues)) { auto srcLinearIndex = it.index(); uint64_t dstLinearIndex = 0; for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) { // Compute the index into the current dimension of the source vector. auto sourceIndexForDim = srcLinearIndex % inputShape[dim]; srcLinearIndex /= inputShape[dim]; // Add the contribution of the current dimension to the output using the // permutation map. dstLinearIndex += outputStrides[invertedPermValues[dim]] * sourceIndexForDim; } outputValues[dstLinearIndex] = it.value(); } return DenseElementsAttr::get(outputType, llvm::ArrayRef(outputValues)); } // A type specialized transposition of an ElementsAttr. // This implementation tries to operate on the underlying data in its raw // representation when possible to avoid allocating a large number of Attribute // objects. DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType, ShapedType outputType, llvm::ArrayRef permValues) { auto baseType = inputType.getElementType(); // Handle possible integer types if (auto intType = dyn_cast(baseType)) { switch (intType.getWidth()) { case 1: return transposeType(attr, inputType, outputType, permValues); case 8: return transposeType(attr, inputType, outputType, permValues); case 16: return transposeType(attr, inputType, outputType, permValues); case 32: return transposeType(attr, inputType, outputType, permValues); case 64: return transposeType(attr, inputType, outputType, permValues); default: return transposeType(attr, inputType, outputType, permValues); } } // Handle possible float types if (baseType.isF32()) { return transposeType(attr, inputType, outputType, permValues); } return transposeType(attr, inputType, outputType, permValues); } struct TosaFoldConstantTranspose : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::TransposeOp op, PatternRewriter &rewriter) const override { auto outputType = cast(op.getType()); // TOSA supports quantized types. if (!outputType.getElementType().isIntOrIndexOrFloat()) return failure(); ElementsAttr inputValues; if (!matchPattern(op.getInput1(), m_Constant(&inputValues))) return failure(); // Make sure the input is a constant that has a single user. if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers())) return failure(); DenseIntElementsAttr permAttr; if (!matchPattern(op.getPerms(), m_Constant(&permAttr))) return failure(); auto permValues = llvm::to_vector<6>(llvm::map_range( // TOSA allows both 32- and 64-bit integer tensors here. permAttr.getValues(), [](const APInt &val) { return val.getSExtValue(); })); auto inputType = cast(op.getInput1().getType()); auto resultAttr = transpose(inputValues, inputType, outputType, permValues); rewriter.replaceOpWithNewOp(op, outputType, resultAttr); return success(); } }; struct TosaFoldConstantReciprocal : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ReciprocalOp recip, PatternRewriter &rewriter) const override { auto inputTensor = recip.getInput1(); // Check that we can apply folding auto preCondCheck = notifyIfNotConstantFloatTosaTensor(inputTensor, recip, rewriter); if (failed(preCondCheck)) { return preCondCheck; } // Extract the tensor values DenseElementsAttr inputValues; matchPattern(inputTensor, m_Constant(&inputValues)); // Check whether this should be folded. if (!constantUnaryOpShouldBeFolded(recip, inputValues)) { return rewriter.notifyMatchFailure( recip, "Currently, reciprocals will only be folded if the input " "tensor has a single user"); } // Create a new tensor with the updated values auto newTensor = applyElementWise( inputValues, &ReciprocalOp::calcOneElement, cast(inputValues.getElementType())); // Replace the use of the reciprocal with the transformed tensor rewriter.replaceOpWithNewOp(recip, newTensor.getType(), newTensor); return success(); } }; /// Getting the axes position of the element which is located /// in the tensor at the counter index llvm::SmallVector getPositionFromIndex(int64_t index, llvm::ArrayRef tensorShape) { int64_t remaining = index; llvm::SmallVector position(tensorShape.size(), 0); for (int64_t i = tensorShape.size() - 1; i >= 0; --i) { position[i] = remaining % tensorShape[i]; remaining /= tensorShape[i]; } return position; } /// Getting the index of the element which is located at the /// axes position in the tensor int64_t getIndexFromPosition(llvm::ArrayRef position, llvm::ArrayRef tensorShape) { int64_t index = 0; int64_t multiplierTmp = 1; for (int64_t i = position.size() - 1; i >= 0; --i) { index += position[i] * multiplierTmp; multiplierTmp *= tensorShape[i]; } return index; } template llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr, llvm::ArrayRef oldShape, int64_t reductionAxis, int64_t reductionIndex) { llvm::SmallVector newShape(oldShape); newShape[reductionAxis] = 1; /// Let's calculate the position of the index llvm::SmallVector position = getPositionFromIndex(reductionIndex, newShape); auto oldTensor = oldTensorAttr.getValues(); /// Starting from the first positon along the reduction axis position[reductionAxis] = 0; int64_t indexAtOldTensor = getIndexFromPosition(position, oldShape); llvm::APInt reducedValue = oldTensor[indexAtOldTensor]; for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis]; ++reductionAxisVal) { int64_t stride = std::accumulate(oldShape.begin() + reductionAxis + 1, oldShape.end(), 1, std::multiplies()); int64_t index = indexAtOldTensor + stride * reductionAxisVal; reducedValue = OperationType::calcOneElement(reducedValue, oldTensor[index]); } return reducedValue; } template struct ReduceConstantOptimization : public OpRewritePattern { ReduceConstantOptimization(MLIRContext *context, bool aggressiveReduceConstant) : OpRewritePattern(context), aggressiveReduceConstant(aggressiveReduceConstant) {} using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OperationType op, PatternRewriter &rewriter) const override { Value inputOp = op.getInput(); auto constOp = inputOp.getDefiningOp(); if (!constOp) return rewriter.notifyMatchFailure( op, "reduce input must be const operation"); if (!inputOp.hasOneUse() && !this->aggressiveReduceConstant) return rewriter.notifyMatchFailure( op, "input operation has more than one user"); auto resultType = cast(op.getOutput().getType()); if (!resultType.hasStaticShape()) return rewriter.notifyMatchFailure(op, "result type shape is not static"); auto reductionAxis = op.getAxis(); const auto denseElementsAttr = constOp.getValue(); const auto shapedOldElementsValues = denseElementsAttr.getType().cast(); if (!llvm::isa(shapedOldElementsValues.getElementType())) return rewriter.notifyMatchFailure( op, "reduce input currently supported with integer type"); auto oldShape = shapedOldElementsValues.getShape(); auto newShape = resultType.getShape(); auto newNumOfElements = std::accumulate(newShape.begin(), newShape.end(), 1, std::multiplies()); llvm::SmallVector newReducedTensor(newNumOfElements); for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements; ++reductionIndex) { /// Let's reduce all the elements along this reduction axis newReducedTensor[reductionIndex] = calculateReducedValue( denseElementsAttr, oldShape, reductionAxis, reductionIndex); } auto rankedTensorType = cast(resultType); auto denseAttr = mlir::DenseElementsAttr::get(rankedTensorType, newReducedTensor); rewriter.replaceOpWithNewOp(op, rankedTensorType, denseAttr); return success(); } const bool aggressiveReduceConstant; }; } // namespace void mlir::tosa::populateTosaConstantReduction(MLIRContext *ctx, RewritePatternSet &patterns, bool aggressiveReduceConstant) { patterns.add>( ctx, aggressiveReduceConstant); patterns.add>( ctx, aggressiveReduceConstant); patterns.add>( ctx, aggressiveReduceConstant); patterns.add>( ctx, aggressiveReduceConstant); patterns.add>( ctx, aggressiveReduceConstant); patterns.add>( ctx, aggressiveReduceConstant); } void mlir::tosa::populateTosaFoldConstantTransposePatterns( MLIRContext *ctx, RewritePatternSet &patterns) { patterns.add(ctx); } void mlir::tosa::populateTosaFoldConstantReciprocalPatterns( MLIRContext *ctx, RewritePatternSet &patterns) { patterns.add(ctx); }