//===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // These rewriters lower from the Tosa to the Linalg dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include using namespace mlir; using namespace mlir::tosa; template static arith::ConstantOp createConstFromIntAttribute(Operation *op, const std::string &attrName, Type requiredAttrType, OpBuilder &rewriter) { auto castedN = static_cast( cast(op->getAttr(attrName)).getValue().getSExtValue()); return rewriter.create( op->getLoc(), IntegerAttr::get(requiredAttrType, castedN)); } static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef resultTypes, PatternRewriter &rewriter) { Location loc = op->getLoc(); auto elementTy = cast(op->getOperand(0).getType()).getElementType(); // tosa::AbsOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); if (isa(op) && isa(elementTy)) { auto zero = rewriter.create( loc, rewriter.getZeroAttr(elementTy)); auto cmp = rewriter.create(loc, arith::CmpIPredicate::sgt, args[0], zero); auto neg = rewriter.create(loc, zero, args[0]); return rewriter.create(loc, cmp, args[0], neg); } // tosa::AddOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::SubOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::MulOp if (isa(op) && isa(elementTy)) { if (dyn_cast(op).getShift() != 0) { (void)rewriter.notifyMatchFailure(op, "Cannot have shift value for float"); return nullptr; } return rewriter.create(loc, resultTypes, args); } // tosa::DivOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::ReciprocalOp if (isa(op) && isa(elementTy)) { auto one = rewriter.create(loc, FloatAttr::get(elementTy, 1)); return rewriter.create(loc, resultTypes, one, args[0]); } if (isa(op) && isa(elementTy)) { Value a = args[0]; Value b = args[1]; auto shift = cast(op->getAttr("shift")).getValue().getSExtValue(); if (shift > 0) { auto shiftConst = rewriter.create(loc, shift, /*bitwidth=*/8); if (!a.getType().isInteger(32)) a = rewriter.create(loc, rewriter.getI32Type(), a); if (!b.getType().isInteger(32)) b = rewriter.create(loc, rewriter.getI32Type(), b); auto result = rewriter.create( loc, rewriter.getI32Type(), a, b, shiftConst, rewriter.getBoolAttr(false)); if (elementTy.isInteger(32)) return result; return rewriter.create(loc, elementTy, result); } int aWidth = a.getType().getIntOrFloatBitWidth(); int bWidth = b.getType().getIntOrFloatBitWidth(); int cWidth = resultTypes[0].getIntOrFloatBitWidth(); if (aWidth < cWidth) a = rewriter.create(loc, resultTypes[0], a); if (bWidth < cWidth) b = rewriter.create(loc, resultTypes[0], b); return rewriter.create(loc, resultTypes, a, b); } // tosa::NegateOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); if (isa(op) && isa(elementTy) && !cast(op).getQuantizationInfo()) { auto constant = rewriter.create(loc, IntegerAttr::get(elementTy, 0)); return rewriter.create(loc, resultTypes, constant, args[0]); } if (isa(op) && isa(elementTy) && cast(op).getQuantizationInfo()) { auto quantizationInfo = cast(op).getQuantizationInfo(); int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth(); int64_t inZp = quantizationInfo.value().getInputZp(); int64_t outZp = quantizationInfo.value().getOutputZp(); // Compute the maximum value that can occur in the intermediate buffer. int64_t zpAdd = inZp + outZp; int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() + std::abs(zpAdd) + 1; // Convert that maximum value into the maximum bitwidth needed to represent // it. We assume 48-bit numbers may be supported further in the pipeline. int intermediateBitWidth = 64; if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) { intermediateBitWidth = 16; } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) { intermediateBitWidth = 32; } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) { intermediateBitWidth = 48; } Type intermediateType = rewriter.getIntegerType(intermediateBitWidth); Value zpAddValue = rewriter.create( loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); // The negation can be applied by doing: // outputValue = inZp + outZp - inputValue auto ext = rewriter.create(loc, intermediateType, args[0]); auto sub = rewriter.create(loc, zpAddValue, ext); // Clamp to the negation range. Value min = rewriter.create( loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(), intermediateType); Value max = rewriter.create( loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(), intermediateType); auto clamp = clampIntHelper(loc, sub, min, max, rewriter); // Truncate to the final value. return rewriter.create(loc, elementTy, clamp); } // tosa::BitwiseAndOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::BitwiseOrOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::BitwiseNotOp if (isa(op) && isa(elementTy)) { auto allOnesAttr = rewriter.getIntegerAttr( elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth())); auto allOnes = rewriter.create(loc, allOnesAttr); return rewriter.create(loc, resultTypes, args[0], allOnes); } // tosa::BitwiseXOrOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::LogicalLeftShiftOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::LogicalRightShiftOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::ArithmeticRightShiftOp if (isa(op) && isa(elementTy)) { auto result = rewriter.create(loc, resultTypes, args); auto round = cast(op->getAttr("round")).getValue(); if (!round) { return result; } Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1); auto one = rewriter.create(loc, IntegerAttr::get(elementTy, 1)); auto zero = rewriter.create(loc, IntegerAttr::get(elementTy, 0)); auto i1one = rewriter.create(loc, IntegerAttr::get(i1Ty, 1)); // Checking that input2 != 0 auto shiftValueGreaterThanZero = rewriter.create( loc, arith::CmpIPredicate::sgt, args[1], zero); // Checking for the last bit of input1 to be 1 auto subtract = rewriter.create(loc, resultTypes, args[1], one); auto shifted = rewriter.create(loc, resultTypes, args[0], subtract) ->getResults(); auto truncated = rewriter.create(loc, i1Ty, shifted, std::nullopt); auto isInputOdd = rewriter.create(loc, i1Ty, truncated, i1one); auto shouldRound = rewriter.create( loc, i1Ty, shiftValueGreaterThanZero, isInputOdd); auto extended = rewriter.create(loc, resultTypes, shouldRound); return rewriter.create(loc, resultTypes, result, extended); } // tosa::ClzOp if (isa(op) && isa(elementTy)) { return rewriter.create(loc, elementTy, args[0]); } // tosa::LogicalAnd if (isa(op) && elementTy.isInteger(1)) return rewriter.create(loc, resultTypes, args); // tosa::LogicalNot if (isa(op) && elementTy.isInteger(1)) { auto one = rewriter.create( loc, rewriter.getIntegerAttr(elementTy, 1)); return rewriter.create(loc, resultTypes, args[0], one); } // tosa::LogicalOr if (isa(op) && elementTy.isInteger(1)) return rewriter.create(loc, resultTypes, args); // tosa::LogicalXor if (isa(op) && elementTy.isInteger(1)) return rewriter.create(loc, resultTypes, args); // tosa::PowOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::RsqrtOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::LogOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::ExpOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::TanhOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::ErfOp if (isa(op) && llvm::isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::GreaterOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, arith::CmpFPredicate::OGT, args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) return rewriter.create(loc, arith::CmpIPredicate::sgt, args[0], args[1]); // tosa::GreaterEqualOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, arith::CmpFPredicate::OGE, args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) return rewriter.create(loc, arith::CmpIPredicate::sge, args[0], args[1]); // tosa::EqualOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, arith::CmpFPredicate::OEQ, args[0], args[1]); if (isa(op) && elementTy.isSignlessInteger()) return rewriter.create(loc, arith::CmpIPredicate::eq, args[0], args[1]); // tosa::SelectOp if (isa(op)) { elementTy = cast(op->getOperand(1).getType()).getElementType(); if (isa(elementTy) || isa(elementTy)) return rewriter.create(loc, args[0], args[1], args[2]); } // tosa::MaximumOp if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args[0], args[1]); } if (isa(op) && elementTy.isSignlessInteger()) { auto predicate = rewriter.create( loc, arith::CmpIPredicate::sgt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } // tosa::MinimumOp if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args[0], args[1]); } if (isa(op) && elementTy.isSignlessInteger()) { auto predicate = rewriter.create( loc, arith::CmpIPredicate::slt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } // tosa::CeilOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::FloorOp if (isa(op) && isa(elementTy)) return rewriter.create(loc, resultTypes, args); // tosa::ClampOp if (isa(op) && isa(elementTy)) { bool losesInfo = false; APFloat minApf = cast(op->getAttr("min_fp")).getValue(); APFloat maxApf = cast(op->getAttr("max_fp")).getValue(); minApf.convert(cast(elementTy).getFloatSemantics(), APFloat::rmNearestTiesToEven, &losesInfo); maxApf.convert(cast(elementTy).getFloatSemantics(), APFloat::rmNearestTiesToEven, &losesInfo); auto min = rewriter.create( loc, elementTy, rewriter.getFloatAttr(elementTy, minApf)); auto max = rewriter.create( loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf)); return clampFloatHelper(loc, args[0], min, max, rewriter); } if (isa(op) && isa(elementTy)) { auto intTy = cast(elementTy); int32_t min = static_cast( cast(op->getAttr("min_int")).getValue().getSExtValue()); int32_t max = static_cast( cast(op->getAttr("max_int")).getValue().getSExtValue()); if (intTy.isUnsignedInteger()) { min = std::max(min, 0); max = std::min( max, APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue()); } else { min = std::max( min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth()) .getSExtValue()); max = std::min( max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth()) .getSExtValue()); } auto minVal = rewriter.create( loc, min, intTy.getIntOrFloatBitWidth()); auto maxVal = rewriter.create( loc, max, intTy.getIntOrFloatBitWidth()); return clampIntHelper(loc, args[0], minVal, maxVal, rewriter); } // tosa::SigmoidOp if (isa(op) && isa(elementTy)) { auto one = rewriter.create(loc, FloatAttr::get(elementTy, 1)); auto negate = rewriter.create(loc, resultTypes, args[0]); auto exp = rewriter.create(loc, resultTypes, negate); auto added = rewriter.create(loc, resultTypes, exp, one); return rewriter.create(loc, resultTypes, one, added); } // tosa::CastOp if (isa(op)) { Type srcTy = elementTy; Type dstTy = resultTypes.front(); bool bitExtend = srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth(); if (srcTy == dstTy) return args.front(); if (isa(srcTy) && isa(dstTy) && bitExtend) return rewriter.create(loc, resultTypes, args, std::nullopt); if (isa(srcTy) && isa(dstTy) && !bitExtend) return rewriter.create(loc, resultTypes, args, std::nullopt); // 1-bit integers need to be treated as signless. if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy)) return rewriter.create(loc, resultTypes, args, std::nullopt); if (srcTy.isInteger(1) && isa(dstTy) && bitExtend) return rewriter.create(loc, resultTypes, args, std::nullopt); // Unsigned integers need an unrealized cast so that they can be passed // to UIToFP. if (srcTy.isUnsignedInteger() && isa(dstTy)) { auto unrealizedCast = rewriter .create( loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0]) .getResult(0); return rewriter.create(loc, resultTypes[0], unrealizedCast); } // All other si-to-fp conversions should be handled by SIToFP. if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy)) return rewriter.create(loc, resultTypes, args, std::nullopt); // Casting to boolean, floats need to only be checked as not-equal to zero. if (isa(srcTy) && dstTy.isInteger(1)) { Value zero = rewriter.create( loc, rewriter.getFloatAttr(srcTy, 0.0)); return rewriter.create(loc, arith::CmpFPredicate::UNE, args.front(), zero); } if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) { auto intMin = rewriter.create( loc, rewriter.getFloatAttr( getElementTypeOrSelf(srcTy), APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); auto intMax = rewriter.create( loc, rewriter.getFloatAttr( getElementTypeOrSelf(srcTy), APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); auto rounded = rewriter.create(loc, args[0]); auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter); return rewriter.create(loc, dstTy, clamped); } // Casting to boolean, integers need to only be checked as not-equal to // zero. if (isa(srcTy) && dstTy.isInteger(1)) { Value zero = rewriter.create( loc, 0, srcTy.getIntOrFloatBitWidth()); return rewriter.create(loc, arith::CmpIPredicate::ne, args.front(), zero); } if (isa(srcTy) && isa(dstTy) && bitExtend) return rewriter.create(loc, resultTypes, args, std::nullopt); if (isa(srcTy) && isa(dstTy) && !bitExtend) { return rewriter.create(loc, dstTy, args[0]); } } (void)rewriter.notifyMatchFailure( op, "unhandled op for linalg body calculation for elementwise op"); return nullptr; } static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor, int64_t rank) { // No need to expand if we are already at the desired rank auto shapedType = dyn_cast(tensor.getType()); assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type"); int64_t numExtraDims = rank - shapedType.getRank(); assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank"); if (!numExtraDims) return tensor; // Compute reassociation indices SmallVector> reassociationIndices( shapedType.getRank()); int64_t index = 0; for (index = 0; index <= numExtraDims; index++) reassociationIndices[0].push_back(index); for (size_t position = 1; position < reassociationIndices.size(); position++) reassociationIndices[position].push_back(index++); // Compute result type SmallVector resultShape; for (index = 0; index < numExtraDims; index++) resultShape.push_back(1); for (auto size : shapedType.getShape()) resultShape.push_back(size); auto resultType = RankedTensorType::get(resultShape, shapedType.getElementType()); // Emit 'tensor.expand_shape' op return rewriter.create(loc, resultType, tensor, reassociationIndices); } static SmallVector expandInputRanks(PatternRewriter &rewriter, Location loc, Operation *operation) { auto rank = operation->getResultTypes().front().cast().getRank(); return llvm::map_to_vector(operation->getOperands(), [&](Value operand) { return expandRank(rewriter, loc, operand, rank); }); } using IndexPool = DenseMap; // Emit an 'arith.constant' op for the given index if it has not been created // yet, or return an existing constant. This will prevent an excessive creation // of redundant constants, easing readability of emitted code for unit tests. static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index) { auto [it, inserted] = indexPool.try_emplace(index); if (inserted) it->second = rewriter.create(loc, rewriter.getIndexAttr(index)); return it->second; } static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index) { auto indexValue = createIndex(rewriter, loc, indexPool, index); return rewriter.create(loc, tensor, indexValue).getResult(); } static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index) { auto shapedType = dyn_cast(tensor.getType()); assert(shapedType && shapedType.hasRank() && "expected a ranked shaped type"); assert(index >= 0 && index < shapedType.getRank() && "index out of bounds"); if (shapedType.isDynamicDim(index)) return getTensorDim(rewriter, loc, indexPool, tensor, index); return rewriter.getIndexAttr(shapedType.getDimSize(index)); } static bool operandsAndResultsRanked(Operation *operation) { auto isRanked = [](Value value) { return isa(value.getType()); }; return llvm::all_of(operation->getOperands(), isRanked) && llvm::all_of(operation->getResults(), isRanked); } // Compute the runtime dimension size for dimension 'dim' of the output by // inspecting input 'operands', all of which are expected to have the same rank. // This function returns a pair {targetSize, masterOperand}. // // The runtime size of the output dimension is returned either as a statically // computed attribute or as a runtime SSA value. // // If the target size was inferred directly from one dominating operand, that // operand is returned in 'masterOperand'. If the target size is inferred from // multiple operands, 'masterOperand' is set to nullptr. static std::pair computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim) { // If any input operand contains a static size greater than 1 for this // dimension, that is the target size. An occurrence of an additional static // dimension greater than 1 with a different value is undefined behavior. for (auto operand : operands) { auto size = operand.getType().cast().getDimSize(dim); if (!ShapedType::isDynamic(size) && size > 1) return {rewriter.getIndexAttr(size), operand}; } // Filter operands with dynamic dimension auto operandsWithDynamicDim = llvm::to_vector(llvm::make_filter_range(operands, [&](Value operand) { return operand.getType().cast().isDynamicDim(dim); })); // If no operand has a dynamic dimension, it means all sizes were 1 if (operandsWithDynamicDim.empty()) return {rewriter.getIndexAttr(1), operands.front()}; // Emit code that computes the runtime size for this dimension. If there is // only one operand with a dynamic dimension, it is considered the master // operand that determines the runtime size of the output dimension. auto targetSize = getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim); if (operandsWithDynamicDim.size() == 1) return {targetSize, operandsWithDynamicDim[0]}; // Calculate maximum size among all dynamic dimensions for (size_t i = 1; i < operandsWithDynamicDim.size(); i++) { auto nextSize = getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim); targetSize = rewriter.create(loc, targetSize, nextSize); } return {targetSize, nullptr}; } // Compute the runtime output size for all dimensions. This function returns // a pair {targetShape, masterOperands}. static std::pair, SmallVector> computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands) { assert(!operands.empty()); auto rank = operands.front().getType().cast().getRank(); SmallVector targetShape; SmallVector masterOperands; for (auto dim : llvm::seq(0, rank)) { auto [targetSize, masterOperand] = computeTargetSize(rewriter, loc, indexPool, operands, dim); targetShape.push_back(targetSize); masterOperands.push_back(masterOperand); } return {targetShape, masterOperands}; } static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand) { // Nothing to do if this is a static dimension auto rankedTensorType = operand.getType().cast(); if (!rankedTensorType.isDynamicDim(dim)) return operand; // If the target size for this dimension was directly inferred by only taking // this operand into account, there is no need to broadcast. This is an // optimization that will prevent redundant control flow, and constitutes the // main motivation for tracking "master operands". if (operand == masterOperand) return operand; // Affine maps for 'linalg.generic' op auto rank = rankedTensorType.getRank(); SmallVector affineExprs; for (auto index : llvm::seq(0, rank)) { auto affineExpr = index == dim ? rewriter.getAffineConstantExpr(0) : rewriter.getAffineDimExpr(index); affineExprs.push_back(affineExpr); } auto broadcastAffineMap = AffineMap::get(rank, 0, affineExprs, rewriter.getContext()); auto identityAffineMap = rewriter.getMultiDimIdentityMap(rank); SmallVector affineMaps = {broadcastAffineMap, identityAffineMap}; // Check if broadcast is necessary auto one = createIndex(rewriter, loc, indexPool, 1); auto runtimeSize = getTensorDim(rewriter, loc, indexPool, operand, dim); auto broadcastNecessary = rewriter.create( loc, arith::CmpIPredicate::eq, runtimeSize, one); // Emit 'then' region of 'scf.if' auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) { // Emit 'tensor.empty' op SmallVector outputTensorShape; for (auto index : llvm::seq(0, rank)) { auto size = index == dim ? targetSize : getOrFoldTensorDim(rewriter, loc, indexPool, operand, index); outputTensorShape.push_back(size); } Value outputTensor = opBuilder.create( loc, outputTensorShape, rankedTensorType.getElementType()); // Emit 'linalg.generic' op auto resultTensor = opBuilder .create( loc, outputTensor.getType(), operand, outputTensor, affineMaps, getNParallelLoopsAttrs(rank), [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { // Emit 'linalg.yield' op opBuilder.create(loc, blockArgs.front()); }) .getResult(0); // Cast to original operand type if necessary auto castResultTensor = rewriter.createOrFold( loc, operand.getType(), resultTensor); // Emit 'scf.yield' op opBuilder.create(loc, castResultTensor); }; // Emit 'else' region of 'scf.if' auto emitElseRegion = [&](OpBuilder &opBuilder, Location loc) { opBuilder.create(loc, operand); }; // Emit 'scf.if' op auto ifOp = rewriter.create(loc, broadcastNecessary, emitThenRegion, emitElseRegion); return ifOp.getResult(0); } static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef targetShape, ArrayRef masterOperands) { size_t rank = operand.getType().cast().getRank(); assert(targetShape.size() == rank); assert(masterOperands.size() == rank); for (auto index : llvm::seq(0, rank)) operand = broadcastDynamicDimension(rewriter, loc, indexPool, operand, index, targetShape[index], masterOperands[index]); return operand; } static SmallVector broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, ArrayRef targetShape, ArrayRef masterOperands) { // No need to broadcast for unary operations if (operands.size() == 1) return operands; // Broadcast dynamic dimensions operand by operand return llvm::map_to_vector(operands, [&](Value operand) { return broadcastDynamicDimensions(rewriter, loc, indexPool, operand, targetShape, masterOperands); }); } static LogicalResult emitElementwiseComputation(PatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef targetShape) { // Generate output tensor auto resultType = operation->getResultTypes().front().cast(); Value outputTensor = rewriter.create( loc, targetShape, resultType.getElementType()); // Create affine maps. Input affine maps broadcast static dimensions of size // 1. The output affine map is an identity map. // auto rank = resultType.getRank(); auto affineMaps = llvm::map_to_vector(operands, [&](Value operand) { auto shape = cast(operand.getType()).getShape(); SmallVector affineExprs; for (auto it : llvm::enumerate(shape)) { auto affineExpr = it.value() == 1 ? rewriter.getAffineConstantExpr(0) : rewriter.getAffineDimExpr(it.index()); affineExprs.push_back(affineExpr); } return AffineMap::get(rank, 0, affineExprs, rewriter.getContext()); }); affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); // Emit 'linalg.generic' op bool encounteredError = false; auto linalgOp = rewriter.create( loc, outputTensor.getType(), operands, outputTensor, affineMaps, getNParallelLoopsAttrs(rank), [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { Value opResult = createLinalgBodyCalculationForElementwiseOp( operation, blockArgs.take_front(operation->getNumOperands()), {resultType.getElementType()}, rewriter); if (!opResult) { encounteredError = true; return; } opBuilder.create(loc, opResult); }); if (encounteredError) return rewriter.notifyMatchFailure( operation, "unable to create linalg.generic body for elementwise op"); // Cast 'linalg.generic' result into original result type if needed auto castResult = rewriter.createOrFold( loc, resultType, linalgOp->getResult(0)); rewriter.replaceOp(operation, castResult); return success(); } static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, PatternRewriter &rewriter) { // Collect op properties assert(operation->getNumResults() == 1 && "elementwise op expects 1 result"); assert(operation->getNumOperands() >= 1 && "elementwise op expects at least 1 operand"); if (!operandsAndResultsRanked(operation)) return rewriter.notifyMatchFailure(operation, "Unranked tensors not supported"); // Lower operation IndexPool indexPool; auto loc = operation->getLoc(); auto expandedOperands = expandInputRanks(rewriter, loc, operation); auto [targetShape, masterOperands] = computeTargetShape(rewriter, loc, indexPool, expandedOperands); auto broadcastOperands = broadcastDynamicDimensions( rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands); return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands, targetShape); } // Returns the constant initial value for a given reduction operation. The // attribute type varies depending on the element type required. static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr(elementTy, 0.0); if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr(elementTy, 0); if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr(elementTy, 1.0); if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr(elementTy, 1); if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr( elementTy, APFloat::getLargest( cast(elementTy).getFloatSemantics(), false)); if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr( elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth())); if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr( elementTy, APFloat::getLargest( cast(elementTy).getFloatSemantics(), true)); if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr( elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); if (isa(op) && elementTy.isInteger(1)) return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1)); if (isa(op) && elementTy.isInteger(1)) return rewriter.getIntegerAttr(elementTy, APInt::getZero(1)); if (isa(op) && isa(elementTy)) return rewriter.getFloatAttr( elementTy, APFloat::getLargest( cast(elementTy).getFloatSemantics(), true)); if (isa(op) && isa(elementTy)) return rewriter.getIntegerAttr( elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())); return {}; } // Creates the body calculation for a reduction. The operations vary depending // on the input type. static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter) { Location loc = op->getLoc(); if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args); } if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args); } if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args); } if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args); } if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args[0], args[1]); } if (isa(op) && isa(elementTy)) { auto predicate = rewriter.create( loc, arith::CmpIPredicate::slt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && isa(elementTy)) { return rewriter.create(loc, args[0], args[1]); } if (isa(op) && isa(elementTy)) { auto predicate = rewriter.create( loc, arith::CmpIPredicate::sgt, args[0], args[1]); return rewriter.create(loc, predicate, args[0], args[1]); } if (isa(op) && elementTy.isInteger(1)) return rewriter.create(loc, args); if (isa(op) && elementTy.isInteger(1)) return rewriter.create(loc, args); return {}; } // Performs the match and rewrite for reduction operations. This includes // declaring a correctly sized initial value, and the linalg.generic operation // that reduces across the specified axis. static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter) { auto loc = op->getLoc(); auto inputTy = cast(op->getOperand(0).getType()); auto resultTy = cast(op->getResult(0).getType()); auto elementTy = resultTy.getElementType(); Value input = op->getOperand(0); SmallVector reduceShape; SmallVector dynDims; for (unsigned i = 0; i < inputTy.getRank(); i++) { if (axis != i) { reduceShape.push_back(inputTy.getDimSize(i)); if (inputTy.isDynamicDim(i)) dynDims.push_back(rewriter.create(loc, input, i)); } } // First fill the output buffer with the init value. auto emptyTensor = rewriter .create(loc, reduceShape, resultTy.getElementType(), dynDims) .getResult(); auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); if (!fillValueAttr) return rewriter.notifyMatchFailure( op, "No initial value found for reduction operation"); auto fillValue = rewriter.create(loc, fillValueAttr); auto filledTensor = rewriter .create(loc, ValueRange{fillValue}, ValueRange{emptyTensor}) .result(); bool didEncounterError = false; auto linalgOp = rewriter.create( loc, input, filledTensor, axis, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { auto result = createLinalgBodyCalculationForReduceOp( op, blockArgs, elementTy, rewriter); if (result) didEncounterError = true; nestedBuilder.create(loc, result); }); if (!didEncounterError) return rewriter.notifyMatchFailure( op, "unable to create linalg.generic body for reduce op"); SmallVector reassociationMap; uint64_t expandInputRank = cast(linalgOp.getResults()[0].getType()).getRank(); reassociationMap.resize(expandInputRank); for (uint64_t i = 0; i < expandInputRank; i++) { int32_t dimToPush = i > axis ? i + 1 : i; reassociationMap[i].push_back(rewriter.getAffineDimExpr(dimToPush)); } if (expandInputRank != 0) { int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1; reassociationMap[expandedDim].push_back( rewriter.getAffineDimExpr(expandedDim + 1)); } // Lower directly to `tensor::ExpandShapeOp` instead of `tosa::ReshapeOp`, // since here we know which dimension to expand, and `tosa::ReshapeOp` would // not have access to such information. This matters when handling dynamically // sized tensors. rewriter.replaceOpWithNewOp( op, resultTy, linalgOp.getResults()[0], reassociationMap); return success(); } namespace { template class PointwiseConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const final { return elementwiseMatchAndRewriteHelper(op, rewriter); } }; class RescaleConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::RescaleOp op, PatternRewriter &rewriter) const final { auto loc = op.getLoc(); auto input = op.getInput(); auto inputTy = cast(op.getInput().getType()); auto outputTy = cast(op.getOutput().getType()); unsigned rank = inputTy.getRank(); // This is an illegal configuration. terminate and log an error if (op.getDoubleRound() && !op.getScale32()) return rewriter.notifyMatchFailure( op, "tosa.rescale requires scale32 for double_round to be true"); SmallVector dynDims; for (int i = 0; i < outputTy.getRank(); i++) { if (outputTy.isDynamicDim(i)) { dynDims.push_back(rewriter.create(loc, input, i)); } } // The shift and multiplier values. SmallVector multiplierValues(op.getMultiplier()); SmallVector shiftValues(op.getShift()); // If we shift by more than the bitwidth, this just sets to 0. for (int i = 0, s = multiplierValues.size(); i < s; i++) { if (shiftValues[i] > 63) { shiftValues[i] = 0; multiplierValues[i] = 0; } } // Double round only occurs if shift is greater than 31, check that this // is ever true. bool doubleRound = op.getDoubleRound() && llvm::any_of(shiftValues, [](int32_t v) { return v > 31; }); SmallVector indexingMaps = { rewriter.getMultiDimIdentityMap(rank)}; SmallVector genericInputs = {input}; // If we are rescaling per-channel then we need to store the multiplier // values in a buffer. Value multiplierConstant; int64_t multiplierArg = 0; if (multiplierValues.size() == 1) { multiplierConstant = rewriter.create( loc, rewriter.getI32IntegerAttr(multiplierValues.front())); } else { SmallVector multiplierExprs{ rewriter.getAffineDimExpr(rank - 1)}; auto multiplierType = RankedTensorType::get({static_cast(multiplierValues.size())}, rewriter.getI32Type()); genericInputs.push_back(rewriter.create( loc, DenseIntElementsAttr::get(multiplierType, multiplierValues))); indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, multiplierExprs, rewriter.getContext())); multiplierArg = indexingMaps.size() - 1; } // If we are rescaling per-channel then we need to store the shift // values in a buffer. Value shiftConstant; int64_t shiftArg = 0; if (shiftValues.size() == 1) { shiftConstant = rewriter.create( loc, rewriter.getI8IntegerAttr(shiftValues.front())); } else { SmallVector shiftExprs = { rewriter.getAffineDimExpr(rank - 1)}; auto shiftType = RankedTensorType::get({static_cast(shiftValues.size())}, rewriter.getIntegerType(8)); genericInputs.push_back(rewriter.create( loc, DenseIntElementsAttr::get(shiftType, shiftValues))); indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, shiftExprs, rewriter.getContext())); shiftArg = indexingMaps.size() - 1; } // Indexing maps for output values. indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); // Construct the indexing maps needed for linalg.generic ops. Value emptyTensor = rewriter.create( loc, outputTy.getShape(), outputTy.getElementType(), ArrayRef({dynDims})); auto linalgOp = rewriter.create( loc, outputTy, genericInputs, ValueRange{emptyTensor}, indexingMaps, getNParallelLoopsAttrs(rank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { Value value = blockArgs[0]; Type valueTy = value.getType(); // For now we do all of our math in 64-bit. This is not optimal but // should be correct for now, consider computing correct bit depth // later. int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32; auto inputZp = createConstFromIntAttribute( op, "input_zp", nestedBuilder.getIntegerType(inBitwidth), nestedBuilder); auto outputZp = createConstFromIntAttribute( op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder); Value multiplier = multiplierConstant ? multiplierConstant : blockArgs[multiplierArg]; Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg]; if (valueTy.getIntOrFloatBitWidth() < 32) { if (valueTy.isUnsignedInteger()) { value = nestedBuilder .create( nestedLoc, nestedBuilder.getIntegerType( valueTy.getIntOrFloatBitWidth()), value) .getResult(0); value = nestedBuilder.create( nestedLoc, nestedBuilder.getI32Type(), value); } else { value = nestedBuilder.create( nestedLoc, nestedBuilder.getI32Type(), value); } } value = nestedBuilder.create(nestedLoc, value, inputZp); value = nestedBuilder.create( loc, nestedBuilder.getI32Type(), value, multiplier, shift, nestedBuilder.getBoolAttr(doubleRound)); // Move to the new zero-point. value = nestedBuilder.create(nestedLoc, value, outputZp); // Saturate to the output size. IntegerType outIntType = cast(blockArgs.back().getType()); unsigned outBitWidth = outIntType.getWidth(); int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue(); int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue(); // Unsigned integers have a difference output value. if (outIntType.isUnsignedInteger()) { intMin = 0; intMax = APInt::getMaxValue(outBitWidth).getZExtValue(); } auto intMinVal = nestedBuilder.create( loc, nestedBuilder.getI32IntegerAttr(intMin)); auto intMaxVal = nestedBuilder.create( loc, nestedBuilder.getI32IntegerAttr(intMax)); value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal, nestedBuilder); if (outIntType.getWidth() < 32) { value = nestedBuilder.create( nestedLoc, rewriter.getIntegerType(outIntType.getWidth()), value); if (outIntType.isUnsignedInteger()) { value = nestedBuilder .create(nestedLoc, outIntType, value) .getResult(0); } } nestedBuilder.create(loc, value); }); rewriter.replaceOp(op, linalgOp->getResults()); return success(); } }; // Handle the resize case where the input is a 1x1 image. This case // can entirely avoiding having extract operations which target much // more difficult to optimize away. class ResizeUnaryConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ResizeOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); ImplicitLocOpBuilder builder(loc, rewriter); auto input = op.getInput(); auto inputTy = cast(input.getType()); auto resultTy = cast(op.getType()); const bool isBilinear = op.getMode() == "BILINEAR"; auto inputH = inputTy.getDimSize(1); auto inputW = inputTy.getDimSize(2); auto outputH = resultTy.getDimSize(1); auto outputW = resultTy.getDimSize(2); if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1) return rewriter.notifyMatchFailure( op, "tosa.resize is not a pure 1x1->1x1 image operation"); // TODO(suderman): These string values should be declared the TOSA dialect. if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR") return rewriter.notifyMatchFailure( op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR"); if (inputTy == resultTy) { rewriter.replaceOp(op, input); return success(); } ArrayRef scale = op.getScale(); // Collapse the unit width and height away. SmallVector reassociationMap(2); reassociationMap[0].push_back(builder.getAffineDimExpr(0)); reassociationMap[1].push_back(builder.getAffineDimExpr(1)); reassociationMap[1].push_back(builder.getAffineDimExpr(2)); reassociationMap[1].push_back(builder.getAffineDimExpr(3)); auto collapseTy = RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)}, inputTy.getElementType()); Value collapse = builder.create(collapseTy, input, reassociationMap); // Get any dynamic shapes that appear in the input format. llvm::SmallVector outputDynSize; if (inputTy.isDynamicDim(0)) outputDynSize.push_back(builder.create(input, 0)); if (inputTy.isDynamicDim(3)) outputDynSize.push_back(builder.create(input, 3)); // Generate the elementwise operation for casting scaling the input value. auto genericTy = collapseTy.clone(resultTy.getElementType()); Value empty = builder.create( genericTy.getShape(), resultTy.getElementType(), outputDynSize); auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank()); SmallVector iterators(genericTy.getRank(), utils::IteratorType::parallel); auto generic = builder.create( genericTy, ValueRange{collapse}, ValueRange{empty}, ArrayRef{genericMap, genericMap}, iterators, [=](OpBuilder &b, Location loc, ValueRange args) { Value value = args[0]; // This is the quantized case. if (inputTy.getElementType() != resultTy.getElementType()) { value = b.create(loc, resultTy.getElementType(), value); if (isBilinear && scale[0] != 0) { Value scaleY = b.create( loc, b.getI32IntegerAttr(scale[0])); value = b.create(loc, value, scaleY); } if (isBilinear && scale[2] != 0) { Value scaleX = b.create( loc, b.getI32IntegerAttr(scale[2])); value = b.create(loc, value, scaleX); } } b.create(loc, value); }); rewriter.replaceOpWithNewOp( op, resultTy, generic.getResults()[0], reassociationMap); return success(); } }; // TOSA resize with width or height of 1 may be broadcasted to a wider // dimension. This is done by materializing a new tosa.resize without // the broadcasting behavior, and an explicit broadcast afterwards. class MaterializeResizeBroadcast : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ResizeOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); ImplicitLocOpBuilder builder(loc, rewriter); auto input = op.getInput(); auto inputTy = dyn_cast(input.getType()); auto resultTy = dyn_cast(op.getType()); if (!inputTy || !resultTy) return rewriter.notifyMatchFailure(op, "requires ranked input/output types"); auto batch = inputTy.getDimSize(0); auto channels = inputTy.getDimSize(3); auto inputH = inputTy.getDimSize(1); auto inputW = inputTy.getDimSize(2); auto outputH = resultTy.getDimSize(1); auto outputW = resultTy.getDimSize(2); if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1)) return rewriter.notifyMatchFailure( op, "tosa.resize has no broadcasting behavior"); // For any dimension that is broadcastable we generate a width of 1 // on the output. llvm::SmallVector resizeShape; resizeShape.push_back(batch); resizeShape.push_back(inputH == 1 ? 1 : outputH); resizeShape.push_back(inputW == 1 ? 1 : outputW); resizeShape.push_back(channels); auto resizeTy = resultTy.clone(resizeShape); auto resize = builder.create(resizeTy, input, op->getAttrs()); // Collapse an unit result dims. SmallVector reassociationMap(2); reassociationMap[0].push_back(builder.getAffineDimExpr(0)); reassociationMap.back().push_back(builder.getAffineDimExpr(1)); if (inputH != 1) reassociationMap.push_back({}); reassociationMap.back().push_back(builder.getAffineDimExpr(2)); if (inputW != 1) reassociationMap.push_back({}); reassociationMap.back().push_back(builder.getAffineDimExpr(3)); llvm::SmallVector collapseShape{batch}; if (inputH != 1) collapseShape.push_back(outputH); if (inputW != 1) collapseShape.push_back(outputW); collapseShape.push_back(channels); auto collapseTy = resultTy.clone(collapseShape); Value collapse = builder.create(collapseTy, resize, reassociationMap); // Broadcast the collapsed shape to the output result. llvm::SmallVector outputDynSize; if (inputTy.isDynamicDim(0)) outputDynSize.push_back(builder.create(input, 0)); if (inputTy.isDynamicDim(3)) outputDynSize.push_back(builder.create(input, 3)); SmallVector iterators(resultTy.getRank(), utils::IteratorType::parallel); Value empty = builder.create( resultTy.getShape(), resultTy.getElementType(), outputDynSize); SmallVector inputExprs{rewriter.getAffineDimExpr(0)}; if (inputH != 1) inputExprs.push_back(rewriter.getAffineDimExpr(1)); if (inputW != 1) inputExprs.push_back(rewriter.getAffineDimExpr(2)); inputExprs.push_back(rewriter.getAffineDimExpr(3)); auto inputMap = AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs, rewriter.getContext()); auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank()); rewriter.replaceOpWithNewOp( op, resultTy, ValueRange{collapse}, ValueRange{empty}, ArrayRef{inputMap, outputMap}, iterators, [=](OpBuilder &b, Location loc, ValueRange args) { Value value = args[0]; b.create(loc, value); }); return success(); } }; class GenericResizeConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ResizeOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); ImplicitLocOpBuilder b(loc, rewriter); auto input = op.getInput(); auto inputTy = cast(input.getType()); auto resultTy = cast(op.getType()); auto resultETy = resultTy.getElementType(); bool floatingPointMode = resultETy.isF16() || resultETy.isF32(); auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type(); auto imageH = inputTy.getShape()[1]; auto imageW = inputTy.getShape()[2]; auto dynamicDimsOr = checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()}); if (!dynamicDimsOr.has_value()) return rewriter.notifyMatchFailure( op, "unable to get dynamic dimensions of tosa.resize"); if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR") return rewriter.notifyMatchFailure( op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR"); SmallVector affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; auto emptyTensor = b.create(resultTy.getShape(), resultETy, *dynamicDimsOr); auto genericOp = b.create( resultTy, ValueRange({}), ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank())); Value resize = genericOp.getResult(0); { OpBuilder::InsertionGuard regionGuard(b); b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(), TypeRange({resultETy}), loc); Value batch = b.create(0); Value y = b.create(1); Value x = b.create(2); Value channel = b.create(3); Value zeroI32 = b.create(b.getZeroAttr(b.getI32Type())); Value zeroFp = b.create(b.getZeroAttr(floatTy)); Value hMax = b.create(b.getI32IntegerAttr(imageH - 1)); Value wMax = b.create(b.getI32IntegerAttr(imageW - 1)); Value inY = b.create(b.getI32Type(), y); Value inX = b.create(b.getI32Type(), x); ArrayRef offset = op.getOffset(); ArrayRef border = op.getBorder(); ArrayRef scale = op.getScale(); Value yScaleN, yScaleD, xScaleN, xScaleD; yScaleN = b.create(b.getI32IntegerAttr(scale[0])); yScaleD = b.create(b.getI32IntegerAttr(scale[1])); xScaleN = b.create(b.getI32IntegerAttr(scale[2])); xScaleD = b.create(b.getI32IntegerAttr(scale[3])); Value yOffset, xOffset, yBorder, xBorder; yOffset = b.create(b.getI32IntegerAttr(offset[0])); xOffset = b.create(b.getI32IntegerAttr(offset[1])); yBorder = b.create(b.getI32IntegerAttr(border[0])); xBorder = b.create(b.getI32IntegerAttr(border[1])); // Compute the ix and dx values for both the X and Y dimensions. auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in, Value scaleN, Value scaleD, Value offset, int size, ImplicitLocOpBuilder &b) { if (size == 1) { index = zeroI32; delta = zeroFp; return; } // x = x * scale_d + offset; // ix = floor(x / scale_n) // dx = x / scale_n - ix Value val = b.create(floatTy, in); scaleN = b.create(floatTy, scaleN); scaleD = b.create(floatTy, scaleD); offset = b.create(floatTy, offset); val = b.create(val, scaleD); val = b.create(val, offset); val = b.create(val, scaleN); index = b.create(val); delta = b.create(val, index); index = b.create(b.getI32Type(), index); }; // Compute the ix and dx values for the X and Y dimensions - int case. auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in, Value scaleN, Value scaleD, Value offset, int size, ImplicitLocOpBuilder &b) { if (size == 1) { index = zeroI32; delta = zeroI32; return; } // x = x * scale_d + offset; // ix = floor(x / scale_n) // dx = x - ix * scale_n; Value val = b.create(in, scaleD); val = b.create(val, offset); index = b.create(val, scaleN); delta = b.create(index, scaleN); delta = b.create(val, delta); }; Value ix, iy, dx, dy; if (floatingPointMode) { getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b); getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b); } else { getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b); getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b); } if (op.getMode() == "NEAREST_NEIGHBOR") { auto one = b.create(b.getI32IntegerAttr(1)); auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale, Value max, int size, ImplicitLocOpBuilder &b) -> Value { if (size == 1) { return b.create(0); } Value pred; if (floatingPointMode) { auto h = b.create(b.getFloatAttr(floatTy, 0.5f)); pred = b.create(arith::CmpFPredicate::OGE, dval, h); } else { Value dvalDouble = b.create(dval, one); pred = b.create(arith::CmpIPredicate::sge, dvalDouble, scale); } auto offset = b.create(pred, one, zeroI32); val = b.create(val, offset); val = clampIntHelper(loc, val, zeroI32, max, b); return b.create(b.getIndexType(), val); }; iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b); ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b); Value result = b.create( input, ValueRange{batch, iy, ix, channel}); b.create(result); } else { // The mode here must be BILINEAR. assert(op.getMode() == "BILINEAR"); auto oneVal = b.create(b.getI32IntegerAttr(1)); auto getClampedIdxs = [&](Value &val0, Value &val1, int size, Value in, Value max, ImplicitLocOpBuilder &b) { val0 = in; val1 = b.create(val0, oneVal); val0 = clampIntHelper(loc, val0, zeroI32, max, b); val1 = clampIntHelper(loc, val1, zeroI32, max, b); val0 = b.create(b.getIndexType(), val0); val1 = b.create(b.getIndexType(), val1); }; // Linalg equivalent to the section below: // int16_t iy0 = apply_max(iy, 0); // int16_t iy1 = apply_min(iy + 1, IH - 1); // int16_t ix0 = apply_max(ix, 0); // int16_t ix1 = apply_min(ix + 1, IW - 1); Value x0, x1, y0, y1; getClampedIdxs(y0, y1, imageH, iy, hMax, b); getClampedIdxs(x0, x1, imageW, ix, wMax, b); Value y0x0 = b.create( input, ValueRange{batch, y0, x0, channel}); Value y0x1 = b.create( input, ValueRange{batch, y0, x1, channel}); Value y1x0 = b.create( input, ValueRange{batch, y1, x0, channel}); Value y1x1 = b.create( input, ValueRange{batch, y1, x1, channel}); if (floatingPointMode) { auto oneVal = b.create(b.getFloatAttr(floatTy, 1.0f)); auto interpolate = [&](Value val0, Value val1, Value delta, int inputSize, ImplicitLocOpBuilder &b) -> Value { if (inputSize == 1) return val0; Value oneMinusDelta = b.create(oneVal, delta); Value mul0 = b.create(val0, oneMinusDelta); Value mul1 = b.create(val1, delta); return b.create(mul0, mul1); }; // Linalg equivalent to the section below: // topAcc = v00 * (unit_x - dx); // topAcc += v01 * dx; Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b); // Linalg equivalent to the section below: // bottomAcc = v10 * (unit_x - dx); // bottomAcc += v11 * dx; Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b); // Linalg equivalent to the section below: // result = topAcc * (unit_y - dy) + bottomAcc * dy Value result = interpolate(topAcc, bottomAcc, dy, imageH, b); b.create(result); } else { // Perform in quantized space. y0x0 = b.create(resultETy, y0x0); y0x1 = b.create(resultETy, y0x1); y1x0 = b.create(resultETy, y1x0); y1x1 = b.create(resultETy, y1x1); const int64_t deltaBitwidth = dx.getType().getIntOrFloatBitWidth(); if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) { dx = b.create(resultETy, dx); dy = b.create(resultETy, dy); } Value yScaleNExt = yScaleN; Value xScaleNExt = xScaleN; const int64_t scaleBitwidth = xScaleN.getType().getIntOrFloatBitWidth(); if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) { yScaleNExt = b.create(resultETy, yScaleN); xScaleNExt = b.create(resultETy, xScaleN); } auto interpolate = [](Value val0, Value val1, Value weight1, Value scale, int inputSize, ImplicitLocOpBuilder &b) -> Value { if (inputSize == 1) return b.create(val0, scale); Value weight0 = b.create(scale, weight1); Value mul0 = b.create(val0, weight0); Value mul1 = b.create(val1, weight1); return b.create(mul0, mul1); }; Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b); Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b); Value result = interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b); b.create(result); } } } rewriter.replaceOp(op, resize); return success(); } }; // At the codegen level any identity operations should be removed. Any cases // where identity is load-bearing (e.g. cross device computation) should be // handled before lowering to codegen. template class IdentityNConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const final { rewriter.replaceOp(op, op.getOperation()->getOperands()); return success(); } }; template class ReduceConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SrcOp reduceOp, PatternRewriter &rewriter) const final { return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter); } }; class ReverseConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ReverseOp op, PatternRewriter &rewriter) const final { auto loc = op.getLoc(); Value input = op.getInput(); auto inputTy = cast(input.getType()); auto resultTy = cast(op.getType()); auto axis = op.getAxis(); SmallVector dynDims; for (int i = 0; i < inputTy.getRank(); i++) { if (inputTy.isDynamicDim(i)) { dynDims.push_back(rewriter.create(loc, input, i)); } } Value axisDimSize = rewriter.create(loc, input, axis); // First fill the output buffer with the init value. auto emptyTensor = rewriter .create(loc, inputTy.getShape(), inputTy.getElementType(), ArrayRef({dynDims})) .getResult(); SmallVector affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; rewriter.replaceOpWithNewOp( op, resultTy, ArrayRef({}), ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { llvm::SmallVector indices; for (unsigned int i = 0; i < inputTy.getRank(); i++) { Value index = rewriter.create(nestedLoc, i).getResult(); if (i == axis) { auto one = rewriter.create(nestedLoc, 1); auto sizeMinusOne = rewriter.create(nestedLoc, axisDimSize, one); index = rewriter.create(nestedLoc, sizeMinusOne, index); } indices.push_back(index); } auto extract = nestedBuilder.create( nestedLoc, input, indices); nestedBuilder.create(op.getLoc(), extract.getResult()); }); return success(); } }; // This converter translate a tile operation to a reshape, broadcast, reshape. // The first reshape minimally expands each tiled dimension to include a // proceding size-1 dim. This dim is then broadcasted to the appropriate // multiple. struct TileConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto input = op.getInput1(); auto inputTy = cast(input.getType()); auto inputShape = inputTy.getShape(); auto resultTy = cast(op.getType()); auto elementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); ArrayRef multiples = op.getMultiples(); // Broadcast the newly added dimensions to their appropriate multiple. SmallVector genericShape; for (int i = 0; i < rank; i++) { int64_t dim = multiples[i]; genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim); genericShape.push_back(inputShape[i]); } SmallVector dynDims; for (int i = 0; i < inputTy.getRank(); i++) { if (inputTy.isDynamicDim(i) || multiples[i] == -1) { dynDims.push_back(rewriter.create(loc, input, i)); } } auto emptyTensor = rewriter.create( op.getLoc(), genericShape, elementTy, dynDims); // We needs to map the input shape to the non-broadcasted dimensions. SmallVector dimExprs; dimExprs.reserve(rank); for (unsigned i = 0; i < rank; ++i) dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1)); auto readAffineMap = AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs, rewriter.getContext()); SmallVector affineMaps = { readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())}; auto genericOp = rewriter.create( loc, RankedTensorType::get(genericShape, elementTy), input, ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(genericShape.size()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { nestedBuilder.create(op.getLoc(), *args.begin()); }); rewriter.replaceOpWithNewOp( op, resultTy, genericOp.getResult(0), rewriter.getDenseI64ArrayAttr(resultTy.getShape())); return success(); } }; // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic // op, producing two output buffers. // // The first output buffer contains the index of the found maximum value. It is // initialized to 0 and is resulting integer type. // // The second output buffer contains the maximum value found. It is initialized // to the minimum representable value of the input element type. After being // populated by indexed_generic, this buffer is disgarded as only the index is // requested. // // The indexed_generic op updates both the maximum value and index if the // current value exceeds the running max. class ArgMaxConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp, PatternRewriter &rewriter) const final { auto loc = argmaxOp.getLoc(); Value input = argmaxOp.getInput(); auto inputTy = cast(input.getType()); auto resultTy = cast(argmaxOp.getOutput().getType()); auto inElementTy = inputTy.getElementType(); auto outElementTy = resultTy.getElementType(); int axis = argmaxOp.getAxis(); auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy); if (!isa(outElementTy)) return rewriter.notifyMatchFailure( argmaxOp, "tosa.arg_max to linalg.* requires integer-like result type"); SmallVector dynDims; for (int i = 0; i < inputTy.getRank(); i++) { if (inputTy.isDynamicDim(i) && i != axis) { dynDims.push_back(rewriter.create(loc, input, i)); } } // First fill the output buffer for the index. auto emptyTensorIdx = rewriter .create(loc, resultTy.getShape(), outElementTy, dynDims) .getResult(); auto fillValueIdx = rewriter.create( loc, rewriter.getIntegerAttr(outElementTy, 0)); auto filledTensorIdx = rewriter .create(loc, ValueRange{fillValueIdx}, ValueRange{emptyTensorIdx}) .result(); // Second fill the output buffer for the running max. auto emptyTensorMax = rewriter .create(loc, resultTy.getShape(), inElementTy, dynDims) .getResult(); auto fillValueMaxAttr = createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter); if (!fillValueMaxAttr) return rewriter.notifyMatchFailure( argmaxOp, "unsupported tosa.argmax element type"); auto fillValueMax = rewriter.create(loc, fillValueMaxAttr); auto filledTensorMax = rewriter .create(loc, ValueRange{fillValueMax}, ValueRange{emptyTensorMax}) .result(); // We need to reduce along the arg-max axis, with parallel operations along // the rest. SmallVector iteratorTypes; iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel); iteratorTypes[axis] = utils::IteratorType::reduction; SmallVector srcExprs; SmallVector dstExprs; for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) { srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); if (axis != i) dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); } bool didEncounterError = false; auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs}); auto linalgOp = rewriter.create( loc, ArrayRef({resultTy, resultMaxTy}), input, ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { auto newValue = blockArgs[0]; auto oldIndex = blockArgs[1]; auto oldValue = blockArgs[2]; Value newIndex = rewriter.create( nestedLoc, oldIndex.getType(), rewriter.create(loc, axis)); Value predicate; if (isa(inElementTy)) { predicate = rewriter.create( nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); } else if (isa(inElementTy)) { predicate = rewriter.create( nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue); } else { didEncounterError = true; return; } auto resultMax = rewriter.create( nestedLoc, predicate, newValue, oldValue); auto resultIndex = rewriter.create( nestedLoc, predicate, newIndex, oldIndex); nestedBuilder.create( nestedLoc, ValueRange({resultIndex, resultMax})); }); if (didEncounterError) return rewriter.notifyMatchFailure( argmaxOp, "unsupported tosa.argmax element type"); rewriter.replaceOp(argmaxOp, linalgOp.getResult(0)); return success(); } }; class GatherConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto input = adaptor.getOperands()[0]; auto indices = adaptor.getOperands()[1]; auto valuesTy = dyn_cast_or_null(op.getValues().getType()); auto resultTy = cast(op.getType()); if (!valuesTy) return rewriter.notifyMatchFailure(op, "unranked tensors not supported"); auto dynamicDims = inferDynamicDimsForGather( rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices()); auto resultElementTy = resultTy.getElementType(); auto loc = op.getLoc(); auto emptyTensor = rewriter .create(loc, resultTy.getShape(), resultElementTy, dynamicDims) .getResult(); SmallVector affineMaps = { AffineMap::get( /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0, {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)}, rewriter.getContext()), rewriter.getMultiDimIdentityMap(resultTy.getRank())}; auto genericOp = rewriter.create( loc, ArrayRef({resultTy}), ValueRange{indices}, ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &b, Location loc, ValueRange args) { auto indexValue = args[0]; auto index0 = rewriter.create(loc, 0); Value index1 = rewriter.create( loc, rewriter.getIndexType(), indexValue); auto index2 = rewriter.create(loc, 2); Value extract = rewriter.create( loc, input, ValueRange{index0, index1, index2}); rewriter.create(loc, extract); }); rewriter.replaceOp(op, genericOp.getResult(0)); return success(); } static llvm::SmallVector inferDynamicDimsForGather(OpBuilder &builder, Location loc, Value values, Value indices) { llvm::SmallVector results; auto addDynamicDimension = [&](Value source, int64_t dim) { auto sz = tensor::getMixedSize(builder, loc, source, dim); if (auto dimValue = llvm::dyn_cast_if_present(sz)) results.push_back(dimValue); }; addDynamicDimension(values, 0); addDynamicDimension(indices, 1); addDynamicDimension(values, 2); return results; } }; // Lowerings the TableOp to a series of gathers and numerica operations. This // includes interpolation between the high/low values. For the I8 varient, this // simplifies to a single gather operation. class TableConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::TableOp op, PatternRewriter &rewriter) const final { auto loc = op.getLoc(); Value input = op.getInput(); Value table = op.getTable(); auto inputTy = cast(input.getType()); auto tableTy = cast(table.getType()); auto resultTy = cast(op.getType()); auto inputElementTy = inputTy.getElementType(); auto tableElementTy = tableTy.getElementType(); auto resultElementTy = resultTy.getElementType(); SmallVector dynDims; for (int i = 0; i < resultTy.getRank(); ++i) { if (inputTy.isDynamicDim(i)) { dynDims.push_back( rewriter.create(loc, op.getOperand(0), i)); } } auto emptyTensor = rewriter .create(loc, resultTy.getShape(), resultElementTy, dynDims) .getResult(); SmallVector affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank()), rewriter.getMultiDimIdentityMap(resultTy.getRank())}; auto genericOp = rewriter.create( loc, resultTy, ValueRange({input}), ValueRange{emptyTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank())); rewriter.replaceOp(op, genericOp.getResult(0)); { OpBuilder::InsertionGuard regionGuard(rewriter); Block *block = rewriter.createBlock( &genericOp.getRegion(), genericOp.getRegion().end(), TypeRange({inputElementTy, resultElementTy}), {loc, loc}); auto inputValue = block->getArgument(0); rewriter.setInsertionPointToStart(block); if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) && resultElementTy.isInteger(8)) { Value index = rewriter.create( loc, rewriter.getIndexType(), inputValue); Value offset = rewriter.create(loc, 128); index = rewriter.create(loc, rewriter.getIndexType(), index, offset); Value extract = rewriter.create(loc, table, ValueRange{index}); rewriter.create(loc, extract); return success(); } if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) && resultElementTy.isInteger(32)) { Value extend = rewriter.create( loc, rewriter.getI32Type(), inputValue); auto offset = rewriter.create( loc, rewriter.getI32IntegerAttr(32768)); auto seven = rewriter.create( loc, rewriter.getI32IntegerAttr(7)); auto one = rewriter.create( loc, rewriter.getI32IntegerAttr(1)); auto b1111111 = rewriter.create( loc, rewriter.getI32IntegerAttr(127)); // Compute the index and fractional part from the input value: // value = value + 32768 // index = value >> 7; // fraction = 0x01111111 & value auto extendAdd = rewriter.create(loc, extend, offset); Value index = rewriter.create(loc, extendAdd, seven); Value fraction = rewriter.create(loc, extendAdd, b1111111); // Extract the base and next values from the table. // base = (int32_t) table[index]; // next = (int32_t) table[index + 1]; Value indexPlusOne = rewriter.create(loc, index, one); index = rewriter.create( loc, rewriter.getIndexType(), index); indexPlusOne = rewriter.create( loc, rewriter.getIndexType(), indexPlusOne); Value base = rewriter.create(loc, table, ValueRange{index}); Value next = rewriter.create( loc, table, ValueRange{indexPlusOne}); base = rewriter.create(loc, rewriter.getI32Type(), base); next = rewriter.create(loc, rewriter.getI32Type(), next); // Use the fractional part to interpolate between the input values: // result = (base << 7) + (next - base) * fraction Value baseScaled = rewriter.create(loc, base, seven); Value diff = rewriter.create(loc, next, base); Value diffScaled = rewriter.create(loc, diff, fraction); Value result = rewriter.create(loc, baseScaled, diffScaled); rewriter.create(loc, result); return success(); } } return rewriter.notifyMatchFailure( op, "unable to create body for tosa.table op"); } }; struct RFFT2dConverter final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; static bool isRankedTensor(Type type) { return isa(type); } static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc, OpFoldResult ofr) { auto one = builder.create(loc, 1); auto two = builder.create(loc, 2); auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr); auto divBy2 = builder.createOrFold(loc, value, two); auto plusOne = builder.createOrFold(loc, divBy2, one); return getAsOpFoldResult(plusOne); } static RankedTensorType computeOutputShape(OpBuilder &builder, Location loc, Value input, llvm::SmallVectorImpl &dynamicSizes) { // Get [N, H, W] auto dims = tensor::getMixedSizes(builder, loc, input); // Set W = (W / 2) + 1 to account for the half-sized W dimension of the // output tensors. dims[2] = halfPlusOne(builder, loc, dims[2]); llvm::SmallVector staticSizes; dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes); auto elementType = input.getType().cast().getElementType(); return RankedTensorType::get(staticSizes, elementType); } static Value createZeroTensor(PatternRewriter &rewriter, Location loc, RankedTensorType type, llvm::ArrayRef dynamicSizes) { auto emptyTensor = rewriter.create(loc, type, dynamicSizes); auto fillValueAttr = rewriter.getZeroAttr(type.getElementType()); auto fillValue = rewriter.create(loc, fillValueAttr); auto filledTensor = rewriter .create(loc, ValueRange{fillValue}, ValueRange{emptyTensor}) .result(); return filledTensor; } static Value castIndexToFloat(OpBuilder &builder, Location loc, FloatType type, Value value) { auto integerVal = builder.create( loc, type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type() : builder.getI32Type(), value); return builder.create(loc, type, integerVal); } static Value createLinalgIndex(OpBuilder &builder, Location loc, FloatType type, int64_t index) { auto indexVal = builder.create(loc, index); return castIndexToFloat(builder, loc, type, indexVal); } template static llvm::SmallVector affineDimsExpr(OpBuilder &builder, Args... args) { return {builder.getAffineDimExpr(args)...}; } LogicalResult matchAndRewrite(RFFT2dOp rfft2d, PatternRewriter &rewriter) const override { if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) || !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) { return rewriter.notifyMatchFailure(rfft2d, "only supports ranked tensors"); } auto loc = rfft2d.getLoc(); auto input = rfft2d.getInput(); auto elementType = input.getType().cast().getElementType().cast(); // Compute the output type and set of dynamic sizes llvm::SmallVector dynamicSizes; auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes); // Iterator types for the linalg.generic implementation llvm::SmallVector iteratorTypes = { utils::IteratorType::parallel, utils::IteratorType::parallel, utils::IteratorType::parallel, utils::IteratorType::reduction, utils::IteratorType::reduction}; // Inputs/outputs to the linalg.generic implementation llvm::SmallVector genericOpInputs = {input}; llvm::SmallVector genericOpOutputs = { createZeroTensor(rewriter, loc, outputType, dynamicSizes), createZeroTensor(rewriter, loc, outputType, dynamicSizes)}; // Indexing maps for input and output tensors auto indexingMaps = AffineMap::inferFromExprList(llvm::ArrayRef{ affineDimsExpr(rewriter, 0, 3, 4), affineDimsExpr(rewriter, 0, 1, 2), affineDimsExpr(rewriter, 0, 1, 2)}); // Width and height dimensions of the original input. auto dimH = rewriter.createOrFold(loc, input, 1); auto dimW = rewriter.createOrFold(loc, input, 2); // Constants and dimension sizes auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586); auto twoPi = rewriter.create(loc, twoPiAttr); auto constH = castIndexToFloat(rewriter, loc, elementType, dimH); auto constW = castIndexToFloat(rewriter, loc, elementType, dimW); auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) { Value valReal = args[0]; Value sumReal = args[1]; Value sumImag = args[2]; // Indices for angle computation auto oy = createLinalgIndex(builder, loc, elementType, 1); auto ox = createLinalgIndex(builder, loc, elementType, 2); auto iy = createLinalgIndex(builder, loc, elementType, 3); auto ix = createLinalgIndex(builder, loc, elementType, 4); // angle = 2 * pi() * ((iy * oy) / H + (ix * ox) / W) auto iyXoy = builder.create(loc, iy, oy); auto ixXox = builder.create(loc, ix, ox); auto yComponent = builder.create(loc, iyXoy, constH); auto xComponent = builder.create(loc, ixXox, constW); auto sumXY = builder.create(loc, yComponent, xComponent); auto angle = builder.create(loc, twoPi, sumXY); // realComponent = valReal * cos(angle) // imagComponent = valReal * sin(angle) auto cosAngle = builder.create(loc, angle); auto sinAngle = builder.create(loc, angle); auto realComponent = builder.create(loc, valReal, cosAngle); auto imagComponent = builder.create(loc, valReal, sinAngle); // outReal = sumReal + realComponent // outImag = sumImag - imagComponent auto outReal = builder.create(loc, sumReal, realComponent); auto outImag = builder.create(loc, sumImag, imagComponent); builder.create(loc, ValueRange{outReal, outImag}); }; rewriter.replaceOpWithNewOp( rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs, indexingMaps, iteratorTypes, buildBody); return success(); } }; struct FFT2dConverter final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(FFT2dOp fft2d, PatternRewriter &rewriter) const override { if (!llvm::all_of(fft2d->getOperandTypes(), RFFT2dConverter::isRankedTensor) || !llvm::all_of(fft2d->getResultTypes(), RFFT2dConverter::isRankedTensor)) { return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors"); } Location loc = fft2d.getLoc(); Value input_real = fft2d.getInputReal(); Value input_imag = fft2d.getInputImag(); BoolAttr inverse = fft2d.getInverseAttr(); auto real_el_ty = cast( cast(input_real.getType()).getElementType()); [[maybe_unused]] auto imag_el_ty = cast( cast(input_imag.getType()).getElementType()); assert(real_el_ty == imag_el_ty); // Compute the output type and set of dynamic sizes SmallVector dynamicSizes; // Get [N, H, W] auto dims = tensor::getMixedSizes(rewriter, loc, input_real); SmallVector staticSizes; dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes); auto outputType = RankedTensorType::get(staticSizes, real_el_ty); // Iterator types for the linalg.generic implementation SmallVector iteratorTypes = { utils::IteratorType::parallel, utils::IteratorType::parallel, utils::IteratorType::parallel, utils::IteratorType::reduction, utils::IteratorType::reduction}; // Inputs/outputs to the linalg.generic implementation SmallVector genericOpInputs = {input_real, input_imag}; SmallVector genericOpOutputs = { RFFT2dConverter::createZeroTensor(rewriter, loc, outputType, dynamicSizes), RFFT2dConverter::createZeroTensor(rewriter, loc, outputType, dynamicSizes)}; // Indexing maps for input and output tensors auto indexingMaps = AffineMap::inferFromExprList( ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4), RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4), RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2), RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)}); // Width and height dimensions of the original input. auto dimH = rewriter.createOrFold(loc, input_real, 1); auto dimW = rewriter.createOrFold(loc, input_real, 2); // Constants and dimension sizes auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586); auto twoPi = rewriter.create(loc, twoPiAttr); Value constH = RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH); Value constW = RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW); auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) { Value valReal = args[0]; Value valImag = args[1]; Value sumReal = args[2]; Value sumImag = args[3]; // Indices for angle computation Value oy = RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 1); Value ox = RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 2); Value iy = RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 3); Value ix = RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 4); // float_t angle = sign_val * 2 * pi() * ((iy * oy) / H + (ix * ox) / W); auto iyXoy = builder.create(loc, iy, oy); auto ixXox = builder.create(loc, ix, ox); auto yComponent = builder.create(loc, iyXoy, constH); auto xComponent = builder.create(loc, ixXox, constW); auto sumXY = builder.create(loc, yComponent, xComponent); auto angle = builder.create(loc, twoPi, sumXY); if (inverse.getValue()) { angle = builder.create( loc, angle, rewriter.create( loc, rewriter.getFloatAttr(real_el_ty, -1.0))); } // realComponent = val_real * cos(a) + val_imag * sin(a); // imagComponent = -val_real * sin(a) + val_imag * cos(a); auto cosAngle = builder.create(loc, angle); auto sinAngle = builder.create(loc, angle); auto rcos = builder.create(loc, valReal, cosAngle); auto rsin = builder.create(loc, valImag, sinAngle); auto realComponent = builder.create(loc, rcos, rsin); auto icos = builder.create(loc, valImag, cosAngle); auto isin = builder.create(loc, valReal, sinAngle); auto imagComponent = builder.create(loc, icos, isin); // outReal = sumReal + realComponent // outImag = sumImag - imagComponent auto outReal = builder.create(loc, sumReal, realComponent); auto outImag = builder.create(loc, sumImag, imagComponent); builder.create(loc, ValueRange{outReal, outImag}); }; rewriter.replaceOpWithNewOp( fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs, indexingMaps, iteratorTypes, buildBody); return success(); } }; } // namespace void mlir::tosa::populateTosaToLinalgConversionPatterns( RewritePatternSet *patterns) { // We have multiple resize coverters to handle degenerate cases. patterns->add(patterns->getContext(), /*benefit=*/100); patterns->add(patterns->getContext(), /*benefit=*/200); patterns->add(patterns->getContext(), /*benefit=*/300); patterns->add< // clang-format off PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, IdentityNConverter, ReduceConverter, ReduceConverter, ReduceConverter, ReduceConverter, ReduceConverter, ReduceConverter, ArgMaxConverter, GatherConverter, RescaleConverter, ReverseConverter, RFFT2dConverter, FFT2dConverter, TableConverter, TileConverter>(patterns->getContext()); // clang-format on }