//===----------- MultiBuffering.cpp ---------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements multi buffering transformation. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" using namespace mlir; #define DEBUG_TYPE "memref-transforms" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define DBGSNL() (llvm::dbgs() << "\n") /// Return true if the op fully overwrite the given `buffer` value. static bool overrideBuffer(Operation *op, Value buffer) { auto copyOp = dyn_cast(op); if (!copyOp) return false; return copyOp.getTarget() == buffer; } /// Replace the uses of `oldOp` with the given `val` and for subview uses /// propagate the type change. Changing the memref type may require propagating /// it through subview ops so we cannot just do a replaceAllUse but need to /// propagate the type change and erase old subview ops. static void replaceUsesAndPropagateType(RewriterBase &rewriter, Operation *oldOp, Value val) { SmallVector opsToDelete; SmallVector operandsToReplace; // Save the operand to replace / delete later (avoid iterator invalidation). // TODO: can we use an early_inc iterator? for (OpOperand &use : oldOp->getUses()) { // Non-subview ops will be replaced by `val`. auto subviewUse = dyn_cast(use.getOwner()); if (!subviewUse) { operandsToReplace.push_back(&use); continue; } // `subview(old_op)` is replaced by a new `subview(val)`. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(subviewUse); Type newType = memref::SubViewOp::inferRankReducedResultType( subviewUse.getType().getShape(), cast(val.getType()), subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), subviewUse.getStaticStrides()); Value newSubview = rewriter.create( subviewUse->getLoc(), cast(newType), val, subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), subviewUse.getMixedStrides()); // Ouch recursion ... is this really necessary? replaceUsesAndPropagateType(rewriter, subviewUse, newSubview); opsToDelete.push_back(use.getOwner()); } // Perform late replacement. // TODO: can we use an early_inc iterator? for (OpOperand *operand : operandsToReplace) { Operation *op = operand->getOwner(); rewriter.startOpModification(op); operand->set(val); rewriter.finalizeOpModification(op); } // Perform late op erasure. // TODO: can we use an early_inc iterator? for (Operation *op : opsToDelete) rewriter.eraseOp(op); } // Transformation to do multi-buffering/array expansion to remove dependencies // on the temporary allocation between consecutive loop iterations. // Returns success if the transformation happened and failure otherwise. // This is not a pattern as it requires propagating the new memref type to its // uses and requires updating subview ops. FailureOr mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, unsigned multiBufferingFactor, bool skipOverrideAnalysis) { LLVM_DEBUG(DBGS() << "Start multibuffering: " << allocOp << "\n"); DominanceInfo dom(allocOp->getParentOp()); LoopLikeOpInterface candidateLoop; for (Operation *user : allocOp->getUsers()) { auto parentLoop = user->getParentOfType(); if (!parentLoop) { if (isa(user)) { // Allow dealloc outside of any loop. // TODO: The whole precondition function here is very brittle and will // need to rethought an isolated into a cleaner analysis. continue; } LLVM_DEBUG(DBGS() << "--no parent loop -> fail\n"); LLVM_DEBUG(DBGS() << "----due to user: " << *user << "\n"); return failure(); } if (!skipOverrideAnalysis) { /// Make sure there is no loop-carried dependency on the allocation. if (!overrideBuffer(user, allocOp.getResult())) { LLVM_DEBUG(DBGS() << "--Skip user: found loop-carried dependence\n"); continue; } // If this user doesn't dominate all the other users keep looking. if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { return !dom.dominates(user, otherUser); })) { LLVM_DEBUG( DBGS() << "--Skip user: does not dominate all other users\n"); continue; } } else { if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) { return !isa(otherUser) && !parentLoop->isProperAncestor(otherUser); })) { LLVM_DEBUG( DBGS() << "--Skip user: not all other users are in the parent loop\n"); continue; } } candidateLoop = parentLoop; break; } if (!candidateLoop) { LLVM_DEBUG(DBGS() << "Skip alloc: no candidate loop\n"); return failure(); } std::optional inductionVar = candidateLoop.getSingleInductionVar(); std::optional lowerBound = candidateLoop.getSingleLowerBound(); std::optional singleStep = candidateLoop.getSingleStep(); if (!inductionVar || !lowerBound || !singleStep || !llvm::hasSingleElement(candidateLoop.getLoopRegions())) { LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb, step or region\n"); return failure(); } if (!dom.dominates(allocOp.getOperation(), candidateLoop)) { LLVM_DEBUG(DBGS() << "Skip alloc: does not dominate candidate loop\n"); return failure(); } LLVM_DEBUG(DBGS() << "Start multibuffering loop: " << candidateLoop << "\n"); // 1. Construct the multi-buffered memref type. ArrayRef originalShape = allocOp.getType().getShape(); SmallVector multiBufferedShape{multiBufferingFactor}; llvm::append_range(multiBufferedShape, originalShape); LLVM_DEBUG(DBGS() << "--original type: " << allocOp.getType() << "\n"); MemRefType mbMemRefType = MemRefType::Builder(allocOp.getType()) .setShape(multiBufferedShape) .setLayout(MemRefLayoutAttrInterface()); LLVM_DEBUG(DBGS() << "--multi-buffered type: " << mbMemRefType << "\n"); // 2. Create the multi-buffered alloc. Location loc = allocOp->getLoc(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(allocOp); auto mbAlloc = rewriter.create( loc, mbMemRefType, ValueRange{}, allocOp->getAttrs()); LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n"); // 3. Within the loop, build the modular leading index (i.e. each loop // iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor). rewriter.setInsertionPointToStart( &candidateLoop.getLoopRegions().front()->front()); Value ivVal = *inductionVar; Value lbVal = getValueOrCreateConstantIndexOp(rewriter, loc, *lowerBound); Value stepVal = getValueOrCreateConstantIndexOp(rewriter, loc, *singleStep); AffineExpr iv, lb, step; bindDims(rewriter.getContext(), iv, lb, step); Value bufferIndex = affine::makeComposedAffineApply( rewriter, loc, ((iv - lb).floorDiv(step)) % multiBufferingFactor, {ivVal, lbVal, stepVal}); LLVM_DEBUG(DBGS() << "--multi-buffered indexing: " << bufferIndex << "\n"); // 4. Build the subview accessing the particular slice, taking modular // rotation into account. int64_t mbMemRefTypeRank = mbMemRefType.getRank(); IntegerAttr zero = rewriter.getIndexAttr(0); IntegerAttr one = rewriter.getIndexAttr(1); SmallVector offsets(mbMemRefTypeRank, zero); SmallVector sizes(mbMemRefTypeRank, one); SmallVector strides(mbMemRefTypeRank, one); // Offset is [bufferIndex, 0 ... 0 ]. offsets.front() = bufferIndex; // Sizes is [1, original_size_0 ... original_size_n ]. for (int64_t i = 0, e = originalShape.size(); i != e; ++i) sizes[1 + i] = rewriter.getIndexAttr(originalShape[i]); // Strides is [1, 1 ... 1 ]. auto dstMemref = cast(memref::SubViewOp::inferRankReducedResultType( originalShape, mbMemRefType, offsets, sizes, strides)); Value subview = rewriter.create(loc, dstMemref, mbAlloc, offsets, sizes, strides); LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n"); // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to // handle dealloc uses separately.. for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) { auto deallocOp = dyn_cast(use.getOwner()); if (!deallocOp) continue; OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(deallocOp); auto newDeallocOp = rewriter.create(deallocOp->getLoc(), mbAlloc); (void)newDeallocOp; LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n"); rewriter.eraseOp(deallocOp); } // 6. RAUW with the particular slice, taking modular rotation into account. replaceUsesAndPropagateType(rewriter, allocOp, subview); // 7. Finally, erase the old allocOp. rewriter.eraseOp(allocOp); return mbAlloc; } FailureOr mlir::memref::multiBuffer(memref::AllocOp allocOp, unsigned multiBufferingFactor, bool skipOverrideAnalysis) { IRRewriter rewriter(allocOp->getContext()); return multiBuffer(rewriter, allocOp, multiBufferingFactor, skipOverrideAnalysis); }