//===- ExpandRealloc.cpp - Expand memref.realloc ops into it's components -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { namespace memref { #define GEN_PASS_DEF_EXPANDREALLOC #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" } // namespace memref } // namespace mlir using namespace mlir; namespace { /// The `realloc` operation performs a conditional allocation and copy to /// increase the size of a buffer if necessary. This pattern converts the /// `realloc` operation into this sequence of simpler operations. /// Example of an expansion: /// ```mlir /// %realloc = memref.realloc %alloc (%size) : memref to memref /// ``` /// is expanded to /// ```mlir /// %c0 = arith.constant 0 : index /// %dim = memref.dim %alloc, %c0 : memref /// %is_old_smaller = arith.cmpi ult, %dim, %arg1 /// %realloc = scf.if %is_old_smaller -> (memref) { /// %new_alloc = memref.alloc(%size) : memref /// %subview = memref.subview %new_alloc[0] [%dim] [1] /// memref.copy %alloc, %subview /// memref.dealloc %alloc /// scf.yield %alloc_0 : memref /// } else { /// %reinterpret_cast = memref.reinterpret_cast %alloc to /// offset: [0], sizes: [%size], strides: [1] /// scf.yield %reinterpret_cast : memref /// } /// ``` struct ExpandReallocOpPattern : public OpRewritePattern { ExpandReallocOpPattern(MLIRContext *ctx, bool emitDeallocs) : OpRewritePattern(ctx), emitDeallocs(emitDeallocs) {} LogicalResult matchAndRewrite(memref::ReallocOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); assert(op.getType().getRank() == 1 && "result MemRef must have exactly one rank"); assert(op.getSource().getType().getRank() == 1 && "source MemRef must have exactly one rank"); assert(op.getType().getLayout().isIdentity() && "result MemRef must have identity layout (or none)"); assert(op.getSource().getType().getLayout().isIdentity() && "source MemRef must have identity layout (or none)"); // Get the size of the original buffer. int64_t inputSize = op.getSource().getType().cast().getDimSize(0); OpFoldResult currSize = rewriter.getIndexAttr(inputSize); if (ShapedType::isDynamic(inputSize)) { Value dimZero = getValueOrCreateConstantIndexOp(rewriter, loc, rewriter.getIndexAttr(0)); currSize = rewriter.create(loc, op.getSource(), dimZero) .getResult(); } // Get the requested size that the new buffer should have. int64_t outputSize = op.getResult().getType().cast().getDimSize(0); OpFoldResult targetSize = ShapedType::isDynamic(outputSize) ? OpFoldResult{op.getDynamicResultSize()} : rewriter.getIndexAttr(outputSize); // Only allocate a new buffer and copy over the values in the old buffer if // the old buffer is smaller than the requested size. Value lhs = getValueOrCreateConstantIndexOp(rewriter, loc, currSize); Value rhs = getValueOrCreateConstantIndexOp(rewriter, loc, targetSize); Value cond = rewriter.create(loc, arith::CmpIPredicate::ult, lhs, rhs); auto ifOp = rewriter.create( loc, cond, [&](OpBuilder &builder, Location loc) { // Allocate the new buffer. If it is a dynamic memref we need to pass // an additional operand for the size at runtime, otherwise the static // size is encoded in the result type. SmallVector dynamicSizeOperands; if (op.getDynamicResultSize()) dynamicSizeOperands.push_back(op.getDynamicResultSize()); Value newAlloc = builder.create( loc, op.getResult().getType(), dynamicSizeOperands, op.getAlignmentAttr()); // Take a subview of the new (bigger) buffer such that we can copy the // old values over (the copy operation requires both operands to have // the same shape). Value subview = builder.create( loc, newAlloc, ArrayRef{rewriter.getIndexAttr(0)}, ArrayRef{currSize}, ArrayRef{rewriter.getIndexAttr(1)}); builder.create(loc, op.getSource(), subview); // Insert the deallocation of the old buffer only if requested // (enabled by default). if (emitDeallocs) builder.create(loc, op.getSource()); builder.create(loc, newAlloc); }, [&](OpBuilder &builder, Location loc) { // We need to reinterpret-cast here because either the input or output // type might be static, which means we need to cast from static to // dynamic or vice-versa. If both are static and the original buffer // is already bigger than the requested size, the cast represents a // subview operation. Value casted = builder.create( loc, op.getResult().getType().cast(), op.getSource(), rewriter.getIndexAttr(0), ArrayRef{targetSize}, ArrayRef{rewriter.getIndexAttr(1)}); builder.create(loc, casted); }); rewriter.replaceOp(op, ifOp.getResult(0)); return success(); } private: const bool emitDeallocs; }; struct ExpandReallocPass : public memref::impl::ExpandReallocBase { ExpandReallocPass(bool emitDeallocs) : memref::impl::ExpandReallocBase() { this->emitDeallocs.setValue(emitDeallocs); } void runOnOperation() override { MLIRContext &ctx = getContext(); RewritePatternSet patterns(&ctx); memref::populateExpandReallocPatterns(patterns, emitDeallocs.getValue()); ConversionTarget target(ctx); target.addLegalDialect(); target.addIllegalOp(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } }; } // namespace void mlir::memref::populateExpandReallocPatterns(RewritePatternSet &patterns, bool emitDeallocs) { patterns.add(patterns.getContext(), emitDeallocs); } std::unique_ptr mlir::memref::createExpandReallocPass(bool emitDeallocs) { return std::make_unique(emitDeallocs); }