397 lines
16 KiB
C++
397 lines
16 KiB
C++
|
//===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
//
|
||
|
// This file contains the implementation of the core LICM algorithm.
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
|
||
|
|
||
|
#include "mlir/IR/Operation.h"
|
||
|
#include "mlir/IR/PatternMatch.h"
|
||
|
#include "mlir/Interfaces/LoopLikeInterface.h"
|
||
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||
|
#include "mlir/Interfaces/SubsetOpInterface.h"
|
||
|
#include "llvm/Support/Debug.h"
|
||
|
#include <queue>
|
||
|
|
||
|
#define DEBUG_TYPE "licm"
|
||
|
|
||
|
using namespace mlir;
|
||
|
|
||
|
/// Checks whether the given op can be hoisted by checking that
|
||
|
/// - the op and none of its contained operations depend on values inside of the
|
||
|
/// loop (by means of calling definedOutside).
|
||
|
/// - the op has no side-effects.
|
||
|
static bool canBeHoisted(Operation *op,
|
||
|
function_ref<bool(OpOperand &)> condition) {
|
||
|
// Do not move terminators.
|
||
|
if (op->hasTrait<OpTrait::IsTerminator>())
|
||
|
return false;
|
||
|
|
||
|
// Walk the nested operations and check that all used values are either
|
||
|
// defined outside of the loop or in a nested region, but not at the level of
|
||
|
// the loop body.
|
||
|
auto walkFn = [&](Operation *child) {
|
||
|
for (OpOperand &operand : child->getOpOperands()) {
|
||
|
// Ignore values defined in a nested region.
|
||
|
if (op->isAncestor(operand.get().getParentRegion()->getParentOp()))
|
||
|
continue;
|
||
|
if (!condition(operand))
|
||
|
return WalkResult::interrupt();
|
||
|
}
|
||
|
return WalkResult::advance();
|
||
|
};
|
||
|
return !op->walk(walkFn).wasInterrupted();
|
||
|
}
|
||
|
|
||
|
static bool canBeHoisted(Operation *op,
|
||
|
function_ref<bool(Value)> definedOutside) {
|
||
|
return canBeHoisted(
|
||
|
op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
|
||
|
}
|
||
|
|
||
|
size_t mlir::moveLoopInvariantCode(
|
||
|
ArrayRef<Region *> regions,
|
||
|
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
|
||
|
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
|
||
|
function_ref<void(Operation *, Region *)> moveOutOfRegion) {
|
||
|
size_t numMoved = 0;
|
||
|
|
||
|
for (Region *region : regions) {
|
||
|
LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
|
||
|
<< *region->getParentOp() << "\n");
|
||
|
|
||
|
std::queue<Operation *> worklist;
|
||
|
// Add top-level operations in the loop body to the worklist.
|
||
|
for (Operation &op : region->getOps())
|
||
|
worklist.push(&op);
|
||
|
|
||
|
auto definedOutside = [&](Value value) {
|
||
|
return isDefinedOutsideRegion(value, region);
|
||
|
};
|
||
|
|
||
|
while (!worklist.empty()) {
|
||
|
Operation *op = worklist.front();
|
||
|
worklist.pop();
|
||
|
// Skip ops that have already been moved. Check if the op can be hoisted.
|
||
|
if (op->getParentRegion() != region)
|
||
|
continue;
|
||
|
|
||
|
LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
|
||
|
if (!shouldMoveOutOfRegion(op, region) ||
|
||
|
!canBeHoisted(op, definedOutside))
|
||
|
continue;
|
||
|
|
||
|
LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
|
||
|
moveOutOfRegion(op, region);
|
||
|
++numMoved;
|
||
|
|
||
|
// Since the op has been moved, we need to check its users within the
|
||
|
// top-level of the loop body.
|
||
|
for (Operation *user : op->getUsers())
|
||
|
if (user->getParentRegion() == region)
|
||
|
worklist.push(user);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return numMoved;
|
||
|
}
|
||
|
|
||
|
size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
|
||
|
return moveLoopInvariantCode(
|
||
|
loopLike.getLoopRegions(),
|
||
|
[&](Value value, Region *) {
|
||
|
return loopLike.isDefinedOutsideOfLoop(value);
|
||
|
},
|
||
|
[&](Operation *op, Region *) {
|
||
|
return isMemoryEffectFree(op) && isSpeculatable(op);
|
||
|
},
|
||
|
[&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
|
||
|
}
|
||
|
|
||
|
namespace {
|
||
|
/// Helper data structure that keeps track of equivalent/disjoint subset ops.
|
||
|
class MatchingSubsets {
|
||
|
public:
|
||
|
/// Insert a subset op.
|
||
|
void insert(SubsetOpInterface op, bool collectHoistableOps = true) {
|
||
|
allSubsetOps.push_back(op);
|
||
|
if (!collectHoistableOps)
|
||
|
return;
|
||
|
if (auto extractionOp =
|
||
|
dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
|
||
|
insertExtractionOp(extractionOp);
|
||
|
if (auto insertionOp =
|
||
|
dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
|
||
|
insertInsertionOp(insertionOp);
|
||
|
}
|
||
|
|
||
|
/// Return a range of matching extraction-insertion subset ops. If there is no
|
||
|
/// matching extraction/insertion op, the respective value is empty. Ops are
|
||
|
/// skipped if there are other subset ops that are not guaranteed to operate
|
||
|
/// on disjoint subsets.
|
||
|
auto getHoistableSubsetOps() {
|
||
|
return llvm::make_filter_range(
|
||
|
llvm::zip(extractions, insertions), [&](auto pair) {
|
||
|
auto [extractionOp, insertionOp] = pair;
|
||
|
// Hoist only if the extracted and inserted values have the same type.
|
||
|
if (extractionOp && insertionOp &&
|
||
|
extractionOp->getResult(0).getType() !=
|
||
|
insertionOp.getSourceOperand().get().getType())
|
||
|
return false;
|
||
|
// Hoist only if there are no conflicting subset ops.
|
||
|
return allDisjoint(extractionOp, insertionOp);
|
||
|
});
|
||
|
}
|
||
|
|
||
|
/// Populate subset ops starting from the given region iter_arg. Return
|
||
|
/// "failure" if non-subset ops are found along the path to the loop yielding
|
||
|
/// op or if there is no single path to the tied yielded operand. If
|
||
|
/// `collectHoistableOps` is set to "false", subset ops are gathered
|
||
|
/// throughout the traversal, but not enumerated by `getHoistableSubsetOps`.
|
||
|
LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
|
||
|
BlockArgument iterArg,
|
||
|
bool collectHoistableOps = true);
|
||
|
|
||
|
private:
|
||
|
/// Helper function for equivalence of tensor values. Since only insertion
|
||
|
/// subset ops (that are also destination style ops) are followed when
|
||
|
/// traversing the SSA use-def chain, all tensor values are equivalent.
|
||
|
static bool isEquivalent(Value v1, Value v2) { return true; }
|
||
|
|
||
|
/// Return "true" if the subsets of the given extraction and insertion ops
|
||
|
/// are operating disjoint from the subsets that all other known subset ops
|
||
|
/// are operating on.
|
||
|
bool allDisjoint(SubsetExtractionOpInterface extractionOp,
|
||
|
SubsetInsertionOpInterface insertionOp) const {
|
||
|
for (SubsetOpInterface other : allSubsetOps) {
|
||
|
if (other == extractionOp || other == insertionOp)
|
||
|
continue;
|
||
|
if (extractionOp &&
|
||
|
!other.operatesOnDisjointSubset(extractionOp, isEquivalent))
|
||
|
return false;
|
||
|
if (insertionOp &&
|
||
|
!other.operatesOnDisjointSubset(insertionOp, isEquivalent))
|
||
|
return false;
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
/// Insert a subset extraction op. If the subset is equivalent to an existing
|
||
|
/// subset insertion op, pair them up. (If there is already a paired up subset
|
||
|
/// extraction op, overwrite the subset extraction op.)
|
||
|
void insertExtractionOp(SubsetExtractionOpInterface extractionOp) {
|
||
|
for (auto it : llvm::enumerate(insertions)) {
|
||
|
if (!it.value())
|
||
|
continue;
|
||
|
auto other = cast<SubsetOpInterface>(it.value().getOperation());
|
||
|
if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) {
|
||
|
extractions[it.index()] = extractionOp;
|
||
|
return;
|
||
|
}
|
||
|
}
|
||
|
// There is no known equivalent insertion op. Create a new entry.
|
||
|
extractions.push_back(extractionOp);
|
||
|
insertions.push_back({});
|
||
|
}
|
||
|
|
||
|
/// Insert a subset insertion op. If the subset is equivalent to an existing
|
||
|
/// subset extraction op, pair them up. (If there is already a paired up
|
||
|
/// subset insertion op, overwrite the subset insertion op.)
|
||
|
void insertInsertionOp(SubsetInsertionOpInterface insertionOp) {
|
||
|
for (auto it : llvm::enumerate(extractions)) {
|
||
|
if (!it.value())
|
||
|
continue;
|
||
|
auto other = cast<SubsetOpInterface>(it.value().getOperation());
|
||
|
if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) {
|
||
|
insertions[it.index()] = insertionOp;
|
||
|
return;
|
||
|
}
|
||
|
}
|
||
|
// There is no known equivalent extraction op. Create a new entry.
|
||
|
extractions.push_back({});
|
||
|
insertions.push_back(insertionOp);
|
||
|
}
|
||
|
|
||
|
SmallVector<SubsetExtractionOpInterface> extractions;
|
||
|
SmallVector<SubsetInsertionOpInterface> insertions;
|
||
|
SmallVector<SubsetOpInterface> allSubsetOps;
|
||
|
};
|
||
|
} // namespace
|
||
|
|
||
|
/// If the given value has a single use by an op that is a terminator, return
|
||
|
/// that use. Otherwise, return nullptr.
|
||
|
static OpOperand *getSingleTerminatorUse(Value value) {
|
||
|
if (!value.hasOneUse())
|
||
|
return nullptr;
|
||
|
OpOperand &use = *value.getUses().begin();
|
||
|
if (use.getOwner()->hasTrait<OpTrait::IsTerminator>())
|
||
|
return &use;
|
||
|
return nullptr;
|
||
|
}
|
||
|
|
||
|
LogicalResult
|
||
|
MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
|
||
|
BlockArgument iterArg,
|
||
|
bool collectHoistableOps) {
|
||
|
assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
|
||
|
Value value = iterArg;
|
||
|
|
||
|
// Traverse use-def chain. Subset ops can be hoisted only if all ops along the
|
||
|
// use-def chain starting from the region iter_arg are subset extraction or
|
||
|
// subset insertion ops. The chain must terminate at the corresponding yield
|
||
|
// operand (e.g., no swapping of iter_args).
|
||
|
OpOperand *yieldedOperand = nullptr;
|
||
|
// Iterate until the single use of the current SSA value is a terminator,
|
||
|
// which is expected to be the yielding operation of the loop.
|
||
|
while (!(yieldedOperand = getSingleTerminatorUse(value))) {
|
||
|
Value nextValue = {};
|
||
|
|
||
|
for (OpOperand &use : value.getUses()) {
|
||
|
if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
|
||
|
// Subset ops in nested loops are collected to check if there are only
|
||
|
// disjoint subset ops, but such subset ops are not subject to hoisting.
|
||
|
// To hoist subset ops from nested loops, the hoisting transformation
|
||
|
// should be run on the nested loop.
|
||
|
auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use);
|
||
|
if (!nestedIterArg)
|
||
|
return failure();
|
||
|
// Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA
|
||
|
// use-def chain starting at `nestedIterArg` and terminating in the
|
||
|
// tied, yielding operand.
|
||
|
if (failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg,
|
||
|
/*collectHoistableOps=*/false)))
|
||
|
return failure();
|
||
|
nextValue = nestedLoop.getTiedLoopResult(&use);
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
|
||
|
if (!subsetOp)
|
||
|
return failure();
|
||
|
insert(subsetOp);
|
||
|
|
||
|
if (auto insertionOp =
|
||
|
dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
|
||
|
// The value must be used as a destination. (In case of a source, the
|
||
|
// entire tensor would be read, which would prevent any hoisting.)
|
||
|
if (&use != &insertionOp.getDestinationOperand())
|
||
|
return failure();
|
||
|
// There must be a single use-def chain from the region iter_arg to the
|
||
|
// terminator. I.e., only one insertion op. Branches are not supported.
|
||
|
if (nextValue)
|
||
|
return failure();
|
||
|
nextValue = insertionOp.getUpdatedDestination();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Nothing can be hoisted if the chain does not continue with loop yielding
|
||
|
// op or a subset insertion op.
|
||
|
if (!nextValue)
|
||
|
return failure();
|
||
|
value = nextValue;
|
||
|
}
|
||
|
|
||
|
// Hoist only if the SSA use-def chain ends in the yielding terminator of the
|
||
|
// loop and the yielded value is the `idx`-th operand. (I.e., there is no
|
||
|
// swapping yield.)
|
||
|
if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
|
||
|
return failure();
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
/// Hoist all subset ops that operate on the idx-th region iter_arg of the given
|
||
|
/// loop-like op and index into loop-invariant subset locations. Return the
|
||
|
/// newly created loop op (that has extra iter_args) or the original loop op if
|
||
|
/// nothing was hoisted.
|
||
|
static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
|
||
|
LoopLikeOpInterface loopLike,
|
||
|
BlockArgument iterArg) {
|
||
|
assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
|
||
|
auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
|
||
|
int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
|
||
|
MatchingSubsets subsets;
|
||
|
if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
|
||
|
return loopLike;
|
||
|
|
||
|
// Hoist all matching extraction-insertion pairs one-by-one.
|
||
|
for (auto it : subsets.getHoistableSubsetOps()) {
|
||
|
auto extractionOp = std::get<0>(it);
|
||
|
auto insertionOp = std::get<1>(it);
|
||
|
|
||
|
// Ops cannot be hoisted if they depend on loop-variant values.
|
||
|
if (extractionOp) {
|
||
|
if (!canBeHoisted(extractionOp, [&](OpOperand &operand) {
|
||
|
return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
|
||
|
&operand == &extractionOp.getSourceOperand();
|
||
|
}))
|
||
|
extractionOp = {};
|
||
|
}
|
||
|
if (insertionOp) {
|
||
|
if (!canBeHoisted(insertionOp, [&](OpOperand &operand) {
|
||
|
return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
|
||
|
&operand == &insertionOp.getSourceOperand() ||
|
||
|
&operand == &insertionOp.getDestinationOperand();
|
||
|
}))
|
||
|
insertionOp = {};
|
||
|
}
|
||
|
|
||
|
// Only hoist extraction-insertion pairs for now. Standalone extractions/
|
||
|
// insertions that are loop-invariant could be hoisted, but there may be
|
||
|
// easier ways to canonicalize the IR.
|
||
|
if (extractionOp && insertionOp) {
|
||
|
// Create a new loop with an additional iter_arg.
|
||
|
NewYieldValuesFn newYieldValuesFn =
|
||
|
[&](OpBuilder &b, Location loc,
|
||
|
ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
|
||
|
return {insertionOp.getSourceOperand().get()};
|
||
|
};
|
||
|
FailureOr<LoopLikeOpInterface> newLoop =
|
||
|
loopLike.replaceWithAdditionalYields(
|
||
|
rewriter, extractionOp.getResult(),
|
||
|
/*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn);
|
||
|
if (failed(newLoop))
|
||
|
return loopLike;
|
||
|
loopLike = *newLoop;
|
||
|
|
||
|
// Hoist the extraction/insertion ops.
|
||
|
iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
|
||
|
OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
|
||
|
OpResult newLoopResult = loopLike.getLoopResults()->back();
|
||
|
extractionOp->moveBefore(loopLike);
|
||
|
insertionOp->moveAfter(loopLike);
|
||
|
rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
|
||
|
insertionOp.getDestinationOperand().get());
|
||
|
extractionOp.getSourceOperand().set(
|
||
|
loopLike.getTiedLoopInit(iterArg)->get());
|
||
|
rewriter.replaceAllUsesWith(loopResult,
|
||
|
insertionOp.getUpdatedDestination());
|
||
|
insertionOp.getSourceOperand().set(newLoopResult);
|
||
|
insertionOp.getDestinationOperand().set(loopResult);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return loopLike;
|
||
|
}
|
||
|
|
||
|
LoopLikeOpInterface
|
||
|
mlir::hoistLoopInvariantSubsets(RewriterBase &rewriter,
|
||
|
LoopLikeOpInterface loopLike) {
|
||
|
// Note: As subset ops are getting hoisted, the number of region iter_args
|
||
|
// increases. This can enable further hoisting opportunities on the new
|
||
|
// iter_args.
|
||
|
for (int64_t i = 0;
|
||
|
i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) {
|
||
|
loopLike = hoistSubsetAtIterArg(rewriter, loopLike,
|
||
|
loopLike.getRegionIterArgs()[i]);
|
||
|
}
|
||
|
return loopLike;
|
||
|
}
|