bolt/deps/llvm-18.1.8/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
2025-02-14 19:21:04 +01:00

533 lines
16 KiB
C++

//===- TosaValidation.cpp ------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Validate if TOSA dialect input matchs with the specification for given
// requirements.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc"
#include <string>
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace tosa {
#define GEN_PASS_DEF_TOSAVALIDATION
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
} // namespace tosa
} // namespace mlir
using namespace mlir;
using namespace mlir::tosa;
namespace {
static LogicalResult checkConstantOperandPad(Operation *op) {
if (auto padOp = dyn_cast<tosa::PadOp>(op)) {
DenseElementsAttr paddings;
if (!matchPattern(padOp.getPadding(), m_Constant(&paddings)))
return op->emitOpError("padding of pad is not constant");
DenseElementsAttr padConst;
// Assume this op is zero-padding if padConst is not presented.
if (padOp.getPadConst() &&
!matchPattern(padOp.getPadConst(), m_Constant(&padConst)))
return op->emitOpError("pad_const of pad is not constant");
}
return success();
}
static LogicalResult checkConstantOperandTranspose(Operation *op) {
if (auto transposeOp = dyn_cast<tosa::TransposeOp>(op)) {
DenseElementsAttr perms;
if (!matchPattern(transposeOp.getPerms(), m_Constant(&perms)))
return op->emitOpError("perms of transpose is not constant");
}
return success();
}
static LogicalResult checkConstantOperandFullyConnected(Operation *op) {
if (auto fcOp = dyn_cast<tosa::FullyConnectedOp>(op)) {
DenseElementsAttr weight;
if (!matchPattern(fcOp.getWeight(), m_Constant(&weight)))
return op->emitOpError("weight of fully_connected is not constant");
DenseElementsAttr bias;
if (!matchPattern(fcOp.getBias(), m_Constant(&bias)))
return op->emitOpError("bias of fully_connected is not constant");
}
return success();
}
struct TosaLevel {
int32_t MAX_RANK = 0;
int32_t MAX_KERNEL = 0;
int32_t MAX_STRIDE = 0;
int32_t MAX_SCALE = 0;
// @todo: MAX_LOG2_SIZE value and checks
bool operator==(const TosaLevel &rhs) {
return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL &&
MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE;
}
};
static constexpr TosaLevel TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 256};
static constexpr TosaLevel TOSA_LEVEL_NONE = {0, 0, 0, 0};
//===----------------------------------------------------------------------===//
// TOSA Validation Pass.
//===----------------------------------------------------------------------===//
struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
public:
explicit TosaValidation() { populateConstantOperandChecks(); }
explicit TosaValidation(const TosaValidationOptions &options)
: TosaValidation() {
this->profile = options.profile;
this->StrictOperationSpecAlignment = options.StrictOperationSpecAlignment;
this->level = options.level;
}
void runOnOperation() final;
LogicalResult applyConstantOperandCheck(Operation *op) {
for (auto &checker : constCheckers) {
if (failed(checker(op)))
return failure();
}
return success();
}
LogicalResult applyLevelCheck(Operation *op);
// check variable read/write data types against variable declarations
LogicalResult applyVariableCheck(Operation *op);
private:
void populateConstantOperandChecks() {
constCheckers.emplace_back(checkConstantOperandPad);
constCheckers.emplace_back(checkConstantOperandTranspose);
constCheckers.emplace_back(checkConstantOperandFullyConnected);
}
bool levelCheckKernel(Operation *op, int32_t v,
const std::string &checkDesc) {
if (v > tosaLevel.MAX_KERNEL) {
op->emitOpError() << "failed level check: " << checkDesc;
return false;
}
return true;
}
bool levelCheckStride(Operation *op, int32_t v,
const std::string &checkDesc) {
if (v > tosaLevel.MAX_STRIDE) {
op->emitOpError() << "failed level check: " << checkDesc;
return false;
}
return true;
}
bool levelCheckScale(Operation *op, int32_t v, const std::string &checkDesc) {
if (v > tosaLevel.MAX_SCALE) {
op->emitOpError() << "failed level check: " << checkDesc;
return false;
}
return true;
}
bool levelCheckRank(Operation *op, const Value &v,
const std::string &checkDesc) {
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
if (!type.hasRank()) {
op->emitOpError() << "failed level check: unranked tensor";
return false;
}
if (type.getRank() > tosaLevel.MAX_RANK) {
op->emitOpError() << "failed level check: " << checkDesc;
return false;
}
}
return true;
}
template <typename T>
bool levelCheckRanksFor(Operation *op) {
if (dyn_cast<T>(op)) {
// level check ranks of all operands and results
for (auto v : op->getOperands()) {
if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK"))
return false;
}
for (auto v : op->getResults()) {
if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK"))
return false;
}
}
return true;
}
bool levelCheckRanks(Operation *op) {
#define CHECK_RANKS_FOR(tosaOp) \
if (!levelCheckRanksFor<tosaOp##Op>(op)) \
return false;
// tensor operators:
CHECK_RANKS_FOR(ArgMax);
// all activation functions:
CHECK_RANKS_FOR(Clamp);
CHECK_RANKS_FOR(Sigmoid);
CHECK_RANKS_FOR(Tanh);
// all elementwise binary operators:
CHECK_RANKS_FOR(Add);
CHECK_RANKS_FOR(ArithmeticRightShift);
CHECK_RANKS_FOR(BitwiseAnd);
CHECK_RANKS_FOR(BitwiseOr);
CHECK_RANKS_FOR(BitwiseXor);
CHECK_RANKS_FOR(Div);
CHECK_RANKS_FOR(LogicalAnd);
CHECK_RANKS_FOR(LogicalLeftShift);
CHECK_RANKS_FOR(LogicalRightShift);
CHECK_RANKS_FOR(LogicalOr);
CHECK_RANKS_FOR(LogicalXor);
CHECK_RANKS_FOR(Maximum);
CHECK_RANKS_FOR(Minimum);
CHECK_RANKS_FOR(Mul);
CHECK_RANKS_FOR(Pow);
CHECK_RANKS_FOR(Sub);
CHECK_RANKS_FOR(Table);
// all elementwise unary operators:
CHECK_RANKS_FOR(Abs);
CHECK_RANKS_FOR(BitwiseNot);
CHECK_RANKS_FOR(Ceil);
CHECK_RANKS_FOR(Clz);
CHECK_RANKS_FOR(Exp);
CHECK_RANKS_FOR(Floor);
CHECK_RANKS_FOR(Log);
CHECK_RANKS_FOR(LogicalNot);
CHECK_RANKS_FOR(Negate);
CHECK_RANKS_FOR(Reciprocal);
CHECK_RANKS_FOR(Rsqrt);
// all elementwise ternary operators:
CHECK_RANKS_FOR(Select);
// all comparison operators:
CHECK_RANKS_FOR(Equal);
CHECK_RANKS_FOR(Greater);
CHECK_RANKS_FOR(GreaterEqual);
// all reduction operators:
CHECK_RANKS_FOR(ReduceAll);
CHECK_RANKS_FOR(ReduceAny);
CHECK_RANKS_FOR(ReduceMax);
CHECK_RANKS_FOR(ReduceMin);
CHECK_RANKS_FOR(ReduceProd);
CHECK_RANKS_FOR(ReduceSum);
// all data layout operators:
CHECK_RANKS_FOR(Concat);
CHECK_RANKS_FOR(Pad);
CHECK_RANKS_FOR(Reshape);
CHECK_RANKS_FOR(Reverse);
CHECK_RANKS_FOR(Slice);
CHECK_RANKS_FOR(Tile);
CHECK_RANKS_FOR(Transpose);
// all type conversion operators:
CHECK_RANKS_FOR(Cast);
CHECK_RANKS_FOR(Rescale);
// all data nodes operators:
CHECK_RANKS_FOR(Const);
CHECK_RANKS_FOR(Identity);
#undef CHECK_RANKS_FOR
return true;
}
// Pool Op: level check kernel/stride/pad values
template <typename T>
bool levelCheckPool(Operation *op) {
if (auto poolOp = dyn_cast<T>(op)) {
for (auto k : poolOp.getKernel()) {
if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) {
return false;
}
}
for (auto s : poolOp.getStride()) {
if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
return false;
}
}
for (auto p : poolOp.getPad()) {
if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
return false;
}
}
}
return true;
}
// Conv Op: level check dilation/stride/pad values
template <typename T>
bool levelCheckConv(Operation *op) {
if (auto convOp = dyn_cast<T>(op)) {
for (auto k : convOp.getDilation()) {
if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) {
return false;
}
}
for (auto p : convOp.getPad()) {
if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
return false;
}
}
for (auto s : convOp.getStride()) {
if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
return false;
}
}
auto dilation = convOp.getDilation();
if (ShapedType weightType =
dyn_cast<ShapedType>(op->getOperand(1).getType())) {
auto shape = weightType.getShape();
if (isa<tosa::Conv2DOp>(op)) {
assert(shape.size() == 4);
assert(dilation.size() == 2);
if (!levelCheckKernel(op, dilation[0] * shape[1],
"dilation_y * KH <= MAX_KERNEL)") ||
!levelCheckKernel(op, dilation[1] * shape[2],
"dilation_x * KW <= MAX_KERNEL)"))
return false;
} else if (isa<tosa::Conv3DOp>(op)) {
assert(shape.size() == 5);
assert(dilation.size() == 3);
if (!levelCheckKernel(op, dilation[0] * shape[1],
"dilation_d * KD <= MAX_KERNEL)") ||
!levelCheckKernel(op, dilation[1] * shape[2],
"dilation_y * KH <= MAX_KERNEL)") ||
!levelCheckKernel(op, dilation[2] * shape[3],
"dilation_x * KW <= MAX_KERNEL)"))
return false;
} else if (isa<tosa::DepthwiseConv2DOp>(op)) {
assert(shape.size() == 4);
assert(dilation.size() == 2);
if (!levelCheckKernel(op, dilation[0] * shape[0],
"dilation_y * KH <= MAX_KERNEL)") ||
!levelCheckKernel(op, dilation[1] * shape[1],
"dilation_x * KW <= MAX_KERNEL)"))
return false;
}
}
}
return true;
}
// FFT op: level check H, W in input shape [N,H,W]
template <typename T>
bool levelCheckFFT(Operation *op) {
if (isa<T>(op)) {
for (auto v : op->getOperands()) {
if (ShapedType type = dyn_cast<ShapedType>(v.getType())) {
auto shape = type.getShape();
assert(shape.size() == 3);
if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") ||
!levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) {
return false;
}
}
}
}
return true;
}
// TransposeConv2d op: level check kH/kW, outpad, and stride
bool levelCheckTransposeConv2d(Operation *op) {
if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
if (ShapedType filterType =
transpose.getFilter().getType().dyn_cast<ShapedType>()) {
auto shape = filterType.getShape();
assert(shape.size() == 4);
// level check kernel sizes for kH and KW
if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") ||
!levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) {
return false;
}
}
for (auto p : transpose.getOutPad()) {
if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) {
return false;
}
}
for (auto s : transpose.getStride()) {
if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) {
return false;
}
}
}
return true;
}
// Resize op: level check max scales
bool levelCheckResize(Operation *op) {
if (auto resize = dyn_cast<tosa::ResizeOp>(op)) {
auto scale = resize.getScale();
int16_t scaleYN = scale[0];
int16_t scaleYD = scale[1];
int16_t scaleXN = scale[2];
int16_t scaleXD = scale[3];
if (!levelCheckScale(op, scaleYN / scaleYD,
"scale_y_n/scale_y_d <= MAX_SCALE") ||
!levelCheckScale(op, scaleXN / scaleXD,
"scale_x_n/scale_x_d <= MAX_SCALE")) {
return false;
}
}
return true;
}
// configure profile and level values from pass options profileName and
// levelName
void configLevelAndProfile() {
tosaLevel = TOSA_LEVEL_NONE;
if (level == TosaLevelEnum::EightK) {
tosaLevel = TOSA_LEVEL_EIGHTK;
}
}
bool CheckVariable(Operation *op);
bool CheckVariableReadOrWrite(Operation *op);
SmallVector<std::function<LogicalResult(Operation *)>> constCheckers;
TosaLevel tosaLevel;
DenseMap<StringAttr, mlir::Type> variablesMap;
};
LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
if (tosaLevel == TOSA_LEVEL_NONE) {
// no need to do level checks
return success();
}
if (!levelCheckRanks(op)) {
return failure();
}
// additional level checks from spec 0.70
if (!levelCheckPool<tosa::AvgPool2dOp>(op) ||
!levelCheckConv<tosa::Conv2DOp>(op) ||
!levelCheckConv<tosa::Conv3DOp>(op) ||
!levelCheckConv<tosa::DepthwiseConv2DOp>(op) ||
!levelCheckFFT<tosa::FFT2dOp>(op) ||
!levelCheckPool<tosa::MaxPool2dOp>(op) ||
!levelCheckFFT<tosa::RFFT2dOp>(op) || !levelCheckTransposeConv2d(op) ||
!levelCheckResize(op)) {
return failure();
}
return success();
}
inline bool CompatibleTypes(const mlir::Type &type,
const mlir::Type &declaredType) {
// for now, simply use type equality comparison
return type == declaredType;
}
bool TosaValidation::CheckVariable(Operation *op) {
if (isa<mlir::tosa::VariableOp>(op)) {
auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
if (variablesMap.count(nameAttr)) {
op->emitOpError() << "name has already been declared";
return false;
}
auto typeAttr = cast<mlir::TypeAttr>(op->getAttr("type"));
mlir::Type type = typeAttr.getValue();
variablesMap[nameAttr] = type;
}
return true;
}
bool TosaValidation::CheckVariableReadOrWrite(Operation *op) {
if (isa<mlir::tosa::VariableReadOp>(op) ||
isa<mlir::tosa::VariableWriteOp>(op)) {
auto nameAttr = cast<mlir::StringAttr>(op->getAttr("name"));
if (!variablesMap.count(nameAttr)) {
op->emitOpError() << "name has not been declared";
return false;
}
auto varType = variablesMap[nameAttr];
for (auto v : op->getOperands()) {
auto type = v.getType();
if (!CompatibleTypes(type, varType)) {
op->emitOpError() << "operand type does not equal variable type";
return false;
}
}
for (auto v : op->getResults()) {
auto type = v.getType();
if (!CompatibleTypes(type, varType)) {
op->emitOpError() << "result type does not equal variable type";
return false;
}
}
}
return true;
}
LogicalResult TosaValidation::applyVariableCheck(Operation *op) {
if (!CheckVariable(op) || !CheckVariableReadOrWrite(op)) {
return failure();
}
return success();
}
void TosaValidation::runOnOperation() {
configLevelAndProfile();
getOperation().walk([&](Operation *op) {
for (Value operand : op->getOperands()) {
if ((profile == TosaProfileEnum::BaseInference) &&
isa<FloatType>(getElementTypeOrSelf(operand))) {
return signalPassFailure();
}
if (getElementTypeOrSelf(operand).isF64()) {
return signalPassFailure();
}
}
// Some uses of TOSA rely on the constant operands of particular
// operations.
if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op)))
signalPassFailure();
// do level checks
if (failed(applyLevelCheck(op)))
signalPassFailure();
// do variable type checks
if (failed(applyVariableCheck(op)))
signalPassFailure();
});
}
} // namespace