//===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements functions concerned with hoisting invariant operations // in the context of Linalg transformations. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dominance.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" using llvm::dbgs; #define DEBUG_TYPE "linalg-hoisting" #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") using namespace mlir; using namespace mlir::linalg; static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, LoopLikeOpInterface loop) { Value source = transferRead.getSource(); // Skip view-like Ops and retrive the actual soruce Operation while (auto srcOp = dyn_cast_or_null(source.getDefiningOp())) source = srcOp.getViewSource(); llvm::SmallVector users(source.getUsers().begin(), source.getUsers().end()); llvm::SmallDenseSet processed; while (!users.empty()) { Operation *user = users.pop_back_val(); // If the user has already been processed skip. if (!processed.insert(user).second) continue; if (auto viewLike = dyn_cast(user)) { users.append(viewLike->getUsers().begin(), viewLike->getUsers().end()); continue; } if (isMemoryEffectFree(user) || isa(user)) continue; if (!loop->isAncestor(user)) continue; return false; } return true; } void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) { bool changed = true; while (changed) { changed = false; // First move loop invariant ops outside of their loop. This needs to be // done before as we cannot move ops without interrupting the function walk. func.walk( [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); func.walk([&](vector::TransferReadOp transferRead) { if (!isa(transferRead.getShapedType())) return WalkResult::advance(); LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *transferRead.getOperation() << "\n"); auto loop = dyn_cast(transferRead->getParentOp()); LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() << "\n"); if (!isa_and_nonnull(loop)) return WalkResult::advance(); LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() << "\n"); SetVector forwardSlice; getForwardSlice(transferRead.getOperation(), &forwardSlice); // Look for the last TransferWriteOp in the forwardSlice of // `transferRead` that operates on the same memref. vector::TransferWriteOp transferWrite; for (auto *sliceOp : llvm::reverse(forwardSlice)) { auto candidateWrite = dyn_cast(sliceOp); if (!candidateWrite || candidateWrite.getSource() != transferRead.getSource()) continue; transferWrite = candidateWrite; } // All operands of the TransferRead must be defined outside of the loop. for (auto operand : transferRead.getOperands()) if (!loop.isDefinedOutsideOfLoop(operand)) return WalkResult::advance(); // Only hoist transfer_read / transfer_write pairs and singleton // transfer_reads for now. if (!transferWrite) { // Make sure there are no other accesses to the memref before // hoisting transfer_read. if (noAliasingUseInLoop(transferRead, loop)) loop.moveOutOfLoop(transferRead); return WalkResult::advance(); } LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() << "\n"); // Approximate aliasing by checking that: // 1. indices, vector type and permutation map are the same (i.e., the // transfer_read/transfer_write ops are matching), // 2. source operands for transfer.{read|write} do not originate from // Ops implementing ViewLikeOpInterface. // 3. no other operations in the loop access the same memref except // for transfer_read/transfer_write accessing statically disjoint // slices. if (transferRead.getIndices() != transferWrite.getIndices() || transferRead.getVectorType() != transferWrite.getVectorType() || transferRead.getPermutationMap() != transferWrite.getPermutationMap()) return WalkResult::advance(); auto *source = transferRead.getSource().getDefiningOp(); if (source && isa_and_nonnull(source)) return WalkResult::advance(); source = transferWrite.getSource().getDefiningOp(); if (source && isa_and_nonnull(source)) return WalkResult::advance(); // TODO: may want to memoize this information for performance but it // likely gets invalidated often. DominanceInfo dom(loop); if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) return WalkResult::advance(); for (auto &use : transferRead.getSource().getUses()) { if (!loop->isAncestor(use.getOwner())) continue; if (use.getOwner() == transferRead.getOperation() || use.getOwner() == transferWrite.getOperation()) continue; if (auto transferWriteUse = dyn_cast(use.getOwner())) { if (!vector::isDisjointTransferSet( cast(*transferWrite), cast(*transferWriteUse), /*testDynamicValueUsingBounds=*/true)) return WalkResult::advance(); } else if (auto transferReadUse = dyn_cast(use.getOwner())) { if (!vector::isDisjointTransferSet( cast(*transferWrite), cast(*transferReadUse), /*testDynamicValueUsingBounds=*/true)) return WalkResult::advance(); } else { // Unknown use, we cannot prove that it doesn't alias with the // transferRead/transferWrite operations. return WalkResult::advance(); } } // Hoist read before. loop.moveOutOfLoop(transferRead); // Hoist write after. transferWrite->moveAfter(loop); // Rewrite `loop` with new yields by cloning and erase the original loop. IRRewriter rewriter(transferRead.getContext()); NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc, ArrayRef newBBArgs) { return SmallVector{transferWrite.getVector()}; }; auto maybeNewLoop = loop.replaceWithAdditionalYields( rewriter, transferRead.getVector(), /*replaceInitOperandUsesInLoop=*/true, yieldFn); if (failed(maybeNewLoop)) return WalkResult::interrupt(); transferWrite.getVectorMutable().assign( maybeNewLoop->getOperation()->getResults().back()); changed = true; // Need to interrupt and restart because erasing the loop messes up // the walk. return WalkResult::interrupt(); }); } }