bolt/deps/llvm-18.1.8/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
2025-02-14 19:21:04 +01:00

396 lines
16 KiB
C++

//===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the implementation of the core LICM algorithm.
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "llvm/Support/Debug.h"
#include <queue>
#define DEBUG_TYPE "licm"
using namespace mlir;
/// Checks whether the given op can be hoisted by checking that
/// - the op and none of its contained operations depend on values inside of the
/// loop (by means of calling definedOutside).
/// - the op has no side-effects.
static bool canBeHoisted(Operation *op,
function_ref<bool(OpOperand &)> condition) {
// Do not move terminators.
if (op->hasTrait<OpTrait::IsTerminator>())
return false;
// Walk the nested operations and check that all used values are either
// defined outside of the loop or in a nested region, but not at the level of
// the loop body.
auto walkFn = [&](Operation *child) {
for (OpOperand &operand : child->getOpOperands()) {
// Ignore values defined in a nested region.
if (op->isAncestor(operand.get().getParentRegion()->getParentOp()))
continue;
if (!condition(operand))
return WalkResult::interrupt();
}
return WalkResult::advance();
};
return !op->walk(walkFn).wasInterrupted();
}
static bool canBeHoisted(Operation *op,
function_ref<bool(Value)> definedOutside) {
return canBeHoisted(
op, [&](OpOperand &operand) { return definedOutside(operand.get()); });
}
size_t mlir::moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
function_ref<void(Operation *, Region *)> moveOutOfRegion) {
size_t numMoved = 0;
for (Region *region : regions) {
LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
<< *region->getParentOp() << "\n");
std::queue<Operation *> worklist;
// Add top-level operations in the loop body to the worklist.
for (Operation &op : region->getOps())
worklist.push(&op);
auto definedOutside = [&](Value value) {
return isDefinedOutsideRegion(value, region);
};
while (!worklist.empty()) {
Operation *op = worklist.front();
worklist.pop();
// Skip ops that have already been moved. Check if the op can be hoisted.
if (op->getParentRegion() != region)
continue;
LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
if (!shouldMoveOutOfRegion(op, region) ||
!canBeHoisted(op, definedOutside))
continue;
LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
moveOutOfRegion(op, region);
++numMoved;
// Since the op has been moved, we need to check its users within the
// top-level of the loop body.
for (Operation *user : op->getUsers())
if (user->getParentRegion() == region)
worklist.push(user);
}
}
return numMoved;
}
size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
return moveLoopInvariantCode(
loopLike.getLoopRegions(),
[&](Value value, Region *) {
return loopLike.isDefinedOutsideOfLoop(value);
},
[&](Operation *op, Region *) {
return isMemoryEffectFree(op) && isSpeculatable(op);
},
[&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
}
namespace {
/// Helper data structure that keeps track of equivalent/disjoint subset ops.
class MatchingSubsets {
public:
/// Insert a subset op.
void insert(SubsetOpInterface op, bool collectHoistableOps = true) {
allSubsetOps.push_back(op);
if (!collectHoistableOps)
return;
if (auto extractionOp =
dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
insertExtractionOp(extractionOp);
if (auto insertionOp =
dyn_cast<SubsetInsertionOpInterface>(op.getOperation()))
insertInsertionOp(insertionOp);
}
/// Return a range of matching extraction-insertion subset ops. If there is no
/// matching extraction/insertion op, the respective value is empty. Ops are
/// skipped if there are other subset ops that are not guaranteed to operate
/// on disjoint subsets.
auto getHoistableSubsetOps() {
return llvm::make_filter_range(
llvm::zip(extractions, insertions), [&](auto pair) {
auto [extractionOp, insertionOp] = pair;
// Hoist only if the extracted and inserted values have the same type.
if (extractionOp && insertionOp &&
extractionOp->getResult(0).getType() !=
insertionOp.getSourceOperand().get().getType())
return false;
// Hoist only if there are no conflicting subset ops.
return allDisjoint(extractionOp, insertionOp);
});
}
/// Populate subset ops starting from the given region iter_arg. Return
/// "failure" if non-subset ops are found along the path to the loop yielding
/// op or if there is no single path to the tied yielded operand. If
/// `collectHoistableOps` is set to "false", subset ops are gathered
/// throughout the traversal, but not enumerated by `getHoistableSubsetOps`.
LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
BlockArgument iterArg,
bool collectHoistableOps = true);
private:
/// Helper function for equivalence of tensor values. Since only insertion
/// subset ops (that are also destination style ops) are followed when
/// traversing the SSA use-def chain, all tensor values are equivalent.
static bool isEquivalent(Value v1, Value v2) { return true; }
/// Return "true" if the subsets of the given extraction and insertion ops
/// are operating disjoint from the subsets that all other known subset ops
/// are operating on.
bool allDisjoint(SubsetExtractionOpInterface extractionOp,
SubsetInsertionOpInterface insertionOp) const {
for (SubsetOpInterface other : allSubsetOps) {
if (other == extractionOp || other == insertionOp)
continue;
if (extractionOp &&
!other.operatesOnDisjointSubset(extractionOp, isEquivalent))
return false;
if (insertionOp &&
!other.operatesOnDisjointSubset(insertionOp, isEquivalent))
return false;
}
return true;
}
/// Insert a subset extraction op. If the subset is equivalent to an existing
/// subset insertion op, pair them up. (If there is already a paired up subset
/// extraction op, overwrite the subset extraction op.)
void insertExtractionOp(SubsetExtractionOpInterface extractionOp) {
for (auto it : llvm::enumerate(insertions)) {
if (!it.value())
continue;
auto other = cast<SubsetOpInterface>(it.value().getOperation());
if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) {
extractions[it.index()] = extractionOp;
return;
}
}
// There is no known equivalent insertion op. Create a new entry.
extractions.push_back(extractionOp);
insertions.push_back({});
}
/// Insert a subset insertion op. If the subset is equivalent to an existing
/// subset extraction op, pair them up. (If there is already a paired up
/// subset insertion op, overwrite the subset insertion op.)
void insertInsertionOp(SubsetInsertionOpInterface insertionOp) {
for (auto it : llvm::enumerate(extractions)) {
if (!it.value())
continue;
auto other = cast<SubsetOpInterface>(it.value().getOperation());
if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) {
insertions[it.index()] = insertionOp;
return;
}
}
// There is no known equivalent extraction op. Create a new entry.
extractions.push_back({});
insertions.push_back(insertionOp);
}
SmallVector<SubsetExtractionOpInterface> extractions;
SmallVector<SubsetInsertionOpInterface> insertions;
SmallVector<SubsetOpInterface> allSubsetOps;
};
} // namespace
/// If the given value has a single use by an op that is a terminator, return
/// that use. Otherwise, return nullptr.
static OpOperand *getSingleTerminatorUse(Value value) {
if (!value.hasOneUse())
return nullptr;
OpOperand &use = *value.getUses().begin();
if (use.getOwner()->hasTrait<OpTrait::IsTerminator>())
return &use;
return nullptr;
}
LogicalResult
MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
BlockArgument iterArg,
bool collectHoistableOps) {
assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
Value value = iterArg;
// Traverse use-def chain. Subset ops can be hoisted only if all ops along the
// use-def chain starting from the region iter_arg are subset extraction or
// subset insertion ops. The chain must terminate at the corresponding yield
// operand (e.g., no swapping of iter_args).
OpOperand *yieldedOperand = nullptr;
// Iterate until the single use of the current SSA value is a terminator,
// which is expected to be the yielding operation of the loop.
while (!(yieldedOperand = getSingleTerminatorUse(value))) {
Value nextValue = {};
for (OpOperand &use : value.getUses()) {
if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
// Subset ops in nested loops are collected to check if there are only
// disjoint subset ops, but such subset ops are not subject to hoisting.
// To hoist subset ops from nested loops, the hoisting transformation
// should be run on the nested loop.
auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use);
if (!nestedIterArg)
return failure();
// Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA
// use-def chain starting at `nestedIterArg` and terminating in the
// tied, yielding operand.
if (failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg,
/*collectHoistableOps=*/false)))
return failure();
nextValue = nestedLoop.getTiedLoopResult(&use);
continue;
}
auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
if (!subsetOp)
return failure();
insert(subsetOp);
if (auto insertionOp =
dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
// The value must be used as a destination. (In case of a source, the
// entire tensor would be read, which would prevent any hoisting.)
if (&use != &insertionOp.getDestinationOperand())
return failure();
// There must be a single use-def chain from the region iter_arg to the
// terminator. I.e., only one insertion op. Branches are not supported.
if (nextValue)
return failure();
nextValue = insertionOp.getUpdatedDestination();
}
}
// Nothing can be hoisted if the chain does not continue with loop yielding
// op or a subset insertion op.
if (!nextValue)
return failure();
value = nextValue;
}
// Hoist only if the SSA use-def chain ends in the yielding terminator of the
// loop and the yielded value is the `idx`-th operand. (I.e., there is no
// swapping yield.)
if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
return failure();
return success();
}
/// Hoist all subset ops that operate on the idx-th region iter_arg of the given
/// loop-like op and index into loop-invariant subset locations. Return the
/// newly created loop op (that has extra iter_args) or the original loop op if
/// nothing was hoisted.
static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
LoopLikeOpInterface loopLike,
BlockArgument iterArg) {
assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
MatchingSubsets subsets;
if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
return loopLike;
// Hoist all matching extraction-insertion pairs one-by-one.
for (auto it : subsets.getHoistableSubsetOps()) {
auto extractionOp = std::get<0>(it);
auto insertionOp = std::get<1>(it);
// Ops cannot be hoisted if they depend on loop-variant values.
if (extractionOp) {
if (!canBeHoisted(extractionOp, [&](OpOperand &operand) {
return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
&operand == &extractionOp.getSourceOperand();
}))
extractionOp = {};
}
if (insertionOp) {
if (!canBeHoisted(insertionOp, [&](OpOperand &operand) {
return loopLike.isDefinedOutsideOfLoop(operand.get()) ||
&operand == &insertionOp.getSourceOperand() ||
&operand == &insertionOp.getDestinationOperand();
}))
insertionOp = {};
}
// Only hoist extraction-insertion pairs for now. Standalone extractions/
// insertions that are loop-invariant could be hoisted, but there may be
// easier ways to canonicalize the IR.
if (extractionOp && insertionOp) {
// Create a new loop with an additional iter_arg.
NewYieldValuesFn newYieldValuesFn =
[&](OpBuilder &b, Location loc,
ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
return {insertionOp.getSourceOperand().get()};
};
FailureOr<LoopLikeOpInterface> newLoop =
loopLike.replaceWithAdditionalYields(
rewriter, extractionOp.getResult(),
/*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn);
if (failed(newLoop))
return loopLike;
loopLike = *newLoop;
// Hoist the extraction/insertion ops.
iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
OpResult newLoopResult = loopLike.getLoopResults()->back();
extractionOp->moveBefore(loopLike);
insertionOp->moveAfter(loopLike);
rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
insertionOp.getDestinationOperand().get());
extractionOp.getSourceOperand().set(
loopLike.getTiedLoopInit(iterArg)->get());
rewriter.replaceAllUsesWith(loopResult,
insertionOp.getUpdatedDestination());
insertionOp.getSourceOperand().set(newLoopResult);
insertionOp.getDestinationOperand().set(loopResult);
}
}
return loopLike;
}
LoopLikeOpInterface
mlir::hoistLoopInvariantSubsets(RewriterBase &rewriter,
LoopLikeOpInterface loopLike) {
// Note: As subset ops are getting hoisted, the number of region iter_args
// increases. This can enable further hoisting opportunities on the new
// iter_args.
for (int64_t i = 0;
i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) {
loopLike = hoistSubsetAtIterArg(rewriter, loopLike,
loopLike.getRegionIterArgs()[i]);
}
return loopLike;
}