bolt/deps/llvm-18.1.8/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
2025-02-14 19:21:04 +01:00

423 lines
15 KiB
C++

//===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// These rewriters lower from the Tosa to the Tensor dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
using namespace tosa;
static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
ArrayRef<int64_t> rhsShape,
SmallVector<int64_t> &intermediateShape,
bool isDynamic) {
if (isDynamic) {
// TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
intermediateShape = {ShapedType::kDynamic};
return true;
}
if (lhsShape.empty() || rhsShape.empty()) {
intermediateShape = {};
return true;
}
unsigned currLhsDim = 0, currRhsDim = 0;
while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
int64_t rhsSize = rhsShape[currRhsDim];
int64_t lhsSize = lhsShape[currLhsDim];
while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
currRhsDim < rhsShape.size()) {
if (lhsSize < rhsSize) {
currLhsDim++;
if (currLhsDim < lhsShape.size()) {
lhsSize *= lhsShape[currLhsDim];
}
} else {
currRhsDim++;
if (currRhsDim < rhsShape.size()) {
rhsSize *= rhsShape[currRhsDim];
}
}
}
if (lhsSize == rhsSize) {
intermediateShape.push_back(lhsSize);
}
currRhsDim++;
currLhsDim++;
}
// If the iterators didn't reach the end and their leftover dimensions are not
// equal to 1 an intermediate shape was not found.
while (currLhsDim < lhsShape.size()) {
if (lhsShape[currLhsDim++] != 1) {
return false;
}
}
while (currRhsDim < rhsShape.size()) {
if (rhsShape[currRhsDim++] != 1) {
return false;
}
}
return true;
}
static bool createReassociationMapsForCollapse(
PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
ArrayRef<int64_t> dstShape,
SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
// If the shape is dynamic, create a map for collapsing into one dimension.
if (isDynamic) {
SmallVector<AffineExpr, 2> exprs;
for (int i = 0, s = srcShape.size(); i < s; ++i)
exprs.push_back(rewriter.getAffineDimExpr(i));
reassociationMap = {exprs};
return true;
}
if (dstShape.empty()) {
reassociationMap = {};
return true;
}
reassociationMap.resize(dstShape.size());
unsigned currSrcDim = 0, currDstDim = 0;
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
int64_t dstSize = dstShape[currDstDim];
int64_t srcSize = srcShape[currSrcDim];
while (srcSize < dstSize && currSrcDim < srcShape.size()) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
srcSize *= srcShape[currSrcDim];
}
if (srcSize == dstSize) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
// If the next dim in collapsedShape is not 1, treat subsequent dims in
// expandedShape which are 1 to be collapsed.
if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
}
}
}
currDstDim++;
}
// If both iterators didn't reach the end, we have leftover dimentions which
// implies that we have a mismatch in shape.
return currSrcDim == srcShape.size() && currDstDim == dstShape.size();
}
namespace {
Value createCollapse(ConversionPatternRewriter &rewriter, Location loc,
ShapedType resultTy, Value operand) {
ShapedType operandTy = cast<ShapedType>(operand.getType());
if (resultTy == operandTy)
return operand;
bool isDynamic = !operandTy.hasStaticShape();
if (isDynamic && resultTy.getRank() != 1) {
(void)rewriter.notifyMatchFailure(
loc, "Cannot collapse dynamic dims to more than one dimension");
return {};
}
SmallVector<ReassociationExprs, 4> reassociationMap;
if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
resultTy.getShape(),
reassociationMap, isDynamic)) {
(void)rewriter.notifyMatchFailure(
loc, "tosa.reshape Attempting to collapse into an incompatible shape");
return {};
}
SmallVector<int64_t> intermediateShape;
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
intermediateShape, isDynamic)) {
(void)rewriter.notifyMatchFailure(
loc, "tosa.reshape Cannot collapse into given shape");
return {};
}
return rewriter.create<tensor::CollapseShapeOp>(loc, resultTy, operand,
reassociationMap);
}
Value createExpand(ConversionPatternRewriter &rewriter, Location loc,
ShapedType resultTy, Value operand) {
ShapedType operandTy = cast<ShapedType>(operand.getType());
if (resultTy == operandTy)
return operand;
bool isDynamic = !operandTy.hasStaticShape();
if (isDynamic && operandTy.getRank() != 1) {
(void)rewriter.notifyMatchFailure(
loc, "Cannot expand dynamic dims from more than one dimension");
return {};
}
SmallVector<ReassociationExprs, 4> reassociationMap;
if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
operandTy.getShape(),
reassociationMap, isDynamic)) {
(void)rewriter.notifyMatchFailure(
loc, "tosa.reshape Attempting to expand into an incompatible shape");
return {};
}
SmallVector<int64_t> intermediateShape;
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
intermediateShape, isDynamic) ||
intermediateShape != operandTy.getShape()) {
(void)rewriter.notifyMatchFailure(
loc, "tosa.reshape Cannot expand into given shape");
return {};
}
return rewriter.create<tensor::ExpandShapeOp>(loc, resultTy, operand,
reassociationMap);
}
class ReshapeConverterCollapseExpand
: public OpConversionPattern<tosa::ReshapeOp> {
public:
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
ShapedType resultTy = cast<ShapedType>(reshape.getType());
bool isDynamic = !operandTy.hasStaticShape();
SmallVector<int64_t> intermediateShape;
if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
intermediateShape, isDynamic)) {
return rewriter.notifyMatchFailure(
reshape, "tosa.reshape Cannot identify an intermediate shape between "
"the given two shapes");
}
auto intermediateTy = RankedTensorType::get(
intermediateShape, reshape.getType().getElementType());
Value collapse = createCollapse(rewriter, reshape.getLoc(), intermediateTy,
adaptor.getInput1());
if (!collapse)
return failure();
Value expand = createExpand(rewriter, reshape.getLoc(), resultTy, collapse);
if (!expand)
return failure();
rewriter.replaceOp(reshape, expand);
return success();
}
};
class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
public:
using OpConversionPattern<tosa::SliceOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Location loc = sliceOp.getLoc();
Value input = adaptor.getInput();
ShapedType resultType = cast<ShapedType>(sliceOp.getType());
if (llvm::isa<UnrankedTensorType>(resultType))
return failure();
SmallVector<int64_t> strides, sizes;
ArrayRef<int64_t> starts = sliceOp.getStart();
strides.resize(cast<ShapedType>(sliceOp.getType()).getRank(), 1);
SmallVector<Value> dynSizes;
for (const auto &i : llvm::enumerate(sliceOp.getSize())) {
int64_t size = i.value();
size_t index = i.index();
sizes.push_back(size == -1 ? ShapedType::kDynamic : size);
if (!ShapedType::isDynamic(sizes.back()))
continue;
auto dim = rewriter.create<tensor::DimOp>(loc, input, index);
auto offset = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(starts[index]));
dynSizes.push_back(rewriter.create<arith::SubIOp>(loc, dim, offset));
}
auto newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes,
ValueRange({}), rewriter.getDenseI64ArrayAttr(starts),
rewriter.getDenseI64ArrayAttr(sizes),
rewriter.getDenseI64ArrayAttr(strides));
rewriter.replaceOp(sliceOp, newSliceOp.getResult());
return success();
}
};
class PadConverter : public OpRewritePattern<tosa::PadOp> {
public:
using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::PadOp padOp,
PatternRewriter &rewriter) const final {
auto loc = padOp.getLoc();
auto input = padOp.getInput1();
auto padding = padOp.getPadding();
ShapedType inputTy = cast<ShapedType>(input.getType());
Type elementTy = inputTy.getElementType();
int64_t rank = inputTy.getRank();
// Setup the default constantAttr.
Value padConstant;
if (padOp.getPadConst()) {
padConstant = rewriter.createOrFold<tensor::ExtractOp>(
loc, padOp.getPadConst(), ValueRange({}));
} else {
TypedAttr constantAttr;
if (isa<FloatType>(elementTy)) {
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
} else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) {
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
} else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) {
int64_t value = padOp.getQuantizationInfo()->getInputZp();
constantAttr = rewriter.getIntegerAttr(elementTy, value);
}
if (constantAttr)
padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
}
if (!padConstant) {
return rewriter.notifyMatchFailure(
padOp, "tosa.pad was unable to determine the pad constant value.");
}
Value lowIndex =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value highIndex =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult, 3> lowValues;
SmallVector<OpFoldResult, 3> highValues;
lowValues.reserve(rank);
highValues.reserve(rank);
for (int i = 0; i < rank; i++) {
Value inputIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
loc, padding, ValueRange({inputIndex, lowIndex}));
Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
loc, padding, ValueRange({inputIndex, highIndex}));
lowVal = rewriter.createOrFold<arith::IndexCastOp>(
loc, rewriter.getIndexType(), lowVal);
highVal = rewriter.createOrFold<arith::IndexCastOp>(
loc, rewriter.getIndexType(), highVal);
lowValues.push_back(lowVal);
highValues.push_back(highVal);
}
auto newPadOp = rewriter.create<tensor::PadOp>(
loc, padOp.getType(), input, lowValues, highValues, padConstant);
rewriter.replaceOp(padOp, newPadOp.getResult());
return success();
}
};
struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto resultType = dyn_cast<RankedTensorType>(op.getType());
Location loc = op.getLoc();
int axis = op.getAxis();
Value axisValue = rewriter.createOrFold<arith::ConstantOp>(
loc, rewriter.getIndexAttr(axis));
int64_t rank = resultType.getRank();
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes =
tensor::getMixedSizes(rewriter, op.getLoc(), adaptor.getOperands()[0]);
// Pre-compute the offsets along the axis dimension.
// The axisOffsets will be of size rank + 1, where the last value
// will hold the total size of the tensor along the 'axis' dimension.
SmallVector<OpFoldResult> axisOffsets;
axisOffsets.push_back(rewriter.getIndexAttr(0));
axisOffsets.push_back(sizes[axis]);
for (auto arg : adaptor.getOperands().drop_front()) {
auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
auto currentOffset =
getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back());
auto total =
rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size);
axisOffsets.push_back(getAsOpFoldResult(total));
}
sizes[axis] = axisOffsets.back();
// Compute the dynamic sizes of the tensor.empty operation.
// This is based off of the specified result type of the tosa.concat
// operation, since we don't want to change the result type of the operation
// during the conversion.
SmallVector<Value> dynDims;
for (int64_t i = 0; i < rank; ++i) {
if (resultType.isDynamicDim(i)) {
dynDims.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, sizes[i]));
}
}
Value result = rewriter.create<tensor::EmptyOp>(
loc, resultType.getShape(), resultType.getElementType(), dynDims);
for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) {
auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg);
offsets[axis] = offset;
result = rewriter.createOrFold<tensor::InsertSliceOp>(
loc, arg, result, offsets, sizes, strides);
}
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
void mlir::tosa::populateTosaToTensorConversionPatterns(
RewritePatternSet *patterns) {
patterns->add<SliceConverter, PadConverter, ConcatConverter>(
patterns->getContext());
patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext());
}