//===- TosaMakeBroadcastable.cpp ------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Insert reshape to binary op's input if needed to match rank // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { namespace tosa { #define GEN_PASS_DEF_TOSAMAKEBROADCASTABLE #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" } // namespace tosa } // namespace mlir using namespace mlir; using namespace mlir::tosa; namespace { /// Common code to create the reshape op where necessary to make the rank of the /// operations equal. input1 and input2 will be updated when the rank has /// changed. The caller is expected to use these to rewrite the original /// operator with the RESHAPE now in the graph. /// return failure when (1) no reshape needed, or (2) output_type is specified /// and it has different rank LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, Location loc, RankedTensorType outputType, Value &input1, Value &input2) { auto input1Ty = dyn_cast(input1.getType()); auto input2Ty = dyn_cast(input2.getType()); if (!input1Ty || !input2Ty) { return rewriter.notifyMatchFailure(loc, "input not a ranked tensor"); } int64_t input1Rank = input1Ty.getRank(); int64_t input2Rank = input2Ty.getRank(); if (input1Rank == input2Rank) return rewriter.notifyMatchFailure(loc, "cannot rewrite as its already correct"); Value input1Copy = input1; Value input2Copy = input2; if (EqualizeRanks(rewriter, loc, input1Copy, input2Copy).failed()) { return rewriter.notifyMatchFailure(loc, "failed to reshape inputs"); } // Verify the rank agrees with the output type if the output type is ranked. if (outputType) { if (outputType.getRank() != llvm::cast(input1Copy.getType()).getRank() || outputType.getRank() != llvm::cast(input2Copy.getType()).getRank()) return rewriter.notifyMatchFailure( loc, "the reshaped type doesn't agrees with the ranked output type"); } input1 = input1Copy; input2 = input2Copy; return success(); } template struct ConvertTosaOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy tosaBinaryOp, PatternRewriter &rewriter) const override { Value input1 = tosaBinaryOp.getInput1(); Value input2 = tosaBinaryOp.getInput2(); Value output = tosaBinaryOp.getResult(); auto outputType = dyn_cast(output.getType()); if (!outputType) return failure(); if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, input1, input2) .failed()) return failure(); rewriter.replaceOpWithNewOp(tosaBinaryOp, outputType, input1, input2); return success(); } }; // The MulOp has an extra parameter 'shift' not present in other elementwise // binary ops, that necessitates special handling of its builder. template <> struct ConvertTosaOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp, PatternRewriter &rewriter) const override { Value input1 = tosaBinaryOp.getInput1(); Value input2 = tosaBinaryOp.getInput2(); int32_t shift = tosaBinaryOp.getShift(); Value output = tosaBinaryOp.getResult(); auto outputType = dyn_cast(output.getType()); if (!outputType) return failure(); if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, input1, input2) .failed()) return failure(); rewriter.replaceOpWithNewOp(tosaBinaryOp, outputType, input1, input2, shift); return success(); } }; // The ArithmeticRightShiftOp has an extra parameter 'round' not present in // other elementwise binary ops, that necessitates special handling of its // builder. template <> struct ConvertTosaOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp, PatternRewriter &rewriter) const override { Value input1 = tosaBinaryOp.getInput1(); Value input2 = tosaBinaryOp.getInput2(); int32_t round = tosaBinaryOp.getRound(); Value output = tosaBinaryOp.getResult(); auto outputType = dyn_cast(output.getType()); if (!outputType) return failure(); if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, input1, input2) .failed()) return failure(); rewriter.replaceOpWithNewOp( tosaBinaryOp, outputType, input1, input2, round); return success(); } }; template <> struct ConvertTosaOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::SelectOp tosaOp, PatternRewriter &rewriter) const override { Value input1 = tosaOp.getPred(); Value input2 = tosaOp.getOnTrue(); Value input3 = tosaOp.getOnFalse(); Value output = tosaOp.getResult(); auto outputType = dyn_cast(output.getType()); if (!outputType) return rewriter.notifyMatchFailure(tosaOp, "output not a ranked tensor"); // Apply broadcasting to each pair of inputs separately, and chain them as // compound as below so that the broadcasting happens all at once. bool reshaped1 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, input1, input2) .succeeded(); bool reshaped2 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, input1, input3) .succeeded(); bool reshaped3 = reshapeLowerToHigher(rewriter, tosaOp.getLoc(), outputType, input2, input3) .succeeded(); if (!reshaped1 && !reshaped2 && !reshaped3) return rewriter.notifyMatchFailure( tosaOp, "cannot rewrite as the rank of all operands is already aligned"); int32_t result1Rank = cast(input1.getType()).getRank(); int32_t result2Rank = cast(input2.getType()).getRank(); int32_t result3Rank = cast(input3.getType()).getRank(); int32_t outputRank = outputType.getRank(); if ((result1Rank != result2Rank) || (result2Rank != result3Rank) || (result1Rank != outputRank)) return rewriter.notifyMatchFailure( tosaOp, "not all ranks are aligned with each other"); rewriter.replaceOpWithNewOp(tosaOp, outputType, input1, input2, input3); return success(); } }; } // namespace namespace { /// Pass that enables broadcast by making all input arrays have the same /// number of dimensions. Insert RESHAPE operations to lower rank operand struct TosaMakeBroadcastable : public tosa::impl::TosaMakeBroadcastableBase { public: void runOnOperation() override { auto func = getOperation(); RewritePatternSet patterns(func.getContext()); MLIRContext *ctx = func.getContext(); // Add the generated patterns to the list. patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); } }; } // namespace std::unique_ptr mlir::tosa::createTosaMakeBroadcastablePass() { return std::make_unique(); }