bolt/deps/llvm-18.1.8/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp
2025-02-14 19:21:04 +01:00

175 lines
7.3 KiB
C++

//===- ExpandRealloc.cpp - Expand memref.realloc ops into it's components -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace memref {
#define GEN_PASS_DEF_EXPANDREALLOC
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir
using namespace mlir;
namespace {
/// The `realloc` operation performs a conditional allocation and copy to
/// increase the size of a buffer if necessary. This pattern converts the
/// `realloc` operation into this sequence of simpler operations.
/// Example of an expansion:
/// ```mlir
/// %realloc = memref.realloc %alloc (%size) : memref<?xf32> to memref<?xf32>
/// ```
/// is expanded to
/// ```mlir
/// %c0 = arith.constant 0 : index
/// %dim = memref.dim %alloc, %c0 : memref<?xf32>
/// %is_old_smaller = arith.cmpi ult, %dim, %arg1
/// %realloc = scf.if %is_old_smaller -> (memref<?xf32>) {
/// %new_alloc = memref.alloc(%size) : memref<?xf32>
/// %subview = memref.subview %new_alloc[0] [%dim] [1]
/// memref.copy %alloc, %subview
/// memref.dealloc %alloc
/// scf.yield %alloc_0 : memref<?xf32>
/// } else {
/// %reinterpret_cast = memref.reinterpret_cast %alloc to
/// offset: [0], sizes: [%size], strides: [1]
/// scf.yield %reinterpret_cast : memref<?xf32>
/// }
/// ```
struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> {
ExpandReallocOpPattern(MLIRContext *ctx, bool emitDeallocs)
: OpRewritePattern(ctx), emitDeallocs(emitDeallocs) {}
LogicalResult matchAndRewrite(memref::ReallocOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
assert(op.getType().getRank() == 1 &&
"result MemRef must have exactly one rank");
assert(op.getSource().getType().getRank() == 1 &&
"source MemRef must have exactly one rank");
assert(op.getType().getLayout().isIdentity() &&
"result MemRef must have identity layout (or none)");
assert(op.getSource().getType().getLayout().isIdentity() &&
"source MemRef must have identity layout (or none)");
// Get the size of the original buffer.
int64_t inputSize =
op.getSource().getType().cast<BaseMemRefType>().getDimSize(0);
OpFoldResult currSize = rewriter.getIndexAttr(inputSize);
if (ShapedType::isDynamic(inputSize)) {
Value dimZero = getValueOrCreateConstantIndexOp(rewriter, loc,
rewriter.getIndexAttr(0));
currSize = rewriter.create<memref::DimOp>(loc, op.getSource(), dimZero)
.getResult();
}
// Get the requested size that the new buffer should have.
int64_t outputSize =
op.getResult().getType().cast<BaseMemRefType>().getDimSize(0);
OpFoldResult targetSize = ShapedType::isDynamic(outputSize)
? OpFoldResult{op.getDynamicResultSize()}
: rewriter.getIndexAttr(outputSize);
// Only allocate a new buffer and copy over the values in the old buffer if
// the old buffer is smaller than the requested size.
Value lhs = getValueOrCreateConstantIndexOp(rewriter, loc, currSize);
Value rhs = getValueOrCreateConstantIndexOp(rewriter, loc, targetSize);
Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
lhs, rhs);
auto ifOp = rewriter.create<scf::IfOp>(
loc, cond,
[&](OpBuilder &builder, Location loc) {
// Allocate the new buffer. If it is a dynamic memref we need to pass
// an additional operand for the size at runtime, otherwise the static
// size is encoded in the result type.
SmallVector<Value> dynamicSizeOperands;
if (op.getDynamicResultSize())
dynamicSizeOperands.push_back(op.getDynamicResultSize());
Value newAlloc = builder.create<memref::AllocOp>(
loc, op.getResult().getType(), dynamicSizeOperands,
op.getAlignmentAttr());
// Take a subview of the new (bigger) buffer such that we can copy the
// old values over (the copy operation requires both operands to have
// the same shape).
Value subview = builder.create<memref::SubViewOp>(
loc, newAlloc, ArrayRef<OpFoldResult>{rewriter.getIndexAttr(0)},
ArrayRef<OpFoldResult>{currSize},
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
builder.create<memref::CopyOp>(loc, op.getSource(), subview);
// Insert the deallocation of the old buffer only if requested
// (enabled by default).
if (emitDeallocs)
builder.create<memref::DeallocOp>(loc, op.getSource());
builder.create<scf::YieldOp>(loc, newAlloc);
},
[&](OpBuilder &builder, Location loc) {
// We need to reinterpret-cast here because either the input or output
// type might be static, which means we need to cast from static to
// dynamic or vice-versa. If both are static and the original buffer
// is already bigger than the requested size, the cast represents a
// subview operation.
Value casted = builder.create<memref::ReinterpretCastOp>(
loc, op.getResult().getType().cast<MemRefType>(), op.getSource(),
rewriter.getIndexAttr(0), ArrayRef<OpFoldResult>{targetSize},
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)});
builder.create<scf::YieldOp>(loc, casted);
});
rewriter.replaceOp(op, ifOp.getResult(0));
return success();
}
private:
const bool emitDeallocs;
};
struct ExpandReallocPass
: public memref::impl::ExpandReallocBase<ExpandReallocPass> {
ExpandReallocPass(bool emitDeallocs)
: memref::impl::ExpandReallocBase<ExpandReallocPass>() {
this->emitDeallocs.setValue(emitDeallocs);
}
void runOnOperation() override {
MLIRContext &ctx = getContext();
RewritePatternSet patterns(&ctx);
memref::populateExpandReallocPatterns(patterns, emitDeallocs.getValue());
ConversionTarget target(ctx);
target.addLegalDialect<arith::ArithDialect, scf::SCFDialect,
memref::MemRefDialect>();
target.addIllegalOp<memref::ReallocOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
void mlir::memref::populateExpandReallocPatterns(RewritePatternSet &patterns,
bool emitDeallocs) {
patterns.add<ExpandReallocOpPattern>(patterns.getContext(), emitDeallocs);
}
std::unique_ptr<Pass> mlir::memref::createExpandReallocPass(bool emitDeallocs) {
return std::make_unique<ExpandReallocPass>(emitDeallocs);
}