433 lines
18 KiB
C++
433 lines
18 KiB
C++
//===- EraseUnusedOperandsAndResults.cpp ----------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
/// Return `true` if the `result` of an operation `genericOp` is dead.
|
|
static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
|
|
if (!result.use_empty())
|
|
return false;
|
|
// If out operand not used in payload, we can drop it.
|
|
OpOperand *outputOpOperand =
|
|
genericOp.getDpsInitOperand(result.getResultNumber());
|
|
if (!genericOp.payloadUsesValueFromOperand(outputOpOperand))
|
|
return true;
|
|
|
|
// The out operand that is part of a payload can be dropped if
|
|
// these conditions are met:
|
|
// - Result from out operand is dead.
|
|
// - User of arg is yield.
|
|
// - outArg data is not being used by other outArgs.
|
|
|
|
// Check block arg and cycle from out operand has a single use.
|
|
BlockArgument outputArg =
|
|
genericOp.getRegionOutputArgs()[result.getResultNumber()];
|
|
if (!outputArg.hasOneUse())
|
|
return false;
|
|
Operation *argUserOp = *outputArg.user_begin();
|
|
|
|
// Check argUser has no other use.
|
|
if (!argUserOp->use_empty())
|
|
return false;
|
|
|
|
// Check that argUser is a yield.
|
|
auto yieldOp = dyn_cast<linalg::YieldOp>(argUserOp);
|
|
if (!yieldOp)
|
|
return false;
|
|
|
|
// Check outArg data is not being used by other outArgs.
|
|
if (yieldOp.getOperand(result.getResultNumber()) != outputArg)
|
|
return false;
|
|
|
|
return true;
|
|
}
|
|
|
|
namespace {
|
|
|
|
struct DeduplicateAndRemoveDeadOperandsAndResults
|
|
: public OpRewritePattern<GenericOp> {
|
|
DeduplicateAndRemoveDeadOperandsAndResults(MLIRContext *ctx,
|
|
bool removeOutputs)
|
|
: OpRewritePattern<GenericOp>(ctx), removeOutputs(removeOutputs) {}
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// Create a map from argument position in the original op to the argument
|
|
// position in the new op. If the argument is dropped it wont have an entry.
|
|
SmallVector<OpOperand *> droppedOpOperands;
|
|
|
|
// Information needed to build the new op.
|
|
SmallVector<Value> newInputOperands, newOutputOperands;
|
|
SmallVector<AffineMap> newIndexingMaps;
|
|
|
|
// Gather information about duplicate input operands.
|
|
llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
|
|
deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
|
|
newIndexingMaps);
|
|
|
|
// Gather information about the dropped outputs.
|
|
llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
|
|
deduplicateOutputOperands(genericOp, droppedOpOperands,
|
|
newOutputOperands, newIndexingMaps);
|
|
|
|
// Check if there is any change to operands.
|
|
if (newInputOperands.size() + newOutputOperands.size() ==
|
|
genericOp->getNumOperands())
|
|
return failure();
|
|
|
|
// Create the new op with the body being empty.
|
|
Location loc = genericOp.getLoc();
|
|
SmallVector<Type> newResultTypes;
|
|
for (Value v : newOutputOperands)
|
|
if (isa<TensorType>(v.getType()))
|
|
newResultTypes.push_back(v.getType());
|
|
auto newOp = rewriter.create<GenericOp>(
|
|
loc, newResultTypes, newInputOperands, newOutputOperands,
|
|
rewriter.getAffineMapArrayAttr(newIndexingMaps),
|
|
genericOp.getIteratorTypes(), genericOp.getDocAttr(),
|
|
genericOp.getLibraryCallAttr(),
|
|
[](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
|
|
return;
|
|
});
|
|
// Copy over unknown attributes. They might be load bearing for some flow.
|
|
ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
|
|
for (NamedAttribute kv : genericOp->getAttrs())
|
|
if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
|
|
newOp->setAttr(kv.getName(), kv.getValue());
|
|
|
|
// Fix up the payload of the canonicalized operation.
|
|
populateOpPayload(genericOp, newOp, origInsToNewInsPos,
|
|
origOutsToNewOutsPos, rewriter);
|
|
|
|
// Replace all live uses of the op.
|
|
SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
|
|
for (const auto &result : llvm::enumerate(genericOp.getResults())) {
|
|
auto it = origOutsToNewOutsPos.find(result.index());
|
|
if (it == origOutsToNewOutsPos.end())
|
|
continue;
|
|
replacementsVals[result.index()] = newOp.getResult(it->second);
|
|
}
|
|
rewriter.replaceOp(genericOp, replacementsVals);
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
/// If unset, outputs are not modified by this pattern.
|
|
bool removeOutputs;
|
|
|
|
// Deduplicate input operands, and return the
|
|
// - Mapping from operand position in the original op, to operand position in
|
|
// the canonicalized op.
|
|
// - The preserved input operands list (by reference).
|
|
llvm::SmallDenseMap<unsigned, unsigned>
|
|
deduplicateInputOperands(GenericOp genericOp,
|
|
SmallVector<OpOperand *> &droppedOpOperands,
|
|
SmallVector<Value> &newInputOperands,
|
|
SmallVector<AffineMap> &newIndexingMaps) const {
|
|
llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
|
|
llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
|
|
for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
|
|
OpOperand *inputOpOperand = en.value();
|
|
// Check if operand is dead and if dropping the indexing map makes the
|
|
// loops to shape computation invalid.
|
|
if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
|
|
// Add the current operands to the list of potentially droppable
|
|
// operands. If it cannot be dropped, this needs to be popped back.
|
|
droppedOpOperands.push_back(inputOpOperand);
|
|
if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
|
|
continue;
|
|
droppedOpOperands.pop_back();
|
|
}
|
|
|
|
// Check if this operand is a duplicate.
|
|
AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
|
|
auto it = dedupedInputs.find(
|
|
std::make_pair(inputOpOperand->get(), indexingMap));
|
|
if (it != dedupedInputs.end()) {
|
|
origToNewPos[en.index()] = it->second;
|
|
droppedOpOperands.push_back(inputOpOperand);
|
|
continue;
|
|
}
|
|
|
|
// This is a preserved argument.
|
|
origToNewPos[en.index()] = newInputOperands.size();
|
|
dedupedInputs[{inputOpOperand->get(), indexingMap}] =
|
|
newInputOperands.size();
|
|
newInputOperands.push_back(inputOpOperand->get());
|
|
newIndexingMaps.push_back(indexingMap);
|
|
}
|
|
return origToNewPos;
|
|
}
|
|
|
|
// Deduplicate output operands, and return the
|
|
// - Mapping from operand position in the original op, to operand position in
|
|
// the canonicalized op.
|
|
// - The preserved output operands list (by reference).
|
|
llvm::SmallDenseMap<unsigned, unsigned>
|
|
deduplicateOutputOperands(GenericOp genericOp,
|
|
SmallVector<OpOperand *> &droppedOpOperands,
|
|
SmallVector<Value> &newOutputOperands,
|
|
SmallVector<AffineMap> &newIndexingMaps) const {
|
|
llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
|
|
llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
|
|
dedupedOutpts;
|
|
// If the op doesn't have tensor semantics or outputs should not be removed,
|
|
// keep all the outputs as preserved.
|
|
if (!genericOp.hasPureTensorSemantics() || !removeOutputs) {
|
|
for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) {
|
|
origToNewPos[en.index()] = newOutputOperands.size();
|
|
newOutputOperands.push_back(en.value().get());
|
|
newIndexingMaps.push_back(
|
|
genericOp.getMatchingIndexingMap(&en.value()));
|
|
}
|
|
return origToNewPos;
|
|
}
|
|
// Output argument can be dropped if the result has
|
|
// - no users, and
|
|
// - it is not used in the payload, and
|
|
// - the corresponding indexing maps are not needed for loop bound
|
|
// computation.
|
|
auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
|
|
for (const auto &outputOpOperand :
|
|
llvm::enumerate(genericOp.getDpsInitsMutable())) {
|
|
OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value());
|
|
AffineMap indexingMap =
|
|
genericOp.getMatchingIndexingMap(&outputOpOperand.value());
|
|
auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap,
|
|
yieldOp->getOperand(outputOpOperand.index()));
|
|
if (isResultValueDead(genericOp, result)) {
|
|
// Check if the opoperand can be dropped without affecting loop
|
|
// bound computation. Add the operand to the list of dropped op
|
|
// operand for checking. If it cannot be dropped, need to pop the
|
|
// value back.
|
|
droppedOpOperands.push_back(&outputOpOperand.value());
|
|
if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
|
|
continue;
|
|
}
|
|
droppedOpOperands.pop_back();
|
|
}
|
|
|
|
if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) {
|
|
// The out operand can also be dropped if it is computed redundantly
|
|
// by another result, the conditions for that are
|
|
// - The same operand is used as the out operand
|
|
// - The same indexing map is used
|
|
// - The same yield value is used.
|
|
auto it = dedupedOutpts.find(key);
|
|
if (it != dedupedOutpts.end()) {
|
|
origToNewPos[outputOpOperand.index()] = it->second;
|
|
droppedOpOperands.push_back(&outputOpOperand.value());
|
|
continue;
|
|
}
|
|
}
|
|
|
|
origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
|
|
dedupedOutpts[key] = newOutputOperands.size();
|
|
newOutputOperands.push_back(outputOpOperand.value().get());
|
|
newIndexingMaps.push_back(
|
|
genericOp.getMatchingIndexingMap(&outputOpOperand.value()));
|
|
}
|
|
return origToNewPos;
|
|
}
|
|
|
|
// Populate the body of the canonicalized operation.
|
|
void populateOpPayload(
|
|
GenericOp genericOp, GenericOp newOp,
|
|
const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
|
|
const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
|
|
PatternRewriter &rewriter) const {
|
|
// Merge the body of the original op with the new op.
|
|
Block *newOpBlock = &newOp.getRegion().front();
|
|
assert(newOpBlock->empty() && "expected new op to have an empty payload");
|
|
Block *origOpBlock = &genericOp.getRegion().front();
|
|
SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
|
|
|
|
// Replace all arguments in the original op, with arguments from the
|
|
// canonicalized op.
|
|
auto updateReplacements =
|
|
[&](SmallVector<OpOperand *> &origOperands,
|
|
SmallVector<OpOperand *> &newOperands,
|
|
const llvm::SmallDenseMap<unsigned, unsigned> &map) {
|
|
for (const auto &origOperand : llvm::enumerate(origOperands)) {
|
|
auto it = map.find(origOperand.index());
|
|
if (it == map.end())
|
|
continue;
|
|
OpOperand *newOperand = newOperands[it->second];
|
|
replacements[origOperand.value()->getOperandNumber()] =
|
|
newOpBlock->getArgument(newOperand->getOperandNumber());
|
|
}
|
|
};
|
|
|
|
SmallVector<OpOperand *> origInputOperands =
|
|
genericOp.getDpsInputOperands();
|
|
SmallVector<OpOperand *> newInputOperands = newOp.getDpsInputOperands();
|
|
updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
|
|
|
|
SmallVector<OpOperand *> origOutputOperands =
|
|
llvm::to_vector(llvm::map_range(genericOp.getDpsInitsMutable(),
|
|
[](OpOperand &o) { return &o; }));
|
|
SmallVector<OpOperand *> newOutputOperands =
|
|
llvm::to_vector(llvm::map_range(newOp.getDpsInitsMutable(),
|
|
[](OpOperand &o) { return &o; }));
|
|
updateReplacements(origOutputOperands, newOutputOperands,
|
|
origOutsToNewOutsPos);
|
|
|
|
// Drop the unused yield args.
|
|
if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) {
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
YieldOp origYieldOp = cast<YieldOp>(origOpBlock->getTerminator());
|
|
rewriter.setInsertionPoint(origYieldOp);
|
|
|
|
SmallVector<Value> newYieldVals(newOp.getNumDpsInits(), nullptr);
|
|
for (const auto &yieldOpOperands :
|
|
llvm::enumerate(origYieldOp.getValues())) {
|
|
auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
|
|
if (it == origOutsToNewOutsPos.end())
|
|
continue;
|
|
newYieldVals[it->second] = yieldOpOperands.value();
|
|
}
|
|
rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
|
|
}
|
|
|
|
rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
|
|
}
|
|
};
|
|
|
|
/// Remove unused cycles.
|
|
/// We can remove unused cycle within a payload of generic region
|
|
/// if these conditions are met:
|
|
/// - Result from out operand is dead.
|
|
/// - Block arg from out operand has a single use in the %cycle
|
|
/// instruction.
|
|
/// - Cycle has a single use and it is in yield.
|
|
struct RemoveUnusedCycleInGenericOp : public OpRewritePattern<GenericOp> {
|
|
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
// If the op doesnt have tensor semantics, preserve the outputs as is.
|
|
if (!genericOp.hasPureTensorSemantics())
|
|
return failure();
|
|
|
|
bool hasRemovedCycles = false;
|
|
// Iterate over output operands and remove any unused cycles.
|
|
for (const auto &outputOpOperand :
|
|
llvm::enumerate(genericOp.getDpsInits())) {
|
|
|
|
// Check that result from out operand is dead.
|
|
Value result = genericOp.getResult(outputOpOperand.index());
|
|
if (!result.use_empty())
|
|
continue;
|
|
|
|
// Check that outputArg has one use in cycle.
|
|
BlockArgument outputArg =
|
|
genericOp.getRegionOutputArgs()[outputOpOperand.index()];
|
|
if (!outputArg.hasOneUse())
|
|
continue;
|
|
|
|
// Check cycle has at most one use.
|
|
Operation *cycleOp = *outputArg.user_begin();
|
|
if (!cycleOp->hasOneUse())
|
|
continue;
|
|
|
|
// Check that the cycleUser is a yield.
|
|
Operation *cycleUserOp = *cycleOp->user_begin();
|
|
if (!isa<linalg::YieldOp>(cycleUserOp))
|
|
continue;
|
|
|
|
// Check that argIndex matches yieldIndex, else data is being used.
|
|
if (cycleUserOp->getOperand(outputOpOperand.index()) !=
|
|
cycleOp->getResult(0))
|
|
continue;
|
|
|
|
// Directly replace the cycle with the blockArg such that
|
|
// Deduplicate pattern can eliminate it along with unused yield.
|
|
rewriter.replaceOp(cycleOp, outputArg);
|
|
rewriter.modifyOpInPlace(genericOp, [] {});
|
|
hasRemovedCycles = true;
|
|
}
|
|
|
|
if (hasRemovedCycles) {
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
/// Fold uses of duplicate inputs in the body of a linalg.generic. E.g.:
|
|
/// ```
|
|
/// linalg.generic ins(%a, %b, %a, %b) outs(%a)
|
|
/// ^bb0(%in0, %in1, %in2, %in3, %out1)
|
|
/// ```
|
|
/// Assuming that all %a and %b have the same index map:
|
|
/// * All uses of %in0 and %in2 are replaced with %out1
|
|
/// * All uses of %in1 are replaced with %in3
|
|
/// This pattern can enable additional canonicalizations: In the above example,
|
|
/// %in0, %in1 and %in3 have no uses anymore and their corresponding operands
|
|
/// can be folded away. This pattern does not modify uses of output block args.
|
|
struct FoldDuplicateInputBbArgs : public OpRewritePattern<GenericOp> {
|
|
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// Find replacement bbArgs for all input bbArg.
|
|
DenseMap<int, int> replacements;
|
|
for (int i = 0; i < genericOp.getNumDpsInputs(); ++i) {
|
|
// Skip bbArgs that have no uses.
|
|
if (genericOp.getBody()->getArgument(i).getUses().empty())
|
|
continue;
|
|
// Find replacement bbArg. This can be an input or an output bbArg.
|
|
for (int j = genericOp->getNumOperands() - 1; j > i; --j) {
|
|
if (genericOp->getOperand(i) == genericOp->getOperand(j) &&
|
|
genericOp.getIndexingMapsArray()[i] ==
|
|
genericOp.getIndexingMapsArray()[j]) {
|
|
replacements[i] = j;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Stop here if no replacements were found.
|
|
if (replacements.empty())
|
|
return failure();
|
|
|
|
// Rewrite the op.
|
|
rewriter.modifyOpInPlace(genericOp, [&]() {
|
|
for (auto [before, after] : replacements) {
|
|
BlockArgument bbArg = genericOp.getBody()->getArgument(before);
|
|
BlockArgument replacement = genericOp.getBody()->getArgument(after);
|
|
rewriter.replaceAllUsesWith(bbArg, replacement);
|
|
}
|
|
});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::linalg::populateEraseUnusedOperandsAndResultsPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
|
|
patterns.getContext(), /*removeOutputs=*/true);
|
|
patterns.insert<RemoveUnusedCycleInGenericOp>(patterns.getContext());
|
|
}
|
|
|
|
void mlir::linalg::populateEraseUnnecessaryInputsPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns.insert<DeduplicateAndRemoveDeadOperandsAndResults>(
|
|
patterns.getContext(), /*removeOutputs=*/false);
|
|
patterns.insert<FoldDuplicateInputBbArgs>(patterns.getContext());
|
|
}
|