//===- TosaToLinalgPass.cpp - Lowering Tosa to Linalg Dialect -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This transformation pass legalizes Tosa operations to the Linalg dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" namespace mlir { #define GEN_PASS_DEF_TOSATOLINALG #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; namespace { struct TosaToLinalg : public impl::TosaToLinalgBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry .insert(); } void runOnOperation() override { RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); target.addLegalDialect(); target.addIllegalDialect(); // Not every TOSA op can be legalized to linalg. target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); target.addLegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); FunctionOpInterface func = getOperation(); mlir::tosa::populateTosaToLinalgConversionPatterns(&patterns); if (failed(applyFullConversion(func, target, std::move(patterns)))) signalPassFailure(); } }; } // namespace std::unique_ptr mlir::tosa::createTosaToLinalg() { return std::make_unique(); } void mlir::tosa::addTosaToLinalgPasses( OpPassManager &pm, const TosaToLinalgOptions &options, const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions, tosa::TosaValidationOptions const &validationOptions) { // Optional decompositions are designed to benefit linalg. if (!options.disableTosaDecompositions) pm.addNestedPass(tosa::createTosaOptionalDecompositions()); pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(tosa::createTosaInferShapesPass()); pm.addNestedPass(tosa::createTosaMakeBroadcastablePass()); pm.addNestedPass( tosa::createTosaToLinalgNamed(tosaToLinalgNamedOptions)); pm.addNestedPass(createCanonicalizerPass()); // TODO: Remove pass that operates on const tensor and enable optionality pm.addNestedPass(tosa::createTosaLayerwiseConstantFoldPass( {options.aggressiveReduceConstant})); pm.addNestedPass(tosa::createTosaMakeBroadcastablePass()); pm.addPass(tosa::createTosaValidation(validationOptions)); pm.addNestedPass(tosa::createTosaToLinalg()); } //===----------------------------------------------------------------------===// // Pipeline registration. //===----------------------------------------------------------------------===// void mlir::tosa::registerTosaToLinalgPipelines() { PassPipelineRegistration<>( "tosa-to-linalg-pipeline", "The default pipeline for converting TOSA operators to the equivalent " "operations using the tensor operations in LinAlg as well as LinAlg " "named operations.", [](OpPassManager &pm) { TosaToLinalgOptions tosaToLinalgOptions; TosaToLinalgNamedOptions tosaToLinalgNamedOptions; tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions, tosaToLinalgNamedOptions, /* validationOptions = */ {tosa::TosaProfileEnum::BaseInference, /* StrictOperationSpecAlignment = */ true, tosa::TosaLevelEnum::EightK}); }); }