//===- TosaTestPasses.cpp -------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Test passes to exercise TOSA helper functions. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define PASS_NAME "tosa-test-quant-utils" using namespace mlir; using namespace mlir::tosa; // This transformation converts quantized uint8 to quantized int8. The // construction of the new type invokes buildQTypeFromMinMax. Extracted from // TOSA legalization infrastructure. struct ConvertTosaNegateOp : public RewritePattern { explicit ConvertTosaNegateOp(MLIRContext *context) : RewritePattern(tosa::NegateOp::getOperationName(), 1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override; }; LogicalResult ConvertTosaNegateOp::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { auto tosaNegateOp = cast(op); auto inputType = dyn_cast(tosaNegateOp.getInput1().getType()); // skip if input is not ranked tensor type if (!inputType) return failure(); // skip if it's not ranked tensor type. auto outputType = dyn_cast(tosaNegateOp.getResult().getType()); if (!outputType) return failure(); // skip if output is not per-tensor quantized type. auto outputElementType = dyn_cast(outputType.getElementType()); if (!outputElementType) return failure(); // skip if output is not uint8. if (outputElementType.isSigned() || outputElementType.getStorageTypeIntegralWidth() != 8) return failure(); double typeRangeMin = double(outputElementType.getStorageTypeMin() - outputElementType.getZeroPoint()) * outputElementType.getScale(); double typeRangeMax = double(outputElementType.getStorageTypeMax() - outputElementType.getZeroPoint()) * outputElementType.getScale(); bool narrowRange = outputElementType.getStorageTypeMin() == 1; auto dstQConstType = RankedTensorType::get( outputType.getShape(), buildQTypeFromMinMax(rewriter, outputElementType.getExpressedType(), rewriter.getF64FloatAttr(typeRangeMin), rewriter.getF64FloatAttr(typeRangeMax), rewriter.getI32IntegerAttr( outputElementType.getStorageTypeIntegralWidth()), 0, true /* signed */, rewriter.getBoolAttr(narrowRange))); ElementsAttr inputElems; if (!matchPattern(tosaNegateOp.getInput1(), m_Constant(&inputElems))) return failure(); auto newConstOp = rewriter.create(op->getLoc(), dstQConstType, inputElems); auto newNegateOp = rewriter.create( op->getLoc(), dstQConstType, newConstOp.getResult()); rewriter.replaceOp(op, {newNegateOp.getResult()}); return success(); } // This transformation modifies the quantized output of a test conv2d input and // appends a TOSA rescale after it. The rescale op requires the invocation of // computeMultiplierAndShift. From TOSA legalization infrastructure. struct ConvertTosaConv2DOp : public RewritePattern { explicit ConvertTosaConv2DOp(MLIRContext *context) : RewritePattern(tosa::Conv2DOp::getOperationName(), 1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override; }; LogicalResult ConvertTosaConv2DOp::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { auto tosaConv2DOp = cast(op); auto inputType = dyn_cast(tosaConv2DOp.getInput().getType()); // skip if input is not ranked tensor type if (!inputType) return failure(); auto weightType = dyn_cast(tosaConv2DOp.getWeight().getType()); // skip if wt is not ranked tensor type if (!weightType) return failure(); // skip if it's not ranked tensor type. auto outputType = dyn_cast(tosaConv2DOp.getResult().getType()); if (!outputType) return failure(); auto inputQType = dyn_cast(inputType.getElementType()); auto weightQType = dyn_cast(weightType.getElementType()); auto outputQType = dyn_cast(outputType.getElementType()); // Works on quantized type only. if (!(inputQType && weightQType && outputQType)) return failure(); auto newTosaConv2DOpType = RankedTensorType::get(outputType.getShape(), rewriter.getIntegerType(32)); auto newTosaConv2DOp = rewriter.create( op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.getInput(), tosaConv2DOp.getWeight(), tosaConv2DOp.getBias(), tosaConv2DOp.getPadAttr(), tosaConv2DOp.getStrideAttr(), tosaConv2DOp.getDilationAttr()); // Create rescale to quantized type double inputScale = inputQType.getScale(); double weightScale = weightQType.getScale(); double outputScale = outputQType.getScale(); int64_t outputZp = outputQType.getZeroPoint(); double opTensorScale = (inputScale * weightScale) / outputScale; int32_t multiplier; int32_t shift; // Obtain the quantized scale = multiplier and shift. computeMultiplierAndShift(opTensorScale, multiplier, shift, 32); auto newTosaRescaleOp = rewriter.create( op->getLoc(), outputType, newTosaConv2DOp.getResult(), rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(outputZp), rewriter.getDenseI32ArrayAttr({multiplier}), rewriter.getDenseI8ArrayAttr({static_cast(shift)}), rewriter.getBoolAttr(true), rewriter.getBoolAttr(true), rewriter.getBoolAttr(false)); rewriter.replaceOp(op, {newTosaRescaleOp.getResult()}); return success(); } namespace { struct TosaTestQuantUtilAPI : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TosaTestQuantUtilAPI) StringRef getArgument() const final { return PASS_NAME; } StringRef getDescription() const final { return "TOSA Test: Exercise the APIs in QuantUtils.cpp."; } void runOnOperation() override; }; void TosaTestQuantUtilAPI::runOnOperation() { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); auto func = getOperation(); patterns.add(ctx); patterns.add(ctx); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } } // namespace namespace mlir { void registerTosaTestQuantUtilAPIPass() { PassRegistration(); } } // namespace mlir