1952 lines
72 KiB
C++
1952 lines
72 KiB
C++
|
//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
//
|
||
|
// \file
|
||
|
// This file implements the TOSA Specification:
|
||
|
// https://developer.mlplatform.org/w/tosa/
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||
|
#include "mlir/Dialect/Quant/QuantOps.h"
|
||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||
|
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
|
||
|
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
|
||
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
||
|
#include "mlir/IR/BuiltinTypes.h"
|
||
|
#include "mlir/IR/DialectImplementation.h"
|
||
|
#include "mlir/IR/Matchers.h"
|
||
|
#include "mlir/IR/PatternMatch.h"
|
||
|
#include "mlir/IR/TypeUtilities.h"
|
||
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||
|
#include "mlir/Transforms/InliningUtils.h"
|
||
|
#include "llvm/ADT/APFloat.h"
|
||
|
#include "llvm/ADT/DenseMap.h"
|
||
|
#include "llvm/ADT/TypeSwitch.h"
|
||
|
|
||
|
using namespace mlir;
|
||
|
using namespace mlir::tosa;
|
||
|
|
||
|
#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Tosa dialect interface includes.
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
|
||
|
|
||
|
namespace {
|
||
|
#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Dialect Function Inliner Interface.
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
struct TosaInlinerInterface : public DialectInlinerInterface {
|
||
|
using DialectInlinerInterface::DialectInlinerInterface;
|
||
|
|
||
|
//===--------------------------------------------------------------------===//
|
||
|
// Analysis Hooks.
|
||
|
//===--------------------------------------------------------------------===//
|
||
|
|
||
|
/// All operations can be inlined by default.
|
||
|
bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
|
||
|
IRMapping &map) const final {
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
/// All regions with If and While parent operators can be inlined.
|
||
|
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
|
||
|
IRMapping &map) const final {
|
||
|
return (isa<tosa::IfOp>(dest->getParentOp()) ||
|
||
|
isa<tosa::WhileOp>(dest->getParentOp()));
|
||
|
}
|
||
|
};
|
||
|
|
||
|
/// This class implements the bytecode interface for the Tosa dialect.
|
||
|
struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
|
||
|
TosaDialectBytecodeInterface(Dialect *dialect)
|
||
|
: BytecodeDialectInterface(dialect) {}
|
||
|
|
||
|
//===--------------------------------------------------------------------===//
|
||
|
// Attributes
|
||
|
|
||
|
Attribute readAttribute(DialectBytecodeReader &reader) const override {
|
||
|
return ::readAttribute(getContext(), reader);
|
||
|
}
|
||
|
|
||
|
LogicalResult writeAttribute(Attribute attr,
|
||
|
DialectBytecodeWriter &writer) const override {
|
||
|
return ::writeAttribute(attr, writer);
|
||
|
}
|
||
|
|
||
|
//===--------------------------------------------------------------------===//
|
||
|
// Types
|
||
|
|
||
|
Type readType(DialectBytecodeReader &reader) const override {
|
||
|
return ::readType(getContext(), reader);
|
||
|
}
|
||
|
|
||
|
LogicalResult writeType(Type type,
|
||
|
DialectBytecodeWriter &writer) const override {
|
||
|
return ::writeType(type, writer);
|
||
|
}
|
||
|
|
||
|
void writeVersion(DialectBytecodeWriter &writer) const final {
|
||
|
// TODO: Populate.
|
||
|
}
|
||
|
|
||
|
std::unique_ptr<DialectVersion>
|
||
|
readVersion(DialectBytecodeReader &reader) const final {
|
||
|
// TODO: Populate
|
||
|
reader.emitError("Dialect does not support versioning");
|
||
|
return nullptr;
|
||
|
}
|
||
|
|
||
|
LogicalResult upgradeFromVersion(Operation *topLevelOp,
|
||
|
const DialectVersion &version_) const final {
|
||
|
return success();
|
||
|
}
|
||
|
};
|
||
|
|
||
|
} // namespace
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// TOSA control flow support.
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
/// Returns the while loop body.
|
||
|
SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Tosa dialect initialization.
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
void TosaDialect::initialize() {
|
||
|
addOperations<
|
||
|
#define GET_OP_LIST
|
||
|
#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
|
||
|
>();
|
||
|
addAttributes<
|
||
|
#define GET_ATTRDEF_LIST
|
||
|
#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
|
||
|
>();
|
||
|
addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
|
||
|
}
|
||
|
|
||
|
Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
|
||
|
Type type, Location loc) {
|
||
|
// Tosa dialect constants only support ElementsAttr unlike standard dialect
|
||
|
// constant which supports all attributes.
|
||
|
if (llvm::isa<ElementsAttr>(value))
|
||
|
return builder.create<tosa::ConstOp>(loc, type,
|
||
|
llvm::cast<ElementsAttr>(value));
|
||
|
return nullptr;
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// Parsers and printers
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr,
|
||
|
Attribute &attr) {
|
||
|
if (succeeded(parser.parseOptionalEqual())) {
|
||
|
if (failed(parser.parseAttribute(attr))) {
|
||
|
return parser.emitError(parser.getCurrentLocation())
|
||
|
<< "expected attribute";
|
||
|
}
|
||
|
if (auto typedAttr = attr.dyn_cast<TypedAttr>()) {
|
||
|
typeAttr = TypeAttr::get(typedAttr.getType());
|
||
|
}
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
Type type;
|
||
|
if (failed(parser.parseColonType(type))) {
|
||
|
return parser.emitError(parser.getCurrentLocation()) << "expected type";
|
||
|
}
|
||
|
typeAttr = TypeAttr::get(type);
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
|
||
|
Attribute attr) {
|
||
|
bool needsSpace = false;
|
||
|
auto typedAttr = attr.dyn_cast_or_null<TypedAttr>();
|
||
|
if (!typedAttr || typedAttr.getType() != type.getValue()) {
|
||
|
p << ": ";
|
||
|
p.printAttribute(type);
|
||
|
needsSpace = true; // subsequent attr value needs a space separator
|
||
|
}
|
||
|
if (attr) {
|
||
|
if (needsSpace)
|
||
|
p << ' ';
|
||
|
p << "= ";
|
||
|
p.printAttribute(attr);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// TOSA Operator Verifiers.
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
static bool hasZeroDimension(ShapedType shapedType) {
|
||
|
if (!shapedType.hasRank())
|
||
|
return false;
|
||
|
|
||
|
auto rank = shapedType.getRank();
|
||
|
|
||
|
for (int i = 0; i < rank; i++) {
|
||
|
if (shapedType.isDynamicDim(i))
|
||
|
continue;
|
||
|
if (shapedType.getDimSize(i) == 0)
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
template <typename T> static LogicalResult verifyConvOp(T op) {
|
||
|
// All TOSA conv ops have an input() and weight().
|
||
|
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
|
||
|
auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
|
||
|
|
||
|
// Must be ranked tensor types
|
||
|
if (!inputType) {
|
||
|
op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
|
||
|
return failure();
|
||
|
}
|
||
|
if (!weightType) {
|
||
|
op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
if (hasZeroDimension(inputType))
|
||
|
return op.emitOpError() << "tensor has a dimension with size zero. Each "
|
||
|
"dimension of a tensor must have size >= 1";
|
||
|
|
||
|
auto inputEType = inputType.getElementType();
|
||
|
auto weightEType = weightType.getElementType();
|
||
|
|
||
|
bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
|
||
|
bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
|
||
|
|
||
|
// Either both must be quantized or both unquantized.
|
||
|
if (inputIsQuant != weightIsQuant) {
|
||
|
op.emitOpError(
|
||
|
"expect both input and weight to be float or not together, got ")
|
||
|
<< inputEType << " and " << weightEType;
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
// Quantized type must have constructed the quantizationattr, and unquantized
|
||
|
// types should not have a quantizationattr.
|
||
|
if ((inputIsQuant && !op.getQuantizationInfo()) ||
|
||
|
(!inputIsQuant && op.getQuantizationInfo())) {
|
||
|
op.emitOpError("quantizationattr is required for quantized type, and not "
|
||
|
"allowed for float type");
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::ArgMaxOp::verify() {
|
||
|
// Ensure output is of 32-bit integer
|
||
|
const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
|
||
|
if (!resultETy.isIntOrIndex())
|
||
|
return emitOpError("result tensor is not of integer type");
|
||
|
|
||
|
// Ensure axis is within the tensor rank
|
||
|
const auto inputType = llvm::cast<ShapedType>(getInput().getType());
|
||
|
const int64_t axis = getAxisAttr().getInt();
|
||
|
if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank()))
|
||
|
return emitOpError("specified axis is outside the rank of the tensor");
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::AvgPool2dOp::verify() {
|
||
|
auto inputType = llvm::cast<ShapedType>(getInput().getType());
|
||
|
if (hasZeroDimension(inputType))
|
||
|
return emitOpError() << "tensor has a dimension with size zero. Each "
|
||
|
"dimension of a tensor must have size >= 1";
|
||
|
|
||
|
auto inputETy = inputType.getElementType();
|
||
|
auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
|
||
|
|
||
|
if (auto quantType =
|
||
|
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
|
||
|
inputETy = quantType.getStorageType();
|
||
|
|
||
|
if (auto quantType =
|
||
|
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
|
||
|
resultETy = quantType.getStorageType();
|
||
|
|
||
|
auto accType = getAccType();
|
||
|
if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
|
||
|
return emitOpError("accumulator type for integer tensor is not i32");
|
||
|
|
||
|
if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
|
||
|
return emitOpError("accumulator type for f16 tensor is not f16/f32");
|
||
|
|
||
|
if (inputETy.isBF16() && !accType.isF32())
|
||
|
return emitOpError("accumulator type for bf16 tensor is not f32");
|
||
|
|
||
|
if (inputETy.isF32() && !accType.isF32())
|
||
|
return emitOpError("accumulator type for f32 tensor is not f32");
|
||
|
|
||
|
if ((inputETy.isF32() && resultETy.isF32()) ||
|
||
|
(inputETy.isF16() && resultETy.isF16()) ||
|
||
|
(inputETy.isBF16() && resultETy.isBF16()) ||
|
||
|
(inputETy.isInteger(8) && resultETy.isInteger(8)) ||
|
||
|
(inputETy.isInteger(16) && resultETy.isInteger(16)))
|
||
|
return success();
|
||
|
|
||
|
return emitOpError("input/output element types are incompatible.");
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::ClampOp::verify() {
|
||
|
mlir::Type inputETy =
|
||
|
llvm::cast<ShapedType>(getInput().getType()).getElementType();
|
||
|
if (auto quantType =
|
||
|
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
|
||
|
inputETy = quantType.getStorageType();
|
||
|
}
|
||
|
mlir::Type maxFpType = getMaxFpAttr().getType();
|
||
|
mlir::Type minFpType = getMinFpAttr().getType();
|
||
|
mlir::Type outputETy =
|
||
|
llvm::cast<ShapedType>(getOutput().getType()).getElementType();
|
||
|
if (auto quantType =
|
||
|
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
|
||
|
outputETy = quantType.getStorageType();
|
||
|
}
|
||
|
unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
|
||
|
|
||
|
if (inputETy != outputETy)
|
||
|
return emitOpError("input/output element types are incompatible.");
|
||
|
|
||
|
// if input datatype is float, check that the two min/max_fp attributes share
|
||
|
// the same type and that their type is either the same of the input's
|
||
|
// datatype, or a float type whose bitwidth > input datatype bitwidth
|
||
|
if (!inputETy.isInteger(dataTypeBitWidth)) {
|
||
|
if (((maxFpType != minFpType) ||
|
||
|
(maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <=
|
||
|
inputETy.getIntOrFloatBitWidth())))
|
||
|
return emitOpError("min/max attributes types are incompatible with "
|
||
|
"input/output element types.");
|
||
|
}
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// TOSA Operator Quantization Builders.
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
/// This builder is called on all convolution operators except TransposeConv,
|
||
|
/// which has specialized output shape semantics. The builder also defines the
|
||
|
/// bitwidth of the output given the bit width of the input & weight content.
|
||
|
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
|
||
|
Type outputType, Value input, Value weight,
|
||
|
Value bias, DenseI64ArrayAttr pad,
|
||
|
DenseI64ArrayAttr stride,
|
||
|
DenseI64ArrayAttr dilation) {
|
||
|
|
||
|
result.addOperands({input, weight, bias});
|
||
|
result.addAttribute("pad", pad);
|
||
|
result.addAttribute("stride", stride);
|
||
|
result.addAttribute("dilation", dilation);
|
||
|
|
||
|
auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
|
||
|
if (quantAttr) {
|
||
|
result.addAttribute("quantization_info", quantAttr);
|
||
|
result.addTypes(
|
||
|
buildConvOpResultTypeInfo(builder, outputType, input, weight));
|
||
|
} else {
|
||
|
result.addTypes(outputType);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
|
||
|
static void buildTransConvOpWithQuantInfo(
|
||
|
OpBuilder &builder, OperationState &result, Type outputType, Value input,
|
||
|
Value weight, Value bias, DenseI64ArrayAttr outpad,
|
||
|
DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
|
||
|
result.addOperands({input, weight, bias});
|
||
|
result.addAttribute("out_pad", outpad);
|
||
|
result.addAttribute("stride", stride);
|
||
|
result.addAttribute("out_shape", outputShape);
|
||
|
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
|
||
|
|
||
|
if (quantAttr) {
|
||
|
result.addAttribute("quantization_info", quantAttr);
|
||
|
result.addTypes(
|
||
|
buildConvOpResultTypeInfo(builder, outputType, input, weight));
|
||
|
} else {
|
||
|
result.addTypes(outputType);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/// The tosa.fully_connected op has its own builder as it does not have
|
||
|
/// strides/dilation/padding.
|
||
|
static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
|
||
|
Type outputType, Value input, Value weight,
|
||
|
Value bias) {
|
||
|
|
||
|
result.addOperands({input, weight, bias});
|
||
|
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
|
||
|
if (quantAttr) {
|
||
|
result.addAttribute("quantization_info", quantAttr);
|
||
|
result.addTypes(
|
||
|
buildConvOpResultTypeInfo(builder, outputType, input, weight));
|
||
|
} else {
|
||
|
result.addTypes(outputType);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/// The tosa.matmul op is also intended to be generated where a fully_connected
|
||
|
/// op must be constructed where the weight is not a constant. In this case,
|
||
|
/// the fully_connected op must be expressed using matmul.
|
||
|
/// TODO: Add link to the leglization document explaining this.
|
||
|
static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
|
||
|
OperationState &result, Type outputType,
|
||
|
Value a, Value b) {
|
||
|
result.addOperands({a, b});
|
||
|
auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
|
||
|
|
||
|
if (quantAttr) {
|
||
|
result.addAttribute("quantization_info", quantAttr);
|
||
|
|
||
|
auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
|
||
|
assert(inputType && "Input must be a shaped tensor type!");
|
||
|
|
||
|
auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
|
||
|
inputType.getElementType());
|
||
|
assert(inputQType && "Tensor must have quantized datatype!");
|
||
|
|
||
|
unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
|
||
|
|
||
|
auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
|
||
|
assert(outputShapedType && "Output must be a shaped type");
|
||
|
|
||
|
IntegerType accElementType;
|
||
|
if (inputBits == 16)
|
||
|
accElementType = builder.getIntegerType(48);
|
||
|
else
|
||
|
accElementType = builder.getI32Type();
|
||
|
auto accType = outputShapedType.clone(accElementType);
|
||
|
result.addTypes(accType);
|
||
|
} else {
|
||
|
result.addTypes(outputType);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
|
||
|
/// but avg_pool operator has its own builder as it has additional parameters
|
||
|
/// not part of the unary ops.
|
||
|
static void
|
||
|
buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
|
||
|
Type outputType, Value input,
|
||
|
DenseArrayAttr kernel, DenseArrayAttr stride,
|
||
|
DenseArrayAttr pad, TypeAttr acc_type) {
|
||
|
result.addOperands(input);
|
||
|
result.addAttribute("kernel", kernel);
|
||
|
result.addAttribute("stride", stride);
|
||
|
result.addAttribute("pad", pad);
|
||
|
result.addAttribute("acc_type", acc_type);
|
||
|
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
|
||
|
if (quantAttr)
|
||
|
result.addAttribute("quantization_info", quantAttr);
|
||
|
result.types.push_back(outputType);
|
||
|
}
|
||
|
|
||
|
/// This builder is called on single-parameter unary operators that have scale
|
||
|
/// relationship between their input and output, expressed by the
|
||
|
/// UnaryOpQuantizationAttr.
|
||
|
static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
|
||
|
OperationState &result, Type outputType,
|
||
|
Value input) {
|
||
|
result.addOperands(input);
|
||
|
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
|
||
|
if (quantAttr)
|
||
|
result.addAttribute("quantization_info", quantAttr);
|
||
|
result.types.push_back(outputType);
|
||
|
}
|
||
|
|
||
|
/// This builder is called on TOSA pad operator that needs to create its own
|
||
|
/// OptionalAttr quantization_attr parameter to scale the padding values
|
||
|
/// correctly. No pad_const is interpreted as zero-padding.
|
||
|
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
|
||
|
Type outputType, Value input,
|
||
|
Value paddings) {
|
||
|
result.addOperands({input, paddings});
|
||
|
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
|
||
|
if (quantAttr)
|
||
|
result.addAttribute("quantization_info", quantAttr);
|
||
|
result.types.push_back(outputType);
|
||
|
}
|
||
|
|
||
|
/// This builder is called on TOSA pad operator when an explicit pad_const
|
||
|
/// value is passed in. It also optionally constructs quantization_attr.
|
||
|
static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
|
||
|
OperationState &result,
|
||
|
Type outputType, Value input,
|
||
|
Value paddings,
|
||
|
Value padConst) {
|
||
|
result.addOperands({input, paddings, padConst});
|
||
|
auto quantAttr = buildPadOpQuantizationAttr(builder, input);
|
||
|
if (quantAttr)
|
||
|
result.addAttribute("quantization_info", quantAttr);
|
||
|
result.types.push_back(outputType);
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// TOSA Operator Return Type Inference.
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
|
||
|
SmallVector<int64_t> &outShape) {
|
||
|
int64_t outRank = 0;
|
||
|
for (int i = 0, e = operands.size(); i != e; ++i) {
|
||
|
auto shape = operands.getShape(i);
|
||
|
if (!shape.hasRank()) {
|
||
|
// TODO(jennik): Update function to have better case handling for invalid
|
||
|
// operands and for ranked tensors.
|
||
|
return failure();
|
||
|
}
|
||
|
outRank = std::max<int64_t>(outRank, shape.getRank());
|
||
|
}
|
||
|
|
||
|
outShape.resize(outRank, 1);
|
||
|
|
||
|
for (int i = 0, e = operands.size(); i != e; ++i) {
|
||
|
auto shape = operands.getShape(i);
|
||
|
auto rankDiff = outShape.size() - shape.getRank();
|
||
|
|
||
|
for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
|
||
|
auto dim1 = outShape[i + rankDiff];
|
||
|
auto dim2 = shape.getDimSize(i);
|
||
|
auto resolvedDim = dim1;
|
||
|
|
||
|
if (dim1 == 1) {
|
||
|
resolvedDim = dim2;
|
||
|
} else if (dim2 == 1) {
|
||
|
resolvedDim = dim1;
|
||
|
} else if (dim1 != dim2) {
|
||
|
return failure();
|
||
|
}
|
||
|
outShape[i + rankDiff] = resolvedDim;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
ArgMaxOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||
|
IntegerAttr axis = adaptor.getProperties().axis;
|
||
|
int32_t axisVal = axis.getValue().getSExtValue();
|
||
|
|
||
|
if (!inputShape.hasRank()) {
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents());
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
SmallVector<int64_t> outShape;
|
||
|
outShape.reserve(inputShape.getRank() - 1);
|
||
|
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
|
||
|
if (i == axisVal)
|
||
|
continue;
|
||
|
outShape.push_back(inputShape.getDimSize(i));
|
||
|
}
|
||
|
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
RFFT2dOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||
|
|
||
|
if (!inputShape.hasRank())
|
||
|
return failure();
|
||
|
|
||
|
llvm::SmallVector<int64_t> outputShape;
|
||
|
outputShape.resize(3, ShapedType::kDynamic);
|
||
|
outputShape[0] = inputShape.getDimSize(0);
|
||
|
outputShape[1] = inputShape.getDimSize(1);
|
||
|
int64_t inWidth = inputShape.getDimSize(2);
|
||
|
|
||
|
// Note that we can support this calculation symbolically
|
||
|
// in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
|
||
|
if (inWidth != ShapedType::kDynamic)
|
||
|
outputShape[2] = inWidth / 2 + 1;
|
||
|
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
FFT2dOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
inferredReturnShapes.push_back(
|
||
|
ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
|
||
|
inferredReturnShapes.push_back(
|
||
|
ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
ConcatOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
// Infer all dimension sizes by reducing based on inputs.
|
||
|
const Properties &prop = adaptor.getProperties();
|
||
|
int32_t axis = prop.axis.getValue().getSExtValue();
|
||
|
llvm::SmallVector<int64_t> outputShape;
|
||
|
bool hasRankedInput = false;
|
||
|
for (auto operand : adaptor.getOperands()) {
|
||
|
ShapeAdaptor operandShape(operand.getType());
|
||
|
if (!operandShape.hasRank())
|
||
|
continue;
|
||
|
|
||
|
// Copy the Operand's rank.
|
||
|
if (!hasRankedInput)
|
||
|
outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
|
||
|
|
||
|
// Copy shapes until the dim is non-dynamic.
|
||
|
for (int i = 0, s = operandShape.getRank(); i < s; i++) {
|
||
|
if (i == axis || operandShape.isDynamicDim(i))
|
||
|
continue;
|
||
|
if (outputShape[i] == ShapedType::kDynamic)
|
||
|
outputShape[i] = operandShape.getDimSize(i);
|
||
|
if (outputShape[i] != operandShape.getDimSize(i))
|
||
|
return emitOptionalError(location,
|
||
|
"Cannot concat tensors with different sizes"
|
||
|
" on the non-axis dimension ",
|
||
|
i);
|
||
|
}
|
||
|
|
||
|
hasRankedInput = true;
|
||
|
}
|
||
|
Type inputType =
|
||
|
llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
|
||
|
if (!hasRankedInput) {
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
// Determine the dimension size along the concatenation axis.
|
||
|
int64_t concatDimSize = 0;
|
||
|
for (auto operand : adaptor.getOperands()) {
|
||
|
ShapeAdaptor operandShape(operand.getType());
|
||
|
|
||
|
// We need to know the length of the concatenation axis of all inputs to
|
||
|
// determine the dimension size of the output shape.
|
||
|
if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
|
||
|
concatDimSize = ShapedType::kDynamic;
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
concatDimSize += operandShape.getDimSize(axis);
|
||
|
}
|
||
|
|
||
|
outputShape[axis] = concatDimSize;
|
||
|
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::EqualOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
ValueShapeRange operands, DictionaryAttr attributes,
|
||
|
OpaqueProperties properties, RegionRange regions,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
auto elementType = IntegerType::get(context, /*width=*/1);
|
||
|
|
||
|
llvm::SmallVector<int64_t> outShape;
|
||
|
if (resolveBroadcastShape(operands, outShape).failed()) {
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||
|
if (l.size() != r.size() || l.size() != 1)
|
||
|
return false;
|
||
|
return succeeded(verifyCompatibleShape(l[0], r[0]));
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
FullyConnectedOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||
|
ShapeAdaptor weightShape(adaptor.getWeight().getType());
|
||
|
ShapeAdaptor biasShape(adaptor.getBias().getType());
|
||
|
|
||
|
// All shapes are dynamic.
|
||
|
SmallVector<int64_t> outShape;
|
||
|
outShape.resize(2, ShapedType::kDynamic);
|
||
|
|
||
|
if (inputShape.hasRank()) {
|
||
|
outShape[0] = inputShape.getDimSize(0);
|
||
|
}
|
||
|
|
||
|
if (weightShape.hasRank()) {
|
||
|
outShape[1] = weightShape.getDimSize(0);
|
||
|
}
|
||
|
|
||
|
if (biasShape.hasRank()) {
|
||
|
outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0)
|
||
|
: outShape[1];
|
||
|
}
|
||
|
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
|
||
|
|
||
|
LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
MatMulOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
ShapeAdaptor lhsShape(adaptor.getA().getType());
|
||
|
ShapeAdaptor rhsShape(adaptor.getB().getType());
|
||
|
|
||
|
// All shapes are dynamic.
|
||
|
SmallVector<int64_t> outShape;
|
||
|
outShape.resize(3, ShapedType::kDynamic);
|
||
|
|
||
|
if (lhsShape.hasRank()) {
|
||
|
outShape[0] = lhsShape.getDimSize(0);
|
||
|
outShape[1] = lhsShape.getDimSize(1);
|
||
|
}
|
||
|
|
||
|
if (rhsShape.hasRank()) {
|
||
|
outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
|
||
|
: outShape[0];
|
||
|
outShape[2] = rhsShape.getDimSize(2);
|
||
|
}
|
||
|
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::PadOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
PadOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
ShapeAdaptor inputShape(adaptor.getInput1().getType());
|
||
|
ShapeAdaptor paddingShape(adaptor.getPadding().getType());
|
||
|
SmallVector<int64_t> outputShape;
|
||
|
|
||
|
// If both inputs have unknown shape, we cannot determine the shape of the
|
||
|
// output.
|
||
|
if (!inputShape.hasRank() && !paddingShape.hasRank()) {
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents());
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
// If the input rank is unknown we can info the output rank using the padding
|
||
|
// shape's first dim.
|
||
|
if (!inputShape.hasRank()) {
|
||
|
if (paddingShape.isDynamicDim(0)) {
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents());
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic);
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
DenseIntElementsAttr paddings;
|
||
|
// If the paddings value is not a constant, all dimensions must be dynamic.
|
||
|
if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) {
|
||
|
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
SmallVector<int64_t> paddingValues;
|
||
|
for (auto val : paddings) {
|
||
|
paddingValues.push_back(val.getSExtValue());
|
||
|
}
|
||
|
|
||
|
outputShape.reserve(inputShape.getRank());
|
||
|
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
|
||
|
if (inputShape.isDynamicDim(i)) {
|
||
|
outputShape.push_back(ShapedType::kDynamic);
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
|
||
|
paddingValues[i * 2 + 1]);
|
||
|
}
|
||
|
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
|
||
|
return to_vector(llvm::map_range(shape, [](int64_t dim) {
|
||
|
return dim == -1 ? ShapedType::kDynamic : dim;
|
||
|
}));
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::SliceOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
SliceOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
inferredReturnShapes.push_back(
|
||
|
ShapedTypeComponents(convertToMlirShape(adaptor.getSize())));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::SliceOp::verify() {
|
||
|
auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
|
||
|
if (!inputType)
|
||
|
return success();
|
||
|
|
||
|
if (static_cast<size_t>(inputType.getRank()) != getStart().size())
|
||
|
return emitOpError(
|
||
|
"length of start attribute is not equal rank of input shape");
|
||
|
|
||
|
if (static_cast<size_t>(inputType.getRank()) != getSize().size())
|
||
|
return emitOpError(
|
||
|
"length of size attribute is not equal rank of input shape");
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::TableOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
TableOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||
|
|
||
|
if (!inputShape.hasRank()) {
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents());
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
inferredReturnShapes.resize(1);
|
||
|
inputShape.getDims(inferredReturnShapes[0]);
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::TileOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
TileOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
ArrayRef<int64_t> multiples = adaptor.getMultiples();
|
||
|
ShapeAdaptor inputShape(adaptor.getInput1().getType());
|
||
|
SmallVector<int64_t> outputShape;
|
||
|
if (!inputShape.hasRank()) {
|
||
|
outputShape.resize(multiples.size(), ShapedType::kDynamic);
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||
|
return success();
|
||
|
} else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
|
||
|
return failure();
|
||
|
|
||
|
// Any non dynamic dimension can be multiplied to a known size.
|
||
|
outputShape.reserve(multiples.size());
|
||
|
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
|
||
|
int64_t dim = inputShape.getDimSize(i);
|
||
|
if (dim != ShapedType::kDynamic)
|
||
|
dim *= multiples[i];
|
||
|
outputShape.push_back(dim);
|
||
|
}
|
||
|
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::TileOp::verify() {
|
||
|
ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
|
||
|
ShapedType outputType = llvm::cast<ShapedType>(getType());
|
||
|
auto multiples = getMultiples();
|
||
|
|
||
|
if (inputType.hasRank()) {
|
||
|
if (static_cast<size_t>(inputType.getRank()) != multiples.size())
|
||
|
return emitOpError("expect 'multiples' array to have length ")
|
||
|
<< inputType.getRank() << " but got " << multiples.size() << ".";
|
||
|
if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
|
||
|
return emitOpError("expect same input and output tensor rank.");
|
||
|
} else if (outputType.hasRank() &&
|
||
|
static_cast<size_t>(outputType.getRank()) != multiples.size())
|
||
|
return emitOpError("expect 'multiples' array to have length ")
|
||
|
<< outputType.getRank() << " but got " << multiples.size() << ".";
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||
|
if (l.size() != r.size() || l.size() != 1)
|
||
|
return false;
|
||
|
return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
ReshapeOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
ShapeAdaptor inputShape(adaptor.getInput1().getType());
|
||
|
Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
|
||
|
llvm::SmallVector<int64_t> newShapeValue =
|
||
|
convertToMlirShape(adaptor.getNewShape());
|
||
|
|
||
|
// We cannot infer from the total number of elements so we must take the
|
||
|
// shape attribute as exact.
|
||
|
if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
|
||
|
inferredReturnShapes.push_back(
|
||
|
ShapedTypeComponents(newShapeValue, inputType));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
// Determine the number of elements covered by the slice of all static
|
||
|
// dimensions. This allows us to infer the length of the remaining dynamic
|
||
|
// dimension.
|
||
|
int64_t numElements = inputShape.getNumElements();
|
||
|
int64_t staticMul = 1;
|
||
|
for (auto val : newShapeValue) {
|
||
|
if (!ShapedType::isDynamic(val)) {
|
||
|
staticMul *= val;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Determine the length of the dynamic dimension.
|
||
|
for (auto &val : newShapeValue) {
|
||
|
if (ShapedType::isDynamic(val))
|
||
|
val = numElements / staticMul;
|
||
|
}
|
||
|
|
||
|
inferredReturnShapes.push_back(
|
||
|
ShapedTypeComponents(newShapeValue, inputType));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
mlir::LogicalResult tosa::ReshapeOp::verify() {
|
||
|
ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
|
||
|
ShapedType outputType = llvm::cast<ShapedType>(getType());
|
||
|
|
||
|
if (hasZeroDimension(inputType) || hasZeroDimension(outputType))
|
||
|
return emitOpError() << "tensor has a dimension with size zero. Each "
|
||
|
"dimension of a tensor must have size >= 1";
|
||
|
|
||
|
if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
|
||
|
int64_t inputElementsNum = inputType.getNumElements();
|
||
|
int64_t outputElementsNum = outputType.getNumElements();
|
||
|
if (inputElementsNum != outputElementsNum) {
|
||
|
return emitOpError() << "Cannot reshape " << inputElementsNum
|
||
|
<< " elements into " << outputElementsNum;
|
||
|
}
|
||
|
}
|
||
|
return mlir::success();
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector<int64_t> &perms) {
|
||
|
// Perms must be constants.
|
||
|
DenseIntElementsAttr permsAttr;
|
||
|
if (!matchPattern(getPerms(), m_Constant(&permsAttr)))
|
||
|
return failure();
|
||
|
|
||
|
// Transpose is not the identity transpose.
|
||
|
perms = llvm::to_vector(
|
||
|
llvm::map_range(permsAttr.getValues<APInt>(),
|
||
|
[](const APInt &val) { return val.getSExtValue(); }));
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
TransposeOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
ShapeAdaptor inputShape(adaptor.getInput1().getType());
|
||
|
ShapeAdaptor permsShape(adaptor.getPerms().getType());
|
||
|
|
||
|
// We cannot infer anything from a rank-0 "permutation" tensor.
|
||
|
if (permsShape.hasRank() && permsShape.getRank() == 0)
|
||
|
return failure();
|
||
|
|
||
|
// If input rank and permutation length is unknown, the output rank is
|
||
|
// unknown.
|
||
|
if (!inputShape.hasRank() || !permsShape.hasRank() ||
|
||
|
permsShape.isDynamicDim(0)) {
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents());
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
// This would imply the number of permutations does not match the rank of the
|
||
|
// input which is illegal.
|
||
|
if (permsShape.getDimSize(0) != inputShape.getRank()) {
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
SmallVector<int64_t> outputShape;
|
||
|
// Rank-0 means no permutations matter.
|
||
|
if (inputShape.getRank() == 0) {
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
// Check whether the input dimensions are all the same.
|
||
|
bool allTheSame = true;
|
||
|
for (int i = 1, s = inputShape.getRank(); i < s; i++) {
|
||
|
if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
|
||
|
allTheSame = false;
|
||
|
break;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// If all of the input dimensions are the same we don't care about the
|
||
|
// permutation.
|
||
|
if (allTheSame) {
|
||
|
outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
|
||
|
// If the permuations are a constant we can directly determine the output
|
||
|
// shape.
|
||
|
DenseIntElementsAttr attr;
|
||
|
if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) &&
|
||
|
attr.getType().getRank() == 1) {
|
||
|
ShapeAdaptor permShape = attr;
|
||
|
// Constant permutation must be the same length as the input rank.
|
||
|
if (inputShape.getRank() != permShape.getRank())
|
||
|
return emitOptionalError(location,
|
||
|
"constant permutation must be the same length"
|
||
|
" as the input rank");
|
||
|
|
||
|
// Constant permutation values must be within the input rank.
|
||
|
for (int i = 0, e = inputShape.getRank(); i < e; i++) {
|
||
|
if (inputShape.getRank() <= permShape.getDimSize(i))
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
outputShape.reserve(inputShape.getRank());
|
||
|
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
|
||
|
outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::TransposeOp::verify() {
|
||
|
TensorType inputType = getInput1().getType();
|
||
|
TensorType permType = getPerms().getType();
|
||
|
TensorType outputType = getOutput().getType();
|
||
|
|
||
|
if (permType.hasRank() && permType.getRank() != 1)
|
||
|
return emitOpError()
|
||
|
<< "expected permutation tensor to be rank 1 but got rank "
|
||
|
<< permType.getRank();
|
||
|
if (inputType.hasRank() && permType.hasRank())
|
||
|
if (!permType.isDynamicDim(0) &&
|
||
|
permType.getDimSize(0) != inputType.getRank())
|
||
|
return emitOpError() << "expected permutation tensor dim 0 to have size "
|
||
|
<< inputType.getRank()
|
||
|
<< " (input rank) but got size "
|
||
|
<< permType.getDimSize(0);
|
||
|
if (inputType.hasRank() && outputType.hasRank() &&
|
||
|
inputType.getRank() != outputType.getRank())
|
||
|
return emitOpError()
|
||
|
<< "expected input tensor rank to equal result tensor rank";
|
||
|
if (outputType.hasRank() && permType.hasRank())
|
||
|
if (!permType.isDynamicDim(0) &&
|
||
|
permType.getDimSize(0) != outputType.getRank())
|
||
|
return emitOpError() << "expected permutation tensor dim 0 to have size "
|
||
|
<< outputType.getRank()
|
||
|
<< " (output rank) but got size "
|
||
|
<< permType.getDimSize(0);
|
||
|
|
||
|
SmallVector<int64_t> constantPerms;
|
||
|
if (succeeded(getConstantPerms(constantPerms))) {
|
||
|
// Assert that the permutation tensor has a rank, which means that the rank
|
||
|
// has been verified above.
|
||
|
assert(permType.hasRank() &&
|
||
|
"Unexpectedly found permutation tensor without rank");
|
||
|
if (!isPermutationVector(constantPerms))
|
||
|
return emitOpError() << "expected valid permutation tensor";
|
||
|
}
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::GatherOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
GatherOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
llvm::SmallVector<int64_t> outputShape;
|
||
|
outputShape.resize(3, ShapedType::kDynamic);
|
||
|
|
||
|
ShapeAdaptor valuesShape(adaptor.getValues().getType());
|
||
|
if (valuesShape.hasRank()) {
|
||
|
outputShape[0] = valuesShape.getDimSize(0);
|
||
|
outputShape[2] = valuesShape.getDimSize(2);
|
||
|
}
|
||
|
|
||
|
ShapeAdaptor indicesShape(adaptor.getIndices().getType());
|
||
|
if (indicesShape.hasRank()) {
|
||
|
if (outputShape[0] == ShapedType::kDynamic)
|
||
|
outputShape[0] = indicesShape.getDimSize(0);
|
||
|
if (outputShape[1] == ShapedType::kDynamic)
|
||
|
outputShape[1] = indicesShape.getDimSize(1);
|
||
|
}
|
||
|
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
ResizeOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
llvm::SmallVector<int64_t, 4> outputShape;
|
||
|
outputShape.resize(4, ShapedType::kDynamic);
|
||
|
|
||
|
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||
|
if (!inputShape.hasRank())
|
||
|
return failure();
|
||
|
|
||
|
outputShape[0] = inputShape.getDimSize(0);
|
||
|
outputShape[3] = inputShape.getDimSize(3);
|
||
|
int64_t inputHeight = inputShape.getDimSize(1);
|
||
|
int64_t inputWidth = inputShape.getDimSize(2);
|
||
|
|
||
|
if ((inputHeight == ShapedType::kDynamic) ||
|
||
|
(inputWidth == ShapedType::kDynamic))
|
||
|
return failure();
|
||
|
|
||
|
llvm::ArrayRef<int64_t> scaleInt = adaptor.getScale();
|
||
|
llvm::ArrayRef<int64_t> offsetInt = adaptor.getOffset();
|
||
|
llvm::ArrayRef<int64_t> borderInt = adaptor.getBorder();
|
||
|
|
||
|
// Compute the output shape based on attributes: scale, offset, and border.
|
||
|
outputShape[1] =
|
||
|
(((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
|
||
|
scaleInt[1]) +
|
||
|
1;
|
||
|
|
||
|
outputShape[2] =
|
||
|
(((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
|
||
|
scaleInt[3]) +
|
||
|
1;
|
||
|
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
ScatterOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
llvm::SmallVector<int64_t> outputShape;
|
||
|
outputShape.resize(3, ShapedType::kDynamic);
|
||
|
|
||
|
ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
|
||
|
if (valuesInShape.hasRank()) {
|
||
|
outputShape[0] = valuesInShape.getDimSize(0);
|
||
|
outputShape[1] = valuesInShape.getDimSize(1);
|
||
|
outputShape[2] = valuesInShape.getDimSize(2);
|
||
|
}
|
||
|
|
||
|
ShapeAdaptor indicesShape(adaptor.getIndices().getType());
|
||
|
if (indicesShape.hasRank()) {
|
||
|
if (outputShape[0] == ShapedType::kDynamic)
|
||
|
outputShape[0] = indicesShape.getDimSize(0);
|
||
|
}
|
||
|
|
||
|
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||
|
if (inputShape.hasRank()) {
|
||
|
if (outputShape[0] == ShapedType::kDynamic)
|
||
|
outputShape[0] = inputShape.getDimSize(0);
|
||
|
if (outputShape[2] == ShapedType::kDynamic)
|
||
|
outputShape[2] = inputShape.getDimSize(2);
|
||
|
}
|
||
|
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
static LogicalResult ReduceInferReturnTypes(
|
||
|
ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
int64_t axisVal = axis.getValue().getSExtValue();
|
||
|
if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
SmallVector<int64_t> outputShape;
|
||
|
operandShape.getDims(outputShape);
|
||
|
outputShape[axisVal] = 1;
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
#define COMPATIBLE_RETURN_TYPES(OP) \
|
||
|
bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
|
||
|
if (l.size() != r.size() || l.size() != 1) \
|
||
|
return false; \
|
||
|
if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
|
||
|
return false; \
|
||
|
return succeeded(verifyCompatibleShape(l[0], r[0])); \
|
||
|
}
|
||
|
|
||
|
#define REDUCE_SHAPE_INFER(OP) \
|
||
|
LogicalResult OP::inferReturnTypeComponents( \
|
||
|
MLIRContext *context, ::std::optional<Location> location, \
|
||
|
OP::Adaptor adaptor, \
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
|
||
|
Type inputType = \
|
||
|
llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
|
||
|
ShapeAdaptor inputShape(adaptor.getInput().getType()); \
|
||
|
const Properties &prop = adaptor.getProperties(); \
|
||
|
return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
|
||
|
inferredReturnShapes); \
|
||
|
} \
|
||
|
COMPATIBLE_RETURN_TYPES(OP)
|
||
|
|
||
|
REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
|
||
|
REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
|
||
|
REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
|
||
|
REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
|
||
|
REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
|
||
|
REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
|
||
|
#undef REDUCE_SHAPE_INFER
|
||
|
COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
|
||
|
#undef COMPATIBLE_RETURN_TYPES
|
||
|
|
||
|
template <typename T>
|
||
|
static LogicalResult verifyReduceOp(T op) {
|
||
|
// All TOSA reduce Ops have input, output and axis.
|
||
|
TensorType inputType = op.getInput().getType();
|
||
|
TensorType outputType = op.getOutput().getType();
|
||
|
int32_t reduceAxis = op.getAxis();
|
||
|
|
||
|
if (reduceAxis < 0) {
|
||
|
op.emitOpError("reduce axis must not be negative");
|
||
|
return failure();
|
||
|
}
|
||
|
if (inputType.hasRank()) {
|
||
|
int64_t inputRank = inputType.getRank();
|
||
|
// We allow for a special case where the input/output shape has rank 0 and
|
||
|
// axis is also 0.
|
||
|
if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) {
|
||
|
op.emitOpError("expect input tensor rank (")
|
||
|
<< inputRank << ") to be larger than reduce axis (" << reduceAxis
|
||
|
<< ")";
|
||
|
return failure();
|
||
|
}
|
||
|
}
|
||
|
if (outputType.hasRank()) {
|
||
|
int64_t outputRank = outputType.getRank();
|
||
|
if (inputType.hasRank() && outputRank != inputType.getRank()) {
|
||
|
op.emitOpError(
|
||
|
"expect output tensor rank to be equal to input tensor rank");
|
||
|
return failure();
|
||
|
}
|
||
|
if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) {
|
||
|
op.emitOpError("expect output tensor rank (")
|
||
|
<< outputRank << ") to be larger than reduce axis (" << reduceAxis
|
||
|
<< ")";
|
||
|
return failure();
|
||
|
}
|
||
|
// We can only verify the reduced dimension size to be 1 if this is not the
|
||
|
// special case of output rank == 0.
|
||
|
if (outputRank != 0) {
|
||
|
auto outputShape = outputType.getShape();
|
||
|
if (!outputType.isDynamicDim(reduceAxis) &&
|
||
|
outputShape[reduceAxis] != 1) {
|
||
|
op.emitOpError("expect reduced dimension size to be 1, got ")
|
||
|
<< outputShape[reduceAxis];
|
||
|
return failure();
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
|
||
|
LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
|
||
|
LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
|
||
|
LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
|
||
|
LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); }
|
||
|
LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
|
||
|
|
||
|
static LogicalResult NAryInferReturnTypes(
|
||
|
const ValueShapeRange &operands,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
llvm::SmallVector<int64_t> outShape;
|
||
|
if (resolveBroadcastShape(operands, outShape).failed()) {
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents());
|
||
|
} else {
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
|
||
|
}
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
#define NARY_SHAPE_INFER(OP) \
|
||
|
LogicalResult OP::inferReturnTypeComponents( \
|
||
|
MLIRContext *context, ::std::optional<Location> location, \
|
||
|
ValueShapeRange operands, DictionaryAttr attributes, \
|
||
|
OpaqueProperties properties, RegionRange regions, \
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
|
||
|
return NAryInferReturnTypes(operands, inferredReturnShapes); \
|
||
|
}
|
||
|
|
||
|
NARY_SHAPE_INFER(tosa::AbsOp)
|
||
|
NARY_SHAPE_INFER(tosa::AddOp)
|
||
|
NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
|
||
|
NARY_SHAPE_INFER(tosa::BitwiseAndOp)
|
||
|
NARY_SHAPE_INFER(tosa::BitwiseOrOp)
|
||
|
NARY_SHAPE_INFER(tosa::BitwiseXorOp)
|
||
|
NARY_SHAPE_INFER(tosa::BitwiseNotOp)
|
||
|
NARY_SHAPE_INFER(tosa::CastOp)
|
||
|
NARY_SHAPE_INFER(tosa::CeilOp)
|
||
|
NARY_SHAPE_INFER(tosa::ClampOp)
|
||
|
NARY_SHAPE_INFER(tosa::ClzOp)
|
||
|
NARY_SHAPE_INFER(tosa::DivOp)
|
||
|
NARY_SHAPE_INFER(tosa::ExpOp)
|
||
|
NARY_SHAPE_INFER(tosa::FloorOp)
|
||
|
NARY_SHAPE_INFER(tosa::GreaterEqualOp)
|
||
|
NARY_SHAPE_INFER(tosa::GreaterOp)
|
||
|
NARY_SHAPE_INFER(tosa::IdentityOp)
|
||
|
NARY_SHAPE_INFER(tosa::LogOp)
|
||
|
NARY_SHAPE_INFER(tosa::LogicalAndOp)
|
||
|
NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
|
||
|
NARY_SHAPE_INFER(tosa::LogicalNotOp)
|
||
|
NARY_SHAPE_INFER(tosa::LogicalOrOp)
|
||
|
NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
|
||
|
NARY_SHAPE_INFER(tosa::LogicalXorOp)
|
||
|
NARY_SHAPE_INFER(tosa::MaximumOp)
|
||
|
NARY_SHAPE_INFER(tosa::MinimumOp)
|
||
|
NARY_SHAPE_INFER(tosa::MulOp)
|
||
|
NARY_SHAPE_INFER(tosa::NegateOp)
|
||
|
NARY_SHAPE_INFER(tosa::PowOp)
|
||
|
NARY_SHAPE_INFER(tosa::ReciprocalOp)
|
||
|
NARY_SHAPE_INFER(tosa::RescaleOp)
|
||
|
NARY_SHAPE_INFER(tosa::ReverseOp)
|
||
|
NARY_SHAPE_INFER(tosa::RsqrtOp)
|
||
|
NARY_SHAPE_INFER(tosa::SelectOp)
|
||
|
NARY_SHAPE_INFER(tosa::SubOp)
|
||
|
NARY_SHAPE_INFER(tosa::TanhOp)
|
||
|
NARY_SHAPE_INFER(tosa::ErfOp)
|
||
|
NARY_SHAPE_INFER(tosa::SigmoidOp)
|
||
|
#undef PRED_SHAPE_INFER
|
||
|
|
||
|
static LogicalResult poolingInferReturnTypes(
|
||
|
ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
|
||
|
ArrayRef<int64_t> pad,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
llvm::SmallVector<int64_t> outputShape;
|
||
|
outputShape.resize(4, ShapedType::kDynamic);
|
||
|
|
||
|
// We only know the rank if the input type is unranked.
|
||
|
if (!inputShape) {
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
// Batch and number of channels are identical for pooling layer.
|
||
|
outputShape[0] = inputShape.getDimSize(0);
|
||
|
outputShape[3] = inputShape.getDimSize(3);
|
||
|
|
||
|
int64_t height = inputShape.getDimSize(1);
|
||
|
int64_t width = inputShape.getDimSize(2);
|
||
|
|
||
|
if (!ShapedType::isDynamic(height)) {
|
||
|
int64_t padded = height + pad[0] + pad[1] - kernel[0];
|
||
|
outputShape[1] = padded / stride[0] + 1;
|
||
|
}
|
||
|
|
||
|
if (!ShapedType::isDynamic(width)) {
|
||
|
int64_t padded = width + pad[2] + pad[3] - kernel[1];
|
||
|
outputShape[2] = padded / stride[1] + 1;
|
||
|
}
|
||
|
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult Conv2DOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
Conv2DOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
|
||
|
|
||
|
int64_t inputWidth = ShapedType::kDynamic;
|
||
|
int64_t inputHeight = ShapedType::kDynamic;
|
||
|
int64_t weightWidth = ShapedType::kDynamic;
|
||
|
int64_t weightHeight = ShapedType::kDynamic;
|
||
|
|
||
|
// Input shape describes input width/height and batch.
|
||
|
|
||
|
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||
|
if (inputShape.hasRank()) {
|
||
|
outputShape[0] = inputShape.getDimSize(0);
|
||
|
inputHeight = inputShape.getDimSize(1);
|
||
|
inputWidth = inputShape.getDimSize(2);
|
||
|
}
|
||
|
|
||
|
// Weight shapes describes the filter width/height and the output channels.
|
||
|
ShapeAdaptor weightShape(adaptor.getWeight().getType());
|
||
|
if (weightShape.hasRank()) {
|
||
|
outputShape[3] = weightShape.getDimSize(0);
|
||
|
weightHeight = weightShape.getDimSize(1);
|
||
|
weightWidth = weightShape.getDimSize(2);
|
||
|
}
|
||
|
|
||
|
// Bias shape can describe the output channels.
|
||
|
ShapeAdaptor biasShape(adaptor.getBias().getType());
|
||
|
if (biasShape.hasRank()) {
|
||
|
outputShape[3] = ShapedType::isDynamic(outputShape[3])
|
||
|
? biasShape.getDimSize(0)
|
||
|
: outputShape[3];
|
||
|
}
|
||
|
|
||
|
llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
|
||
|
llvm::ArrayRef<int64_t> stride = adaptor.getStride();
|
||
|
llvm::ArrayRef<int64_t> padding = adaptor.getPad();
|
||
|
|
||
|
if (!ShapedType::isDynamic(inputHeight) &&
|
||
|
!ShapedType::isDynamic(weightHeight)) {
|
||
|
int64_t inputSize = inputHeight + padding[0] + padding[1];
|
||
|
int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
|
||
|
int64_t unstridedResult = inputSize - filterSize + 1;
|
||
|
outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
|
||
|
}
|
||
|
|
||
|
if (!ShapedType::isDynamic(inputWidth) &&
|
||
|
!ShapedType::isDynamic(weightWidth)) {
|
||
|
int64_t inputSize = inputWidth + padding[2] + padding[3];
|
||
|
int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
|
||
|
int64_t unstridedResult = inputSize - filterSize + 1;
|
||
|
outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
|
||
|
}
|
||
|
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
|
||
|
|
||
|
LogicalResult Conv3DOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
Conv3DOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamic);
|
||
|
|
||
|
int64_t inputWidth = ShapedType::kDynamic;
|
||
|
int64_t inputHeight = ShapedType::kDynamic;
|
||
|
int64_t inputDepth = ShapedType::kDynamic;
|
||
|
|
||
|
int64_t weightWidth = ShapedType::kDynamic;
|
||
|
int64_t weightHeight = ShapedType::kDynamic;
|
||
|
int64_t weightDepth = ShapedType::kDynamic;
|
||
|
|
||
|
// Input shape describes input width/height and batch.
|
||
|
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||
|
if (inputShape.hasRank()) {
|
||
|
outputShape[0] = inputShape.getDimSize(0);
|
||
|
inputDepth = inputShape.getDimSize(1);
|
||
|
inputHeight = inputShape.getDimSize(2);
|
||
|
inputWidth = inputShape.getDimSize(3);
|
||
|
}
|
||
|
|
||
|
// Weight shapes describes the filter width/height and the output channels.
|
||
|
ShapeAdaptor weightShape(adaptor.getWeight().getType());
|
||
|
if (weightShape.hasRank()) {
|
||
|
outputShape[4] = weightShape.getDimSize(0);
|
||
|
weightDepth = weightShape.getDimSize(1);
|
||
|
weightHeight = weightShape.getDimSize(2);
|
||
|
weightWidth = weightShape.getDimSize(3);
|
||
|
}
|
||
|
|
||
|
// Bias shape can describe the output channels.
|
||
|
ShapeAdaptor biasShape(adaptor.getBias().getType());
|
||
|
if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) {
|
||
|
outputShape[4] = biasShape.getDimSize(0);
|
||
|
}
|
||
|
|
||
|
llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
|
||
|
llvm::ArrayRef<int64_t> stride = adaptor.getStride();
|
||
|
llvm::ArrayRef<int64_t> pad = adaptor.getPad();
|
||
|
|
||
|
if (!ShapedType::isDynamic(inputDepth) &&
|
||
|
!ShapedType::isDynamic(weightDepth)) {
|
||
|
int32_t inputSize = inputDepth + pad[0] + pad[1];
|
||
|
int32_t filterSize = (weightDepth - 1) * dilation[0] + 1;
|
||
|
int32_t unstridedResult = inputSize - filterSize + 1;
|
||
|
outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
|
||
|
}
|
||
|
|
||
|
if (!ShapedType::isDynamic(inputHeight) &&
|
||
|
!ShapedType::isDynamic(weightHeight)) {
|
||
|
int32_t inputSize = inputHeight + pad[2] + pad[3];
|
||
|
int32_t filterSize = (weightHeight - 1) * dilation[1] + 1;
|
||
|
int32_t unstridedResult = inputSize - filterSize + 1;
|
||
|
outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
|
||
|
}
|
||
|
|
||
|
if (!ShapedType::isDynamic(inputWidth) &&
|
||
|
!ShapedType::isDynamic(weightWidth)) {
|
||
|
int32_t inputSize = inputWidth + pad[4] + pad[5];
|
||
|
int32_t filterSize = (weightWidth - 1) * dilation[2] + 1;
|
||
|
int32_t unstridedResult = inputSize - filterSize + 1;
|
||
|
outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
|
||
|
}
|
||
|
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
|
||
|
|
||
|
LogicalResult AvgPool2dOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
AvgPool2dOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||
|
const Properties &prop = adaptor.getProperties();
|
||
|
return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
|
||
|
inferredReturnShapes);
|
||
|
}
|
||
|
|
||
|
LogicalResult MaxPool2dOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
MaxPool2dOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||
|
const Properties &prop = adaptor.getProperties();
|
||
|
return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
|
||
|
inferredReturnShapes);
|
||
|
}
|
||
|
|
||
|
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
DepthwiseConv2DOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
|
||
|
|
||
|
int64_t inputWidth = ShapedType::kDynamic;
|
||
|
int64_t inputHeight = ShapedType::kDynamic;
|
||
|
int64_t inputChannels = ShapedType::kDynamic;
|
||
|
|
||
|
int64_t weightWidth = ShapedType::kDynamic;
|
||
|
int64_t weightHeight = ShapedType::kDynamic;
|
||
|
int64_t depthChannels = ShapedType::kDynamic;
|
||
|
|
||
|
// Input shape describes input width/height and batch.
|
||
|
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||
|
if (inputShape.hasRank()) {
|
||
|
outputShape[0] = inputShape.getDimSize(0);
|
||
|
inputHeight = inputShape.getDimSize(1);
|
||
|
inputWidth = inputShape.getDimSize(2);
|
||
|
inputChannels = inputShape.getDimSize(3);
|
||
|
}
|
||
|
|
||
|
// Weight shapes describes the filter width/height and the output channels.
|
||
|
ShapeAdaptor weightShape(adaptor.getWeight().getType());
|
||
|
if (weightShape.hasRank()) {
|
||
|
weightHeight = weightShape.getDimSize(0);
|
||
|
weightWidth = weightShape.getDimSize(1);
|
||
|
inputChannels = ShapedType::isDynamic(inputChannels)
|
||
|
? weightShape.getDimSize(2)
|
||
|
: inputChannels;
|
||
|
depthChannels = weightShape.getDimSize(3);
|
||
|
}
|
||
|
|
||
|
// If both inputChannels and depthChannels are available we can determine
|
||
|
// the output channels.
|
||
|
if (!ShapedType::isDynamic(inputChannels) &&
|
||
|
!ShapedType::isDynamic(depthChannels)) {
|
||
|
outputShape[3] = inputChannels * depthChannels;
|
||
|
}
|
||
|
|
||
|
// Bias shape can describe the output channels.
|
||
|
ShapeAdaptor biasShape(adaptor.getBias().getType());
|
||
|
if (biasShape.hasRank()) {
|
||
|
outputShape[3] = ShapedType::isDynamic(outputShape[3])
|
||
|
? biasShape.getDimSize(0)
|
||
|
: outputShape[3];
|
||
|
}
|
||
|
|
||
|
llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
|
||
|
llvm::ArrayRef<int64_t> padding = adaptor.getPad();
|
||
|
llvm::ArrayRef<int64_t> stride = adaptor.getStride();
|
||
|
|
||
|
if (!ShapedType::isDynamic(inputHeight) &&
|
||
|
!ShapedType::isDynamic(weightHeight)) {
|
||
|
int64_t inputSize = inputHeight + padding[0] + padding[1];
|
||
|
int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
|
||
|
int64_t unstridedResult = inputSize - filterSize + 1;
|
||
|
outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
|
||
|
}
|
||
|
|
||
|
if (!ShapedType::isDynamic(inputWidth) &&
|
||
|
!ShapedType::isDynamic(weightWidth)) {
|
||
|
int64_t inputSize = inputWidth + padding[2] + padding[3];
|
||
|
int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
|
||
|
int64_t unstridedResult = inputSize - filterSize + 1;
|
||
|
outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
|
||
|
}
|
||
|
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); }
|
||
|
|
||
|
LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
TransposeConv2DOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
// outputShape is mutable.
|
||
|
llvm::SmallVector<int64_t> outputShape =
|
||
|
convertToMlirShape(adaptor.getOutShape());
|
||
|
|
||
|
int64_t inputWidth = ShapedType::kDynamic;
|
||
|
int64_t inputHeight = ShapedType::kDynamic;
|
||
|
int64_t weightWidth = ShapedType::kDynamic;
|
||
|
int64_t weightHeight = ShapedType::kDynamic;
|
||
|
|
||
|
// Input shape describes input width/height and batch.
|
||
|
ShapeAdaptor inputShape(adaptor.getInput().getType());
|
||
|
if (inputShape.hasRank()) {
|
||
|
outputShape[0] = ShapedType::isDynamic(outputShape[0])
|
||
|
? inputShape.getDimSize(0)
|
||
|
: outputShape[0];
|
||
|
inputHeight = inputShape.getDimSize(1);
|
||
|
inputWidth = inputShape.getDimSize(2);
|
||
|
}
|
||
|
|
||
|
// Weight shapes describes the filter width/height and the output channels.
|
||
|
ShapeAdaptor weightShape(adaptor.getFilter().getType());
|
||
|
if (weightShape.hasRank()) {
|
||
|
outputShape[3] = ShapedType::isDynamic(outputShape[3])
|
||
|
? weightShape.getDimSize(0)
|
||
|
: outputShape[3];
|
||
|
weightHeight = weightShape.getDimSize(1);
|
||
|
weightWidth = weightShape.getDimSize(2);
|
||
|
}
|
||
|
|
||
|
// Bias shape can describe the output channels.
|
||
|
ShapeAdaptor biasShape(adaptor.getInput().getType());
|
||
|
if (biasShape.hasRank()) {
|
||
|
outputShape[3] = ShapedType::isDynamic(outputShape[3])
|
||
|
? biasShape.getDimSize(0)
|
||
|
: outputShape[3];
|
||
|
}
|
||
|
|
||
|
llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
|
||
|
llvm::ArrayRef<int64_t> stride = adaptor.getStride();
|
||
|
|
||
|
if (!ShapedType::isDynamic(inputHeight) &&
|
||
|
!ShapedType::isDynamic(weightHeight)) {
|
||
|
int64_t calculateSize =
|
||
|
(inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
|
||
|
outputShape[1] =
|
||
|
ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
|
||
|
}
|
||
|
|
||
|
if (!ShapedType::isDynamic(inputWidth) &&
|
||
|
!ShapedType::isDynamic(weightWidth)) {
|
||
|
int64_t calculateSize =
|
||
|
(inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
|
||
|
outputShape[2] =
|
||
|
ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
|
||
|
}
|
||
|
|
||
|
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult IfOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
IfOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
llvm::SmallVector<tosa::YieldOp> yieldOps;
|
||
|
for (Region *region : adaptor.getRegions()) {
|
||
|
for (auto &block : *region)
|
||
|
if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
|
||
|
yieldOps.push_back(returnOp);
|
||
|
}
|
||
|
|
||
|
if (yieldOps.empty())
|
||
|
return failure();
|
||
|
|
||
|
// Get the initial type information for the yield op.
|
||
|
llvm::SmallVector<ValueKnowledge> resultKnowledge;
|
||
|
resultKnowledge.reserve(yieldOps.front().getNumOperands());
|
||
|
for (auto operand : yieldOps.front().getOperands()) {
|
||
|
resultKnowledge.push_back(
|
||
|
ValueKnowledge::getKnowledgeFromType(operand.getType()));
|
||
|
}
|
||
|
|
||
|
for (auto yieldOp : yieldOps) {
|
||
|
if (resultKnowledge.size() != yieldOp.getNumOperands())
|
||
|
return failure();
|
||
|
|
||
|
for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
|
||
|
int32_t index = it.index();
|
||
|
auto meet = ValueKnowledge::meet(
|
||
|
resultKnowledge[index],
|
||
|
ValueKnowledge::getKnowledgeFromType(it.value().getType()));
|
||
|
if (!meet)
|
||
|
continue;
|
||
|
resultKnowledge[index] = meet;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
for (const ValueKnowledge &result : resultKnowledge) {
|
||
|
inferredReturnShapes.push_back(result.getShapedTypeComponents());
|
||
|
}
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult WhileOp::inferReturnTypeComponents(
|
||
|
MLIRContext *context, ::std::optional<Location> location,
|
||
|
WhileOp::Adaptor adaptor,
|
||
|
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
|
||
|
llvm::SmallVector<tosa::YieldOp> yieldOps;
|
||
|
for (auto &block : adaptor.getBody())
|
||
|
if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
|
||
|
yieldOps.push_back(returnOp);
|
||
|
|
||
|
// TOSA's while must have a tosa.yield as its terminator. If not found this
|
||
|
// tosa.while is invalid.
|
||
|
if (yieldOps.empty())
|
||
|
return failure();
|
||
|
|
||
|
// Get the initial type information from the operand types.
|
||
|
llvm::SmallVector<ValueKnowledge> resultKnowledge;
|
||
|
resultKnowledge.reserve(yieldOps.front().getNumOperands());
|
||
|
for (auto operand : yieldOps.front().getOperands()) {
|
||
|
resultKnowledge.push_back(
|
||
|
ValueKnowledge::getKnowledgeFromType(operand.getType()));
|
||
|
}
|
||
|
|
||
|
for (auto yieldOp : yieldOps) {
|
||
|
if (resultKnowledge.size() != yieldOp.getNumOperands())
|
||
|
return failure();
|
||
|
|
||
|
for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
|
||
|
int32_t index = it.index();
|
||
|
if (auto meet = ValueKnowledge::meet(
|
||
|
resultKnowledge[index],
|
||
|
ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
|
||
|
resultKnowledge[index] = meet;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
for (const ValueKnowledge &result : resultKnowledge) {
|
||
|
inferredReturnShapes.push_back(result.getShapedTypeComponents());
|
||
|
}
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
|
||
|
if (auto vt = llvm::dyn_cast<VectorType>(getType()))
|
||
|
return llvm::to_vector<4>(vt.getShape());
|
||
|
return std::nullopt;
|
||
|
}
|
||
|
|
||
|
// parse and print of IfOp refer to the implementation of SCF dialect.
|
||
|
ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
|
||
|
// Create the regions for 'then'.
|
||
|
result.regions.reserve(2);
|
||
|
Region *thenRegion = result.addRegion();
|
||
|
Region *elseRegion = result.addRegion();
|
||
|
|
||
|
auto &builder = parser.getBuilder();
|
||
|
OpAsmParser::UnresolvedOperand cond;
|
||
|
// Create a i1 tensor type for the boolean condition.
|
||
|
Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
|
||
|
if (parser.parseOperand(cond) ||
|
||
|
parser.resolveOperand(cond, i1Type, result.operands))
|
||
|
return failure();
|
||
|
// Parse optional results type list.
|
||
|
if (parser.parseOptionalArrowTypeList(result.types))
|
||
|
return failure();
|
||
|
// Parse the 'then' region.
|
||
|
if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
|
||
|
return failure();
|
||
|
|
||
|
// If we find an 'else' keyword then parse the 'else' region.
|
||
|
if (!parser.parseOptionalKeyword("else")) {
|
||
|
if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
// Parse the optional attribute list.
|
||
|
if (parser.parseOptionalAttrDict(result.attributes))
|
||
|
return failure();
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
void IfOp::print(OpAsmPrinter &p) {
|
||
|
bool printBlockTerminators = false;
|
||
|
|
||
|
p << " " << getCond();
|
||
|
if (!getResults().empty()) {
|
||
|
p << " -> (" << getResultTypes() << ")";
|
||
|
// Print yield explicitly if the op defines values.
|
||
|
printBlockTerminators = true;
|
||
|
}
|
||
|
p << ' ';
|
||
|
p.printRegion(getThenBranch(),
|
||
|
/*printEntryBlockArgs=*/false,
|
||
|
/*printBlockTerminators=*/printBlockTerminators);
|
||
|
|
||
|
// Print the 'else' regions if it exists and has a block.
|
||
|
auto &elseRegion = getElseBranch();
|
||
|
if (!elseRegion.empty()) {
|
||
|
p << " else ";
|
||
|
p.printRegion(elseRegion,
|
||
|
/*printEntryBlockArgs=*/false,
|
||
|
/*printBlockTerminators=*/printBlockTerminators);
|
||
|
}
|
||
|
|
||
|
p.printOptionalAttrDict((*this)->getAttrs());
|
||
|
}
|
||
|
|
||
|
LogicalResult ReverseOp::verify() {
|
||
|
TensorType inputType = getInput().getType();
|
||
|
TensorType outputType = getOutput().getType();
|
||
|
int32_t reverseAxis = getAxis();
|
||
|
|
||
|
if (reverseAxis < 0)
|
||
|
return emitOpError("expected non-negative reverse axis");
|
||
|
if (inputType.hasRank()) {
|
||
|
int64_t inputRank = inputType.getRank();
|
||
|
// We allow for a special case where the input/output shape has rank 0 and
|
||
|
// axis is also 0.
|
||
|
if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0))
|
||
|
return emitOpError("expect input tensor rank (")
|
||
|
<< inputRank << ") to be larger than reverse axis (" << reverseAxis
|
||
|
<< ")";
|
||
|
}
|
||
|
if (outputType.hasRank()) {
|
||
|
int64_t outputRank = outputType.getRank();
|
||
|
if (inputType.hasRank() && outputRank != inputType.getRank())
|
||
|
return emitOpError(
|
||
|
"expect output tensor rank to be equal to input tensor rank");
|
||
|
if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0))
|
||
|
return emitOpError("expect output tensor rank (")
|
||
|
<< outputRank << ") to be larger than reverse axis ("
|
||
|
<< reverseAxis << ")";
|
||
|
}
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
// parse and print of WhileOp refer to the implementation of SCF dialect.
|
||
|
ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
|
||
|
SmallVector<OpAsmParser::Argument, 4> regionArgs;
|
||
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
|
||
|
Region *cond = result.addRegion();
|
||
|
Region *body = result.addRegion();
|
||
|
|
||
|
OptionalParseResult listResult =
|
||
|
parser.parseOptionalAssignmentList(regionArgs, operands);
|
||
|
if (listResult.has_value() && failed(listResult.value()))
|
||
|
return failure();
|
||
|
|
||
|
FunctionType functionType;
|
||
|
SMLoc typeLoc = parser.getCurrentLocation();
|
||
|
if (failed(parser.parseColonType(functionType)))
|
||
|
return failure();
|
||
|
|
||
|
result.addTypes(functionType.getResults());
|
||
|
|
||
|
if (functionType.getNumInputs() != operands.size()) {
|
||
|
return parser.emitError(typeLoc)
|
||
|
<< "expected as many input types as operands "
|
||
|
<< "(expected " << operands.size() << " got "
|
||
|
<< functionType.getNumInputs() << ")";
|
||
|
}
|
||
|
|
||
|
// Resolve input operands.
|
||
|
if (failed(parser.resolveOperands(operands, functionType.getInputs(),
|
||
|
parser.getCurrentLocation(),
|
||
|
result.operands)))
|
||
|
return failure();
|
||
|
|
||
|
// Propagate the types into the region arguments.
|
||
|
for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
|
||
|
regionArgs[i].type = functionType.getInput(i);
|
||
|
|
||
|
return failure(parser.parseRegion(*cond, regionArgs) ||
|
||
|
parser.parseKeyword("do") || parser.parseRegion(*body) ||
|
||
|
parser.parseOptionalAttrDictWithKeyword(result.attributes));
|
||
|
}
|
||
|
|
||
|
static void printInitializationList(OpAsmPrinter &parser,
|
||
|
Block::BlockArgListType blocksArgs,
|
||
|
ValueRange initializers,
|
||
|
StringRef prefix = "") {
|
||
|
assert(blocksArgs.size() == initializers.size() &&
|
||
|
"expected same length of arguments and initializers");
|
||
|
if (initializers.empty())
|
||
|
return;
|
||
|
|
||
|
parser << prefix << '(';
|
||
|
llvm::interleaveComma(
|
||
|
llvm::zip(blocksArgs, initializers), parser,
|
||
|
[&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
|
||
|
parser << ")";
|
||
|
}
|
||
|
|
||
|
void WhileOp::print(OpAsmPrinter &parser) {
|
||
|
printInitializationList(parser, getCond().front().getArguments(), getInputs(),
|
||
|
" ");
|
||
|
parser << " : ";
|
||
|
parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes());
|
||
|
parser << ' ';
|
||
|
parser.printRegion(getCond(), /*printEntryBlockArgs=*/false);
|
||
|
parser << " do ";
|
||
|
parser.printRegion(getBody());
|
||
|
parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
|
||
|
}
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// TOSA Attribute Definitions.
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#define GET_ATTRDEF_CLASSES
|
||
|
#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
|
||
|
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
// TOSA Operator Definitions.
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#define GET_OP_CLASSES
|
||
|
#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
|