//===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { #define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALG #include "mlir/Dialect/Linalg/Passes.h.inc" } // namespace mlir using namespace mlir; static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { if (!OpTrait::hasElementwiseMappableTraits(op)) return false; // TODO: The conversion pattern can be made to work for `any_of` here, but // it's more complex as it requires tracking which operands are scalars. return llvm::all_of(op->getOperandTypes(), [](Type type) { return isa(type); }); } /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over /// the result types and return a list of values such that, for each result type /// `t` and value `v` at the same index `idx`: /// 1. `v.getType() == t` /// 2. If an operand of `op` has type `t`, let `operand_first` be the first /// such operand. Then`v == operand_first`. /// 3. Otherwise, v is a newly created `tensor::EmptyOp` with: /// a. Static and dynamic dims extracted from the first operand of `op`. /// b. Elemental type equal to the elemental type of `t`. /// /// This is sufficient because ElementwiseMappable guarantees that "The static /// types of all vector (resp. tensor) operands and results must have the same /// shape". static SmallVector getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) { assert(isElementwiseMappableOpOnRankedTensors(op)); Location loc = op->getLoc(); ValueRange operands = op->getOperands(); TypeRange rankedTensorTypes = op->getResultTypes(); SmallVector res; res.reserve(rankedTensorTypes.size()); for (Type t : rankedTensorTypes) { // Try to find an operand with type matching the result tensor. bool found = false; for (Value v : operands) { if (v.getType() == t) { found = true; res.push_back(v); break; } } if (found) continue; // Extract static / dynamic shape mix from the first operand. res.push_back(b.create( loc, tensor::getMixedSizes(b, loc, operands.front()), cast(t).getElementType())); } return res; } namespace { struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final { if (!isElementwiseMappableOpOnRankedTensors(op)) return rewriter.notifyMatchFailure( op, "requires elementwise op on ranked tensors"); auto rank = cast(op->getResult(0).getType()).getRank(); SmallVector indexingMaps( op->getNumResults() + op->getNumOperands(), rewriter.getMultiDimIdentityMap(rank)); SmallVector iteratorTypes( rank, utils::IteratorType::parallel); auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op); rewriter.replaceOpWithNewOp( op, /*resultTensorTypes=*/op->getResultTypes(), /*inputs=*/op->getOperands(), /*outputs=*/outputs, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, /*bodyBuilder=*/ [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { auto resultTypes = llvm::to_vector<6>( llvm::map_range(op->getResultTypes(), [](Type type) { return cast(type).getElementType(); })); auto *scalarOp = builder.create(loc, op->getName().getIdentifier(), regionArgs.take_front(op->getNumOperands()), resultTypes, op->getAttrs()); builder.create(loc, scalarOp->getResults()); }); return success(); } }; } // namespace void mlir::linalg::populateElementwiseToLinalgConversionPatterns( RewritePatternSet &patterns) { patterns.add( patterns.getContext()); } namespace { class ConvertElementwiseToLinalgPass : public impl::ConvertElementwiseToLinalgBase< ConvertElementwiseToLinalgPass> { void runOnOperation() final { auto *func = getOperation(); auto *context = &getContext(); ConversionTarget target(*context); RewritePatternSet patterns(context); mlir::linalg::populateElementwiseToLinalgConversionPatterns(patterns); target.markUnknownOpDynamicallyLegal([](Operation *op) { return !isElementwiseMappableOpOnRankedTensors(op); }); if (failed(applyPartialConversion(func, target, std::move(patterns)))) signalPassFailure(); } }; } // namespace std::unique_ptr mlir::createConvertElementwiseToLinalgPass() { return std::make_unique(); }