163 lines
7.3 KiB
C++
163 lines
7.3 KiB
C++
//===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h"
|
|
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
|
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "mlir/IR/Operation.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::bufferization;
|
|
|
|
static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
|
|
|
|
namespace {
|
|
/// While CondBranchOp also implement the BranchOpInterface, we add a
|
|
/// special-case implementation here because the BranchOpInterface does not
|
|
/// offer all of the functionallity we need to insert dealloc oeprations in an
|
|
/// efficient way. More precisely, there is no way to extract the branch
|
|
/// condition without casting to CondBranchOp specifically. It is still
|
|
/// possible to implement deallocation for cases where we don't know to which
|
|
/// successor the terminator branches before the actual branch happens by
|
|
/// inserting auxiliary blocks and putting the dealloc op there, however, this
|
|
/// can lead to less efficient code.
|
|
/// This function inserts two dealloc operations (one for each successor) and
|
|
/// adjusts the dealloc conditions according to the branch condition, then the
|
|
/// ownerships of the retained MemRefs are updated by combining the result
|
|
/// values of the two dealloc operations.
|
|
///
|
|
/// Example:
|
|
/// ```
|
|
/// ^bb1:
|
|
/// <more ops...>
|
|
/// cf.cond_br cond, ^bb2(<forward-to-bb2>), ^bb3(<forward-to-bb2>)
|
|
/// ```
|
|
/// becomes
|
|
/// ```
|
|
/// // let (m, c) = getMemrefsAndConditionsToDeallocate(bb1)
|
|
/// // let r0 = getMemrefsToRetain(bb1, bb2, <forward-to-bb2>)
|
|
/// // let r1 = getMemrefsToRetain(bb1, bb3, <forward-to-bb3>)
|
|
/// ^bb1:
|
|
/// <more ops...>
|
|
/// let thenCond = map(c, (c) -> arith.andi cond, c)
|
|
/// let elseCond = map(c, (c) -> arith.andi (arith.xori cond, true), c)
|
|
/// o0 = bufferization.dealloc m if thenCond retain r0
|
|
/// o1 = bufferization.dealloc m if elseCond retain r1
|
|
/// // replace ownership(r0) with o0 element-wise
|
|
/// // replace ownership(r1) with o1 element-wise
|
|
/// // let ownership0 := (r) -> o in o0 corresponding to r
|
|
/// // let ownership1 := (r) -> o in o1 corresponding to r
|
|
/// // let cmn := intersection(r0, r1)
|
|
/// foreach (a, b) in zip(map(cmn, ownership0), map(cmn, ownership1)):
|
|
/// forall r in r0: replace ownership0(r) with arith.select cond, a, b)
|
|
/// forall r in r1: replace ownership1(r) with arith.select cond, a, b)
|
|
/// cf.cond_br cond, ^bb2(<forward-to-bb2>, o0), ^bb3(<forward-to-bb3>, o1)
|
|
/// ```
|
|
struct CondBranchOpInterface
|
|
: public BufferDeallocationOpInterface::ExternalModel<CondBranchOpInterface,
|
|
cf::CondBranchOp> {
|
|
FailureOr<Operation *> process(Operation *op, DeallocationState &state,
|
|
const DeallocationOptions &options) const {
|
|
OpBuilder builder(op);
|
|
auto condBr = cast<cf::CondBranchOp>(op);
|
|
|
|
// The list of memrefs to deallocate in this block is independent of which
|
|
// branch is taken.
|
|
SmallVector<Value> memrefs, conditions;
|
|
if (failed(state.getMemrefsAndConditionsToDeallocate(
|
|
builder, condBr.getLoc(), condBr->getBlock(), memrefs, conditions)))
|
|
return failure();
|
|
|
|
// Helper lambda to factor out common logic for inserting the dealloc
|
|
// operations for each successor.
|
|
auto insertDeallocForBranch =
|
|
[&](Block *target, MutableOperandRange destOperands,
|
|
const std::function<Value(Value)> &conditionModifier,
|
|
DenseMap<Value, Value> &mapping) -> DeallocOp {
|
|
SmallVector<Value> toRetain;
|
|
state.getMemrefsToRetain(condBr->getBlock(), target,
|
|
OperandRange(destOperands), toRetain);
|
|
SmallVector<Value> adaptedConditions(
|
|
llvm::map_range(conditions, conditionModifier));
|
|
auto deallocOp = builder.create<bufferization::DeallocOp>(
|
|
condBr.getLoc(), memrefs, adaptedConditions, toRetain);
|
|
state.resetOwnerships(deallocOp.getRetained(), condBr->getBlock());
|
|
for (auto [retained, ownership] : llvm::zip(
|
|
deallocOp.getRetained(), deallocOp.getUpdatedConditions())) {
|
|
state.updateOwnership(retained, ownership, condBr->getBlock());
|
|
mapping[retained] = ownership;
|
|
}
|
|
SmallVector<Value> replacements, ownerships;
|
|
for (OpOperand &operand : destOperands) {
|
|
replacements.push_back(operand.get());
|
|
if (isMemref(operand.get())) {
|
|
assert(mapping.contains(operand.get()) &&
|
|
"Should be contained at this point");
|
|
ownerships.push_back(mapping[operand.get()]);
|
|
}
|
|
}
|
|
replacements.append(ownerships);
|
|
destOperands.assign(replacements);
|
|
return deallocOp;
|
|
};
|
|
|
|
// Call the helper lambda and make sure the dealloc conditions are properly
|
|
// modified to reflect the branch condition as well.
|
|
DenseMap<Value, Value> thenMapping, elseMapping;
|
|
DeallocOp thenTakenDeallocOp = insertDeallocForBranch(
|
|
condBr.getTrueDest(), condBr.getTrueDestOperandsMutable(),
|
|
[&](Value cond) {
|
|
return builder.create<arith::AndIOp>(condBr.getLoc(), cond,
|
|
condBr.getCondition());
|
|
},
|
|
thenMapping);
|
|
DeallocOp elseTakenDeallocOp = insertDeallocForBranch(
|
|
condBr.getFalseDest(), condBr.getFalseDestOperandsMutable(),
|
|
[&](Value cond) {
|
|
Value trueVal = builder.create<arith::ConstantOp>(
|
|
condBr.getLoc(), builder.getBoolAttr(true));
|
|
Value negation = builder.create<arith::XOrIOp>(
|
|
condBr.getLoc(), trueVal, condBr.getCondition());
|
|
return builder.create<arith::AndIOp>(condBr.getLoc(), cond, negation);
|
|
},
|
|
elseMapping);
|
|
|
|
// We specifically need to update the ownerships of values that are retained
|
|
// in both dealloc operations again to get a combined 'Unique' ownership
|
|
// instead of an 'Unknown' ownership.
|
|
SmallPtrSet<Value, 16> thenValues(thenTakenDeallocOp.getRetained().begin(),
|
|
thenTakenDeallocOp.getRetained().end());
|
|
SetVector<Value> commonValues;
|
|
for (Value val : elseTakenDeallocOp.getRetained()) {
|
|
if (thenValues.contains(val))
|
|
commonValues.insert(val);
|
|
}
|
|
|
|
for (Value retained : commonValues) {
|
|
state.resetOwnerships(retained, condBr->getBlock());
|
|
Value combinedOwnership = builder.create<arith::SelectOp>(
|
|
condBr.getLoc(), condBr.getCondition(), thenMapping[retained],
|
|
elseMapping[retained]);
|
|
state.updateOwnership(retained, combinedOwnership, condBr->getBlock());
|
|
}
|
|
|
|
return condBr.getOperation();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::cf::registerBufferDeallocationOpInterfaceExternalModels(
|
|
DialectRegistry ®istry) {
|
|
registry.addExtension(+[](MLIRContext *ctx, ControlFlowDialect *dialect) {
|
|
CondBranchOp::attachInterface<CondBranchOpInterface>(*ctx);
|
|
});
|
|
}
|