//===- BufferDeallocationSimplification.cpp -------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements logic for optimizing `bufferization.dealloc` operations // that requires more analysis than what can be supported by regular // canonicalization patterns. // //===----------------------------------------------------------------------===// #include "mlir/Analysis/AliasAnalysis.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { namespace bufferization { #define GEN_PASS_DEF_BUFFERDEALLOCATIONSIMPLIFICATION #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc" } // namespace bufferization } // namespace mlir using namespace mlir; using namespace mlir::bufferization; //===----------------------------------------------------------------------===// // Helpers //===----------------------------------------------------------------------===// static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter) { if (deallocOp.getMemrefs() == memrefs && deallocOp.getConditions() == conditions) return failure(); rewriter.modifyOpInPlace(deallocOp, [&]() { deallocOp.getMemrefsMutable().assign(memrefs); deallocOp.getConditionsMutable().assign(conditions); }); return success(); } /// Given a memref value, return the "base" value by skipping over all /// ViewLikeOpInterface ops (if any) in the reverse use-def chain. static Value getViewBase(Value value) { while (auto viewLikeOp = value.getDefiningOp()) value = viewLikeOp.getViewSource(); return value; } /// Return "true" if the given values are guaranteed to be different (and /// non-aliasing) allocations based on the fact that one value is the result /// of an allocation and the other value is a block argument of a parent block. /// Note: This is a best-effort analysis that will eventually be replaced by a /// proper "is same allocation" analysis. This function may return "false" even /// though the two values are distinct allocations. static bool distinctAllocAndBlockArgument(Value v1, Value v2) { Value v1Base = getViewBase(v1); Value v2Base = getViewBase(v2); auto areDistinct = [](Value v1, Value v2) { if (Operation *op = v1.getDefiningOp()) if (hasEffect(op, v1)) if (auto bbArg = dyn_cast(v2)) if (bbArg.getOwner()->findAncestorOpInBlock(*op)) return true; return false; }; return areDistinct(v1Base, v2Base) || areDistinct(v2Base, v1Base); } /// Checks if `memref` may potentially alias a MemRef in `otherList`. It is /// often a requirement of optimization patterns that there cannot be any /// aliasing memref in order to perform the desired simplification. static bool potentiallyAliasesMemref(AliasAnalysis &analysis, ValueRange otherList, Value memref) { for (auto other : otherList) { if (distinctAllocAndBlockArgument(other, memref)) continue; if (!analysis.alias(other, memref).isNo()) return true; } return false; } //===----------------------------------------------------------------------===// // Patterns //===----------------------------------------------------------------------===// namespace { /// Remove values from the `memref` operand list that are also present in the /// `retained` list (or a guaranteed alias of it) because they will never /// actually be deallocated. However, we also need to be certain about which /// other memrefs in the `retained` list can alias, i.e., there must not by any /// may-aliasing memref. This is necessary because the `dealloc` operation is /// defined to return one `i1` value per memref in the `retained` list which /// represents the disjunction of the condition values corresponding to all /// aliasing values in the `memref` list. In particular, this means that if /// there is some value R in the `retained` list which aliases with a value M in /// the `memref` list (but can only be staticaly determined to may-alias) and M /// is also present in the `retained` list, then it would be illegal to remove M /// because the result corresponding to R would be computed incorrectly /// afterwards. Because we require an alias analysis, this pattern cannot be /// applied as a regular canonicalization pattern. /// /// Example: /// ```mlir /// %0:3 = bufferization.dealloc (%m0 : ...) if (%cond0) /// retain (%m0, %r0, %r1 : ...) /// ``` /// is canonicalized to /// ```mlir /// // bufferization.dealloc without memrefs and conditions returns %false for /// // every retained value /// %0:3 = bufferization.dealloc retain (%m0, %r0, %r1 : ...) /// %1 = arith.ori %0#0, %cond0 : i1 /// // replace %0#0 with %1 /// ``` /// given that `%r0` and `%r1` may not alias with `%m0`. struct RemoveDeallocMemrefsContainedInRetained : public OpRewritePattern { RemoveDeallocMemrefsContainedInRetained(MLIRContext *context, AliasAnalysis &aliasAnalysis) : OpRewritePattern(context), aliasAnalysis(aliasAnalysis) {} /// The passed 'memref' must not have a may-alias relation to any retained /// memref, and at least one must-alias relation. If there is no must-aliasing /// memref in the retain list, we cannot simply remove the memref as there /// could be situations in which it actually has to be deallocated. If it's /// no-alias, then just proceed, if it's must-alias we need to update the /// updated condition returned by the dealloc operation for that alias. LogicalResult handleOneMemref(DeallocOp deallocOp, Value memref, Value cond, PatternRewriter &rewriter) const { rewriter.setInsertionPointAfter(deallocOp); // Check that there is no may-aliasing memref and that at least one memref // in the retain list aliases (because otherwise it might have to be // deallocated in some situations and can thus not be dropped). bool atLeastOneMustAlias = false; for (Value retained : deallocOp.getRetained()) { AliasResult analysisResult = aliasAnalysis.alias(retained, memref); if (analysisResult.isMay()) return failure(); if (analysisResult.isMust() || analysisResult.isPartial()) atLeastOneMustAlias = true; } if (!atLeastOneMustAlias) return failure(); // Insert arith.ori operations to update the corresponding dealloc result // values to incorporate the condition of the must-aliasing memref such that // we can remove that operand later on. for (auto [i, retained] : llvm::enumerate(deallocOp.getRetained())) { Value updatedCondition = deallocOp.getUpdatedConditions()[i]; AliasResult analysisResult = aliasAnalysis.alias(retained, memref); if (analysisResult.isMust() || analysisResult.isPartial()) { auto disjunction = rewriter.create( deallocOp.getLoc(), updatedCondition, cond); rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(), disjunction); } } return success(); } LogicalResult matchAndRewrite(DeallocOp deallocOp, PatternRewriter &rewriter) const override { // There must not be any duplicates in the retain list anymore because we // would miss updating one of the result values otherwise. DenseSet retained(deallocOp.getRetained().begin(), deallocOp.getRetained().end()); if (retained.size() != deallocOp.getRetained().size()) return failure(); SmallVector newMemrefs, newConditions; for (auto [memref, cond] : llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { if (succeeded(handleOneMemref(deallocOp, memref, cond, rewriter))) continue; if (auto extractOp = memref.getDefiningOp()) if (succeeded(handleOneMemref(deallocOp, extractOp.getOperand(), cond, rewriter))) continue; newMemrefs.push_back(memref); newConditions.push_back(cond); } // Return failure if we don't change anything such that we don't run into an // infinite loop of pattern applications. return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, rewriter); } private: AliasAnalysis &aliasAnalysis; }; /// Remove memrefs from the `retained` list which are guaranteed to not alias /// any memref in the `memrefs` list. The corresponding result value can be /// replaced with `false` in that case according to the operation description. /// /// Example: /// ```mlir /// %0:2 = bufferization.dealloc (%m : memref<2xi32>) if (%cond) /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>) /// return %0#0, %0#1 /// ``` /// can be canonicalized to the following given that `%r0` and `%r1` do not /// alias `%m`: /// ```mlir /// bufferization.dealloc (%m : memref<2xi32>) if (%cond) /// return %false, %false /// ``` struct RemoveRetainedMemrefsGuaranteedToNotAlias : public OpRewritePattern { RemoveRetainedMemrefsGuaranteedToNotAlias(MLIRContext *context, AliasAnalysis &aliasAnalysis) : OpRewritePattern(context), aliasAnalysis(aliasAnalysis) {} LogicalResult matchAndRewrite(DeallocOp deallocOp, PatternRewriter &rewriter) const override { SmallVector newRetainedMemrefs, replacements; for (auto retainedMemref : deallocOp.getRetained()) { if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(), retainedMemref)) { newRetainedMemrefs.push_back(retainedMemref); replacements.push_back({}); continue; } replacements.push_back(rewriter.create( deallocOp.getLoc(), rewriter.getBoolAttr(false))); } if (newRetainedMemrefs.size() == deallocOp.getRetained().size()) return failure(); auto newDeallocOp = rewriter.create( deallocOp.getLoc(), deallocOp.getMemrefs(), deallocOp.getConditions(), newRetainedMemrefs); int i = 0; for (auto &repl : replacements) { if (!repl) repl = newDeallocOp.getUpdatedConditions()[i++]; } rewriter.replaceOp(deallocOp, replacements); return success(); } private: AliasAnalysis &aliasAnalysis; }; /// Split off memrefs to separate dealloc operations to reduce the number of /// runtime checks required and enable further canonicalization of the new and /// simpler dealloc operations. A memref can be split off if it is guaranteed to /// not alias with any other memref in the `memref` operand list. The results /// of the old and the new dealloc operation have to be combined by computing /// the element-wise disjunction of them. /// /// Example: /// ```mlir /// %0:2 = bufferization.dealloc (%m0, %m1 : memref<2xi32>, memref<2xi32>) /// if (%cond0, %cond1) /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>) /// return %0#0, %0#1 /// ``` /// Given that `%m0` is guaranteed to never alias with `%m1`, the above IR is /// canonicalized to the following, thus reducing the number of runtime alias /// checks by 1 and potentially enabling further canonicalization of the new /// split-up dealloc operations. /// ```mlir /// %0:2 = bufferization.dealloc (%m0 : memref<2xi32>) if (%cond0) /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>) /// %1:2 = bufferization.dealloc (%m1 : memref<2xi32>) if (%cond1) /// retain (%r0, %r1 : memref<2xi32>, memref<2xi32>) /// %2 = arith.ori %0#0, %1#0 /// %3 = arith.ori %0#1, %1#1 /// return %2, %3 /// ``` struct SplitDeallocWhenNotAliasingAnyOther : public OpRewritePattern { SplitDeallocWhenNotAliasingAnyOther(MLIRContext *context, AliasAnalysis &aliasAnalysis) : OpRewritePattern(context), aliasAnalysis(aliasAnalysis) {} LogicalResult matchAndRewrite(DeallocOp deallocOp, PatternRewriter &rewriter) const override { Location loc = deallocOp.getLoc(); if (deallocOp.getMemrefs().size() <= 1) return failure(); SmallVector remainingMemrefs, remainingConditions; SmallVector> updatedConditions; for (int64_t i = 0, e = deallocOp.getMemrefs().size(); i < e; ++i) { Value memref = deallocOp.getMemrefs()[i]; Value cond = deallocOp.getConditions()[i]; SmallVector otherMemrefs(deallocOp.getMemrefs()); otherMemrefs.erase(otherMemrefs.begin() + i); // Check if `memref` can split off into a separate bufferization.dealloc. if (potentiallyAliasesMemref(aliasAnalysis, otherMemrefs, memref)) { // `memref` alias with other memrefs, do not split off. remainingMemrefs.push_back(memref); remainingConditions.push_back(cond); continue; } // Create new bufferization.dealloc op for `memref`. auto newDeallocOp = rewriter.create(loc, memref, cond, deallocOp.getRetained()); updatedConditions.push_back( llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions()))); } // Fail if no memref was split off. if (remainingMemrefs.size() == deallocOp.getMemrefs().size()) return failure(); // Create bufferization.dealloc op for all remaining memrefs. auto newDeallocOp = rewriter.create( loc, remainingMemrefs, remainingConditions, deallocOp.getRetained()); // Bit-or all conditions. SmallVector replacements = llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions())); for (auto additionalConditions : updatedConditions) { assert(replacements.size() == additionalConditions.size() && "expected same number of updated conditions"); for (int64_t i = 0, e = replacements.size(); i < e; ++i) { replacements[i] = rewriter.create( loc, replacements[i], additionalConditions[i]); } } rewriter.replaceOp(deallocOp, replacements); return success(); } private: AliasAnalysis &aliasAnalysis; }; /// Check for every retained memref if a must-aliasing memref exists in the /// 'memref' operand list with constant 'true' condition. If so, we can replace /// the operation result corresponding to that retained memref with 'true'. If /// this condition holds for all retained memrefs we can also remove the /// aliasing memrefs and their conditions since they will never be deallocated /// due to the must-alias and we don't need them to compute the result value /// anymore since it got replaced with 'true'. /// /// Example: /// ```mlir /// %0:2 = bufferization.dealloc (%arg0, %arg1, %arg2 : ...) /// if (%true, %true, %true) /// retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) /// ``` /// becomes /// ```mlir /// %0:2 = bufferization.dealloc (%arg2 : memref<2xi32>) if (%true) /// retain (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) /// // replace %0#0 with %true /// // replace %0#1 with %true /// ``` /// Note that the dealloc operation will still have the result values, but they /// don't have uses anymore. struct RetainedMemrefAliasingAlwaysDeallocatedMemref : public OpRewritePattern { RetainedMemrefAliasingAlwaysDeallocatedMemref(MLIRContext *context, AliasAnalysis &aliasAnalysis) : OpRewritePattern(context), aliasAnalysis(aliasAnalysis) {} LogicalResult matchAndRewrite(DeallocOp deallocOp, PatternRewriter &rewriter) const override { BitVector aliasesWithConstTrueMemref(deallocOp.getRetained().size()); SmallVector newMemrefs, newConditions; for (auto [memref, cond] : llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { bool canDropMemref = false; for (auto [i, retained, res] : llvm::enumerate( deallocOp.getRetained(), deallocOp.getUpdatedConditions())) { if (!matchPattern(cond, m_One())) continue; AliasResult analysisResult = aliasAnalysis.alias(retained, memref); if (analysisResult.isMust() || analysisResult.isPartial()) { rewriter.replaceAllUsesWith(res, cond); aliasesWithConstTrueMemref[i] = true; canDropMemref = true; continue; } // TODO: once our alias analysis is powerful enough we can remove the // rest of this loop body auto extractOp = memref.getDefiningOp(); if (!extractOp) continue; AliasResult extractAnalysisResult = aliasAnalysis.alias(retained, extractOp.getOperand()); if (extractAnalysisResult.isMust() || extractAnalysisResult.isPartial()) { rewriter.replaceAllUsesWith(res, cond); aliasesWithConstTrueMemref[i] = true; canDropMemref = true; } } if (!canDropMemref) { newMemrefs.push_back(memref); newConditions.push_back(cond); } } if (!aliasesWithConstTrueMemref.all()) return failure(); return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, rewriter); } private: AliasAnalysis &aliasAnalysis; }; } // namespace //===----------------------------------------------------------------------===// // BufferDeallocationSimplificationPass //===----------------------------------------------------------------------===// namespace { /// The actual buffer deallocation pass that inserts and moves dealloc nodes /// into the right positions. Furthermore, it inserts additional clones if /// necessary. It uses the algorithm described at the top of the file. struct BufferDeallocationSimplificationPass : public bufferization::impl::BufferDeallocationSimplificationBase< BufferDeallocationSimplificationPass> { void runOnOperation() override { AliasAnalysis &aliasAnalysis = getAnalysis(); RewritePatternSet patterns(&getContext()); patterns.add(&getContext(), aliasAnalysis); populateDeallocOpCanonicalizationPatterns(patterns, &getContext()); if (failed( applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } }; } // namespace std::unique_ptr mlir::bufferization::createBufferDeallocationSimplificationPass() { return std::make_unique(); }