175 lines
7.3 KiB
C++
175 lines
7.3 KiB
C++
//===- ExpandRealloc.cpp - Expand memref.realloc ops into it's components -===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
|
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
namespace mlir {
|
|
namespace memref {
|
|
#define GEN_PASS_DEF_EXPANDREALLOC
|
|
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
|
|
} // namespace memref
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
/// The `realloc` operation performs a conditional allocation and copy to
|
|
/// increase the size of a buffer if necessary. This pattern converts the
|
|
/// `realloc` operation into this sequence of simpler operations.
|
|
|
|
/// Example of an expansion:
|
|
/// ```mlir
|
|
/// %realloc = memref.realloc %alloc (%size) : memref<?xf32> to memref<?xf32>
|
|
/// ```
|
|
/// is expanded to
|
|
/// ```mlir
|
|
/// %c0 = arith.constant 0 : index
|
|
/// %dim = memref.dim %alloc, %c0 : memref<?xf32>
|
|
/// %is_old_smaller = arith.cmpi ult, %dim, %arg1
|
|
/// %realloc = scf.if %is_old_smaller -> (memref<?xf32>) {
|
|
/// %new_alloc = memref.alloc(%size) : memref<?xf32>
|
|
/// %subview = memref.subview %new_alloc[0] [%dim] [1]
|
|
/// memref.copy %alloc, %subview
|
|
/// memref.dealloc %alloc
|
|
/// scf.yield %alloc_0 : memref<?xf32>
|
|
/// } else {
|
|
/// %reinterpret_cast = memref.reinterpret_cast %alloc to
|
|
/// offset: [0], sizes: [%size], strides: [1]
|
|
/// scf.yield %reinterpret_cast : memref<?xf32>
|
|
/// }
|
|
/// ```
|
|
struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
|
|
ExpandReallocOpPattern(MLIRContext *ctx, bool emitDeallocs)
|
|
: OpRewritePattern(ctx), emitDeallocs(emitDeallocs) {}
|
|
|
|
LogicalResult matchAndRewrite(memref::ReallocOp op,
|
|
PatternRewriter &rewriter) const final {
|
|
Location loc = op.getLoc();
|
|
assert(op.getType().getRank() == 1 &&
|
|
"result MemRef must have exactly one rank");
|
|
assert(op.getSource().getType().getRank() == 1 &&
|
|
"source MemRef must have exactly one rank");
|
|
assert(op.getType().getLayout().isIdentity() &&
|
|
"result MemRef must have identity layout (or none)");
|
|
assert(op.getSource().getType().getLayout().isIdentity() &&
|
|
"source MemRef must have identity layout (or none)");
|
|
|
|
// Get the size of the original buffer.
|
|
int64_t inputSize =
|
|
op.getSource().getType().cast<BaseMemRefType>().getDimSize(0);
|
|
OpFoldResult currSize = rewriter.getIndexAttr(inputSize);
|
|
if (ShapedType::isDynamic(inputSize)) {
|
|
Value dimZero = getValueOrCreateConstantIndexOp(rewriter, loc,
|
|
rewriter.getIndexAttr(0));
|
|
currSize = rewriter.create<memref::DimOp>(loc, op.getSource(), dimZero)
|
|
.getResult();
|
|
}
|
|
|
|
// Get the requested size that the new buffer should have.
|
|
int64_t outputSize =
|
|
op.getResult().getType().cast<BaseMemRefType>().getDimSize(0);
|
|
OpFoldResult targetSize = ShapedType::isDynamic(outputSize)
|
|
? OpFoldResult{op.getDynamicResultSize()}
|
|
: rewriter.getIndexAttr(outputSize);
|
|
|
|
// Only allocate a new buffer and copy over the values in the old buffer if
|
|
// the old buffer is smaller than the requested size.
|
|
Value lhs = getValueOrCreateConstantIndexOp(rewriter, loc, currSize);
|
|
Value rhs = getValueOrCreateConstantIndexOp(rewriter, loc, targetSize);
|
|
Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
|
|
lhs, rhs);
|
|
auto ifOp = rewriter.create<scf::IfOp>(
|
|
loc, cond,
|
|
[&](OpBuilder &builder, Location loc) {
|
|
// Allocate the new buffer. If it is a dynamic memref we need to pass
|
|
// an additional operand for the size at runtime, otherwise the static
|
|
// size is encoded in the result type.
|
|
SmallVector<Value> dynamicSizeOperands;
|
|
if (op.getDynamicResultSize())
|
|
dynamicSizeOperands.push_back(op.getDynamicResultSize());
|
|
|
|
Value newAlloc = builder.create<memref::AllocOp>(
|
|
loc, op.getResult().getType(), dynamicSizeOperands,
|
|
op.getAlignmentAttr());
|
|
|
|
// Take a subview of the new (bigger) buffer such that we can copy the
|
|
// old values over (the copy operation requires both operands to have
|
|
// the same shape).
|
|
Value subview = builder.create<memref::SubViewOp>(
|
|
loc, newAlloc, ArrayRef<OpFoldResult>{rewriter.getIndexAttr(0)},
|
|
ArrayRef<OpFoldResult>{currSize},
|
|
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
|
|
builder.create<memref::CopyOp>(loc, op.getSource(), subview);
|
|
|
|
// Insert the deallocation of the old buffer only if requested
|
|
// (enabled by default).
|
|
if (emitDeallocs)
|
|
builder.create<memref::DeallocOp>(loc, op.getSource());
|
|
|
|
builder.create<scf::YieldOp>(loc, newAlloc);
|
|
},
|
|
[&](OpBuilder &builder, Location loc) {
|
|
// We need to reinterpret-cast here because either the input or output
|
|
// type might be static, which means we need to cast from static to
|
|
// dynamic or vice-versa. If both are static and the original buffer
|
|
// is already bigger than the requested size, the cast represents a
|
|
// subview operation.
|
|
Value casted = builder.create<memref::ReinterpretCastOp>(
|
|
loc, op.getResult().getType().cast<MemRefType>(), op.getSource(),
|
|
rewriter.getIndexAttr(0), ArrayRef<OpFoldResult>{targetSize},
|
|
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
|
|
builder.create<scf::YieldOp>(loc, casted);
|
|
});
|
|
|
|
rewriter.replaceOp(op, ifOp.getResult(0));
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
const bool emitDeallocs;
|
|
};
|
|
|
|
struct ExpandReallocPass
|
|
: public memref::impl::ExpandReallocBase<ExpandReallocPass> {
|
|
ExpandReallocPass(bool emitDeallocs)
|
|
: memref::impl::ExpandReallocBase<ExpandReallocPass>() {
|
|
this->emitDeallocs.setValue(emitDeallocs);
|
|
}
|
|
void runOnOperation() override {
|
|
MLIRContext &ctx = getContext();
|
|
|
|
RewritePatternSet patterns(&ctx);
|
|
memref::populateExpandReallocPatterns(patterns, emitDeallocs.getValue());
|
|
ConversionTarget target(ctx);
|
|
|
|
target.addLegalDialect<arith::ArithDialect, scf::SCFDialect,
|
|
memref::MemRefDialect>();
|
|
target.addIllegalOp<memref::ReallocOp>();
|
|
if (failed(applyPartialConversion(getOperation(), target,
|
|
std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::memref::populateExpandReallocPatterns(RewritePatternSet &patterns,
|
|
bool emitDeallocs) {
|
|
patterns.add<ExpandReallocOpPattern>(patterns.getContext(), emitDeallocs);
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::memref::createExpandReallocPass(bool emitDeallocs) {
|
|
return std::make_unique<ExpandReallocPass>(emitDeallocs);
|
|
}
|