//===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // These rewriters lower from the Tosa to the Tensor dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; using namespace tosa; static bool findIntermediateShape(ArrayRef lhsShape, ArrayRef rhsShape, SmallVector &intermediateShape, bool isDynamic) { if (isDynamic) { // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1 intermediateShape = {ShapedType::kDynamic}; return true; } if (lhsShape.empty() || rhsShape.empty()) { intermediateShape = {}; return true; } unsigned currLhsDim = 0, currRhsDim = 0; while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) { int64_t rhsSize = rhsShape[currRhsDim]; int64_t lhsSize = lhsShape[currLhsDim]; while (lhsSize != rhsSize && currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) { if (lhsSize < rhsSize) { currLhsDim++; if (currLhsDim < lhsShape.size()) { lhsSize *= lhsShape[currLhsDim]; } } else { currRhsDim++; if (currRhsDim < rhsShape.size()) { rhsSize *= rhsShape[currRhsDim]; } } } if (lhsSize == rhsSize) { intermediateShape.push_back(lhsSize); } currRhsDim++; currLhsDim++; } // If the iterators didn't reach the end and their leftover dimensions are not // equal to 1 an intermediate shape was not found. while (currLhsDim < lhsShape.size()) { if (lhsShape[currLhsDim++] != 1) { return false; } } while (currRhsDim < rhsShape.size()) { if (rhsShape[currRhsDim++] != 1) { return false; } } return true; } static bool createReassociationMapsForCollapse( PatternRewriter &rewriter, ArrayRef srcShape, ArrayRef dstShape, SmallVector &reassociationMap, bool isDynamic) { // If the shape is dynamic, create a map for collapsing into one dimension. if (isDynamic) { SmallVector exprs; for (int i = 0, s = srcShape.size(); i < s; ++i) exprs.push_back(rewriter.getAffineDimExpr(i)); reassociationMap = {exprs}; return true; } if (dstShape.empty()) { reassociationMap = {}; return true; } reassociationMap.resize(dstShape.size()); unsigned currSrcDim = 0, currDstDim = 0; while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { int64_t dstSize = dstShape[currDstDim]; int64_t srcSize = srcShape[currSrcDim]; while (srcSize < dstSize && currSrcDim < srcShape.size()) { reassociationMap[currDstDim].push_back( rewriter.getAffineDimExpr(currSrcDim++)); srcSize *= srcShape[currSrcDim]; } if (srcSize == dstSize) { reassociationMap[currDstDim].push_back( rewriter.getAffineDimExpr(currSrcDim++)); // If the next dim in collapsedShape is not 1, treat subsequent dims in // expandedShape which are 1 to be collapsed. if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) { while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { reassociationMap[currDstDim].push_back( rewriter.getAffineDimExpr(currSrcDim++)); } } } currDstDim++; } // If both iterators didn't reach the end, we have leftover dimentions which // implies that we have a mismatch in shape. return currSrcDim == srcShape.size() && currDstDim == dstShape.size(); } namespace { Value createCollapse(ConversionPatternRewriter &rewriter, Location loc, ShapedType resultTy, Value operand) { ShapedType operandTy = cast(operand.getType()); if (resultTy == operandTy) return operand; bool isDynamic = !operandTy.hasStaticShape(); if (isDynamic && resultTy.getRank() != 1) { (void)rewriter.notifyMatchFailure( loc, "Cannot collapse dynamic dims to more than one dimension"); return {}; } SmallVector reassociationMap; if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(), resultTy.getShape(), reassociationMap, isDynamic)) { (void)rewriter.notifyMatchFailure( loc, "tosa.reshape Attempting to collapse into an incompatible shape"); return {}; } SmallVector intermediateShape; if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(), intermediateShape, isDynamic)) { (void)rewriter.notifyMatchFailure( loc, "tosa.reshape Cannot collapse into given shape"); return {}; } return rewriter.create(loc, resultTy, operand, reassociationMap); } Value createExpand(ConversionPatternRewriter &rewriter, Location loc, ShapedType resultTy, Value operand) { ShapedType operandTy = cast(operand.getType()); if (resultTy == operandTy) return operand; bool isDynamic = !operandTy.hasStaticShape(); if (isDynamic && operandTy.getRank() != 1) { (void)rewriter.notifyMatchFailure( loc, "Cannot expand dynamic dims from more than one dimension"); return {}; } SmallVector reassociationMap; if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(), operandTy.getShape(), reassociationMap, isDynamic)) { (void)rewriter.notifyMatchFailure( loc, "tosa.reshape Attempting to expand into an incompatible shape"); return {}; } SmallVector intermediateShape; if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(), intermediateShape, isDynamic) || intermediateShape != operandTy.getShape()) { (void)rewriter.notifyMatchFailure( loc, "tosa.reshape Cannot expand into given shape"); return {}; } return rewriter.create(loc, resultTy, operand, reassociationMap); } class ReshapeConverterCollapseExpand : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { ShapedType operandTy = cast(adaptor.getInput1().getType()); ShapedType resultTy = cast(reshape.getType()); bool isDynamic = !operandTy.hasStaticShape(); SmallVector intermediateShape; if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(), intermediateShape, isDynamic)) { return rewriter.notifyMatchFailure( reshape, "tosa.reshape Cannot identify an intermediate shape between " "the given two shapes"); } auto intermediateTy = RankedTensorType::get( intermediateShape, reshape.getType().getElementType()); Value collapse = createCollapse(rewriter, reshape.getLoc(), intermediateTy, adaptor.getInput1()); if (!collapse) return failure(); Value expand = createExpand(rewriter, reshape.getLoc(), resultTy, collapse); if (!expand) return failure(); rewriter.replaceOp(reshape, expand); return success(); } }; class SliceConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { Location loc = sliceOp.getLoc(); Value input = adaptor.getInput(); ShapedType resultType = cast(sliceOp.getType()); if (llvm::isa(resultType)) return failure(); SmallVector strides, sizes; ArrayRef starts = sliceOp.getStart(); strides.resize(cast(sliceOp.getType()).getRank(), 1); SmallVector dynSizes; for (const auto &i : llvm::enumerate(sliceOp.getSize())) { int64_t size = i.value(); size_t index = i.index(); sizes.push_back(size == -1 ? ShapedType::kDynamic : size); if (!ShapedType::isDynamic(sizes.back())) continue; auto dim = rewriter.create(loc, input, index); auto offset = rewriter.create( loc, rewriter.getIndexAttr(starts[index])); dynSizes.push_back(rewriter.create(loc, dim, offset)); } auto newSliceOp = rewriter.create( sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes, ValueRange({}), rewriter.getDenseI64ArrayAttr(starts), rewriter.getDenseI64ArrayAttr(sizes), rewriter.getDenseI64ArrayAttr(strides)); rewriter.replaceOp(sliceOp, newSliceOp.getResult()); return success(); } }; class PadConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::PadOp padOp, PatternRewriter &rewriter) const final { auto loc = padOp.getLoc(); auto input = padOp.getInput1(); auto padding = padOp.getPadding(); ShapedType inputTy = cast(input.getType()); Type elementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); // Setup the default constantAttr. Value padConstant; if (padOp.getPadConst()) { padConstant = rewriter.createOrFold( loc, padOp.getPadConst(), ValueRange({})); } else { TypedAttr constantAttr; if (isa(elementTy)) { constantAttr = rewriter.getFloatAttr(elementTy, 0.0); } else if (isa(elementTy) && !padOp.getQuantizationInfo()) { constantAttr = rewriter.getIntegerAttr(elementTy, 0); } else if (isa(elementTy) && padOp.getQuantizationInfo()) { int64_t value = padOp.getQuantizationInfo()->getInputZp(); constantAttr = rewriter.getIntegerAttr(elementTy, value); } if (constantAttr) padConstant = rewriter.create(loc, constantAttr); } if (!padConstant) { return rewriter.notifyMatchFailure( padOp, "tosa.pad was unable to determine the pad constant value."); } Value lowIndex = rewriter.create(loc, rewriter.getIndexAttr(0)); Value highIndex = rewriter.create(loc, rewriter.getIndexAttr(1)); SmallVector lowValues; SmallVector highValues; lowValues.reserve(rank); highValues.reserve(rank); for (int i = 0; i < rank; i++) { Value inputIndex = rewriter.createOrFold(loc, i); Value lowVal = rewriter.createOrFold( loc, padding, ValueRange({inputIndex, lowIndex})); Value highVal = rewriter.createOrFold( loc, padding, ValueRange({inputIndex, highIndex})); lowVal = rewriter.createOrFold( loc, rewriter.getIndexType(), lowVal); highVal = rewriter.createOrFold( loc, rewriter.getIndexType(), highVal); lowValues.push_back(lowVal); highValues.push_back(highVal); } auto newPadOp = rewriter.create( loc, padOp.getType(), input, lowValues, highValues, padConstant); rewriter.replaceOp(padOp, newPadOp.getResult()); return success(); } }; struct ConcatConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto resultType = dyn_cast(op.getType()); Location loc = op.getLoc(); int axis = op.getAxis(); Value axisValue = rewriter.createOrFold( loc, rewriter.getIndexAttr(axis)); int64_t rank = resultType.getRank(); SmallVector strides(rank, rewriter.getIndexAttr(1)); SmallVector offsets(rank, rewriter.getIndexAttr(0)); SmallVector sizes = tensor::getMixedSizes(rewriter, op.getLoc(), adaptor.getOperands()[0]); // Pre-compute the offsets along the axis dimension. // The axisOffsets will be of size rank + 1, where the last value // will hold the total size of the tensor along the 'axis' dimension. SmallVector axisOffsets; axisOffsets.push_back(rewriter.getIndexAttr(0)); axisOffsets.push_back(sizes[axis]); for (auto arg : adaptor.getOperands().drop_front()) { auto size = rewriter.createOrFold(loc, arg, axisValue); auto currentOffset = getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back()); auto total = rewriter.createOrFold(loc, currentOffset, size); axisOffsets.push_back(getAsOpFoldResult(total)); } sizes[axis] = axisOffsets.back(); // Compute the dynamic sizes of the tensor.empty operation. // This is based off of the specified result type of the tosa.concat // operation, since we don't want to change the result type of the operation // during the conversion. SmallVector dynDims; for (int64_t i = 0; i < rank; ++i) { if (resultType.isDynamicDim(i)) { dynDims.push_back( getValueOrCreateConstantIndexOp(rewriter, loc, sizes[i])); } } Value result = rewriter.create( loc, resultType.getShape(), resultType.getElementType(), dynDims); for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) { auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg); offsets[axis] = offset; result = rewriter.createOrFold( loc, arg, result, offsets, sizes, strides); } rewriter.replaceOp(op, result); return success(); } }; } // namespace void mlir::tosa::populateTosaToTensorConversionPatterns( RewritePatternSet *patterns) { patterns->add( patterns->getContext()); patterns->add(patterns->getContext()); }