//===- EraseUnusedOperandsAndResults.cpp ----------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" using namespace mlir; using namespace mlir::linalg; /// Return `true` if the `result` of an operation `genericOp` is dead. static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) { if (!result.use_empty()) return false; // If out operand not used in payload, we can drop it. OpOperand *outputOpOperand = genericOp.getDpsInitOperand(result.getResultNumber()); if (!genericOp.payloadUsesValueFromOperand(outputOpOperand)) return true; // The out operand that is part of a payload can be dropped if // these conditions are met: // - Result from out operand is dead. // - User of arg is yield. // - outArg data is not being used by other outArgs. // Check block arg and cycle from out operand has a single use. BlockArgument outputArg = genericOp.getRegionOutputArgs()[result.getResultNumber()]; if (!outputArg.hasOneUse()) return false; Operation *argUserOp = *outputArg.user_begin(); // Check argUser has no other use. if (!argUserOp->use_empty()) return false; // Check that argUser is a yield. auto yieldOp = dyn_cast(argUserOp); if (!yieldOp) return false; // Check outArg data is not being used by other outArgs. if (yieldOp.getOperand(result.getResultNumber()) != outputArg) return false; return true; } namespace { struct DeduplicateAndRemoveDeadOperandsAndResults : public OpRewritePattern { DeduplicateAndRemoveDeadOperandsAndResults(MLIRContext *ctx, bool removeOutputs) : OpRewritePattern(ctx), removeOutputs(removeOutputs) {} LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { // Create a map from argument position in the original op to the argument // position in the new op. If the argument is dropped it wont have an entry. SmallVector droppedOpOperands; // Information needed to build the new op. SmallVector newInputOperands, newOutputOperands; SmallVector newIndexingMaps; // Gather information about duplicate input operands. llvm::SmallDenseMap origInsToNewInsPos = deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands, newIndexingMaps); // Gather information about the dropped outputs. llvm::SmallDenseMap origOutsToNewOutsPos = deduplicateOutputOperands(genericOp, droppedOpOperands, newOutputOperands, newIndexingMaps); // Check if there is any change to operands. if (newInputOperands.size() + newOutputOperands.size() == genericOp->getNumOperands()) return failure(); // Create the new op with the body being empty. Location loc = genericOp.getLoc(); SmallVector newResultTypes; for (Value v : newOutputOperands) if (isa(v.getType())) newResultTypes.push_back(v.getType()); auto newOp = rewriter.create( loc, newResultTypes, newInputOperands, newOutputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps), genericOp.getIteratorTypes(), genericOp.getDocAttr(), genericOp.getLibraryCallAttr(), [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) { return; }); // Copy over unknown attributes. They might be load bearing for some flow. ArrayRef odsAttrs = genericOp.getAttributeNames(); for (NamedAttribute kv : genericOp->getAttrs()) if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) newOp->setAttr(kv.getName(), kv.getValue()); // Fix up the payload of the canonicalized operation. populateOpPayload(genericOp, newOp, origInsToNewInsPos, origOutsToNewOutsPos, rewriter); // Replace all live uses of the op. SmallVector replacementsVals(genericOp->getNumResults(), nullptr); for (const auto &result : llvm::enumerate(genericOp.getResults())) { auto it = origOutsToNewOutsPos.find(result.index()); if (it == origOutsToNewOutsPos.end()) continue; replacementsVals[result.index()] = newOp.getResult(it->second); } rewriter.replaceOp(genericOp, replacementsVals); return success(); } private: /// If unset, outputs are not modified by this pattern. bool removeOutputs; // Deduplicate input operands, and return the // - Mapping from operand position in the original op, to operand position in // the canonicalized op. // - The preserved input operands list (by reference). llvm::SmallDenseMap deduplicateInputOperands(GenericOp genericOp, SmallVector &droppedOpOperands, SmallVector &newInputOperands, SmallVector &newIndexingMaps) const { llvm::SmallDenseMap origToNewPos; llvm::SmallDenseMap, unsigned> dedupedInputs; for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) { OpOperand *inputOpOperand = en.value(); // Check if operand is dead and if dropping the indexing map makes the // loops to shape computation invalid. if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) { // Add the current operands to the list of potentially droppable // operands. If it cannot be dropped, this needs to be popped back. droppedOpOperands.push_back(inputOpOperand); if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) continue; droppedOpOperands.pop_back(); } // Check if this operand is a duplicate. AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand); auto it = dedupedInputs.find( std::make_pair(inputOpOperand->get(), indexingMap)); if (it != dedupedInputs.end()) { origToNewPos[en.index()] = it->second; droppedOpOperands.push_back(inputOpOperand); continue; } // This is a preserved argument. origToNewPos[en.index()] = newInputOperands.size(); dedupedInputs[{inputOpOperand->get(), indexingMap}] = newInputOperands.size(); newInputOperands.push_back(inputOpOperand->get()); newIndexingMaps.push_back(indexingMap); } return origToNewPos; } // Deduplicate output operands, and return the // - Mapping from operand position in the original op, to operand position in // the canonicalized op. // - The preserved output operands list (by reference). llvm::SmallDenseMap deduplicateOutputOperands(GenericOp genericOp, SmallVector &droppedOpOperands, SmallVector &newOutputOperands, SmallVector &newIndexingMaps) const { llvm::SmallDenseMap origToNewPos; llvm::SmallDenseMap, unsigned> dedupedOutpts; // If the op doesn't have tensor semantics or outputs should not be removed, // keep all the outputs as preserved. if (!genericOp.hasPureTensorSemantics() || !removeOutputs) { for (const auto &en : llvm::enumerate(genericOp.getDpsInitsMutable())) { origToNewPos[en.index()] = newOutputOperands.size(); newOutputOperands.push_back(en.value().get()); newIndexingMaps.push_back( genericOp.getMatchingIndexingMap(&en.value())); } return origToNewPos; } // Output argument can be dropped if the result has // - no users, and // - it is not used in the payload, and // - the corresponding indexing maps are not needed for loop bound // computation. auto yieldOp = cast(genericOp.getBody()->getTerminator()); for (const auto &outputOpOperand : llvm::enumerate(genericOp.getDpsInitsMutable())) { OpResult result = genericOp.getTiedOpResult(&outputOpOperand.value()); AffineMap indexingMap = genericOp.getMatchingIndexingMap(&outputOpOperand.value()); auto key = std::make_tuple(outputOpOperand.value().get(), indexingMap, yieldOp->getOperand(outputOpOperand.index())); if (isResultValueDead(genericOp, result)) { // Check if the opoperand can be dropped without affecting loop // bound computation. Add the operand to the list of dropped op // operand for checking. If it cannot be dropped, need to pop the // value back. droppedOpOperands.push_back(&outputOpOperand.value()); if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) { continue; } droppedOpOperands.pop_back(); } if (!genericOp.payloadUsesValueFromOperand(&outputOpOperand.value())) { // The out operand can also be dropped if it is computed redundantly // by another result, the conditions for that are // - The same operand is used as the out operand // - The same indexing map is used // - The same yield value is used. auto it = dedupedOutpts.find(key); if (it != dedupedOutpts.end()) { origToNewPos[outputOpOperand.index()] = it->second; droppedOpOperands.push_back(&outputOpOperand.value()); continue; } } origToNewPos[outputOpOperand.index()] = newOutputOperands.size(); dedupedOutpts[key] = newOutputOperands.size(); newOutputOperands.push_back(outputOpOperand.value().get()); newIndexingMaps.push_back( genericOp.getMatchingIndexingMap(&outputOpOperand.value())); } return origToNewPos; } // Populate the body of the canonicalized operation. void populateOpPayload( GenericOp genericOp, GenericOp newOp, const llvm::SmallDenseMap &origInsToNewInsPos, const llvm::SmallDenseMap &origOutsToNewOutsPos, PatternRewriter &rewriter) const { // Merge the body of the original op with the new op. Block *newOpBlock = &newOp.getRegion().front(); assert(newOpBlock->empty() && "expected new op to have an empty payload"); Block *origOpBlock = &genericOp.getRegion().front(); SmallVector replacements(origOpBlock->getNumArguments(), nullptr); // Replace all arguments in the original op, with arguments from the // canonicalized op. auto updateReplacements = [&](SmallVector &origOperands, SmallVector &newOperands, const llvm::SmallDenseMap &map) { for (const auto &origOperand : llvm::enumerate(origOperands)) { auto it = map.find(origOperand.index()); if (it == map.end()) continue; OpOperand *newOperand = newOperands[it->second]; replacements[origOperand.value()->getOperandNumber()] = newOpBlock->getArgument(newOperand->getOperandNumber()); } }; SmallVector origInputOperands = genericOp.getDpsInputOperands(); SmallVector newInputOperands = newOp.getDpsInputOperands(); updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos); SmallVector origOutputOperands = llvm::to_vector(llvm::map_range(genericOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); SmallVector newOutputOperands = llvm::to_vector(llvm::map_range(newOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); updateReplacements(origOutputOperands, newOutputOperands, origOutsToNewOutsPos); // Drop the unused yield args. if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) { OpBuilder::InsertionGuard g(rewriter); YieldOp origYieldOp = cast(origOpBlock->getTerminator()); rewriter.setInsertionPoint(origYieldOp); SmallVector newYieldVals(newOp.getNumDpsInits(), nullptr); for (const auto &yieldOpOperands : llvm::enumerate(origYieldOp.getValues())) { auto it = origOutsToNewOutsPos.find(yieldOpOperands.index()); if (it == origOutsToNewOutsPos.end()) continue; newYieldVals[it->second] = yieldOpOperands.value(); } rewriter.replaceOpWithNewOp(origYieldOp, newYieldVals); } rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements); } }; /// Remove unused cycles. /// We can remove unused cycle within a payload of generic region /// if these conditions are met: /// - Result from out operand is dead. /// - Block arg from out operand has a single use in the %cycle /// instruction. /// - Cycle has a single use and it is in yield. struct RemoveUnusedCycleInGenericOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { // If the op doesnt have tensor semantics, preserve the outputs as is. if (!genericOp.hasPureTensorSemantics()) return failure(); bool hasRemovedCycles = false; // Iterate over output operands and remove any unused cycles. for (const auto &outputOpOperand : llvm::enumerate(genericOp.getDpsInits())) { // Check that result from out operand is dead. Value result = genericOp.getResult(outputOpOperand.index()); if (!result.use_empty()) continue; // Check that outputArg has one use in cycle. BlockArgument outputArg = genericOp.getRegionOutputArgs()[outputOpOperand.index()]; if (!outputArg.hasOneUse()) continue; // Check cycle has at most one use. Operation *cycleOp = *outputArg.user_begin(); if (!cycleOp->hasOneUse()) continue; // Check that the cycleUser is a yield. Operation *cycleUserOp = *cycleOp->user_begin(); if (!isa(cycleUserOp)) continue; // Check that argIndex matches yieldIndex, else data is being used. if (cycleUserOp->getOperand(outputOpOperand.index()) != cycleOp->getResult(0)) continue; // Directly replace the cycle with the blockArg such that // Deduplicate pattern can eliminate it along with unused yield. rewriter.replaceOp(cycleOp, outputArg); rewriter.modifyOpInPlace(genericOp, [] {}); hasRemovedCycles = true; } if (hasRemovedCycles) { return success(); } return failure(); } }; /// Fold uses of duplicate inputs in the body of a linalg.generic. E.g.: /// ``` /// linalg.generic ins(%a, %b, %a, %b) outs(%a) /// ^bb0(%in0, %in1, %in2, %in3, %out1) /// ``` /// Assuming that all %a and %b have the same index map: /// * All uses of %in0 and %in2 are replaced with %out1 /// * All uses of %in1 are replaced with %in3 /// This pattern can enable additional canonicalizations: In the above example, /// %in0, %in1 and %in3 have no uses anymore and their corresponding operands /// can be folded away. This pattern does not modify uses of output block args. struct FoldDuplicateInputBbArgs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { // Find replacement bbArgs for all input bbArg. DenseMap replacements; for (int i = 0; i < genericOp.getNumDpsInputs(); ++i) { // Skip bbArgs that have no uses. if (genericOp.getBody()->getArgument(i).getUses().empty()) continue; // Find replacement bbArg. This can be an input or an output bbArg. for (int j = genericOp->getNumOperands() - 1; j > i; --j) { if (genericOp->getOperand(i) == genericOp->getOperand(j) && genericOp.getIndexingMapsArray()[i] == genericOp.getIndexingMapsArray()[j]) { replacements[i] = j; break; } } } // Stop here if no replacements were found. if (replacements.empty()) return failure(); // Rewrite the op. rewriter.modifyOpInPlace(genericOp, [&]() { for (auto [before, after] : replacements) { BlockArgument bbArg = genericOp.getBody()->getArgument(before); BlockArgument replacement = genericOp.getBody()->getArgument(after); rewriter.replaceAllUsesWith(bbArg, replacement); } }); return success(); } }; } // namespace void mlir::linalg::populateEraseUnusedOperandsAndResultsPatterns( RewritePatternSet &patterns) { patterns.insert( patterns.getContext(), /*removeOutputs=*/true); patterns.insert(patterns.getContext()); } void mlir::linalg::populateEraseUnnecessaryInputsPatterns( RewritePatternSet &patterns) { patterns.insert( patterns.getContext(), /*removeOutputs=*/false); patterns.insert(patterns.getContext()); }