//===- UseDefAnalysis.cpp - Analysis for Transitive UseDef chains ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements Analysis functions specific to slicing in Function. // //===----------------------------------------------------------------------===// #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" /// /// Implements Analysis functions specific to slicing in Function. /// using namespace mlir; static void getForwardSliceImpl(Operation *op, SetVector *forwardSlice, const SliceOptions::TransitiveFilter &filter = nullptr) { if (!op) return; // Evaluate whether we should keep this use. // This is useful in particular to implement scoping; i.e. return the // transitive forwardSlice in the current scope. if (filter && !filter(op)) return; for (Region ®ion : op->getRegions()) for (Block &block : region) for (Operation &blockOp : block) if (forwardSlice->count(&blockOp) == 0) getForwardSliceImpl(&blockOp, forwardSlice, filter); for (Value result : op->getResults()) { for (Operation *userOp : result.getUsers()) if (forwardSlice->count(userOp) == 0) getForwardSliceImpl(userOp, forwardSlice, filter); } forwardSlice->insert(op); } void mlir::getForwardSlice(Operation *op, SetVector *forwardSlice, const ForwardSliceOptions &options) { getForwardSliceImpl(op, forwardSlice, options.filter); if (!options.inclusive) { // Don't insert the top level operation, we just queried on it and don't // want it in the results. forwardSlice->remove(op); } // Reverse to get back the actual topological order. // std::reverse does not work out of the box on SetVector and I want an // in-place swap based thing (the real std::reverse, not the LLVM adapter). SmallVector v(forwardSlice->takeVector()); forwardSlice->insert(v.rbegin(), v.rend()); } void mlir::getForwardSlice(Value root, SetVector *forwardSlice, const SliceOptions &options) { for (Operation *user : root.getUsers()) getForwardSliceImpl(user, forwardSlice, options.filter); // Reverse to get back the actual topological order. // std::reverse does not work out of the box on SetVector and I want an // in-place swap based thing (the real std::reverse, not the LLVM adapter). SmallVector v(forwardSlice->takeVector()); forwardSlice->insert(v.rbegin(), v.rend()); } static void getBackwardSliceImpl(Operation *op, SetVector *backwardSlice, const BackwardSliceOptions &options) { if (!op || op->hasTrait()) return; // Evaluate whether we should keep this def. // This is useful in particular to implement scoping; i.e. return the // transitive backwardSlice in the current scope. if (options.filter && !options.filter(op)) return; for (const auto &en : llvm::enumerate(op->getOperands())) { auto operand = en.value(); if (auto *definingOp = operand.getDefiningOp()) { if (backwardSlice->count(definingOp) == 0) getBackwardSliceImpl(definingOp, backwardSlice, options); } else if (auto blockArg = dyn_cast(operand)) { if (options.omitBlockArguments) continue; Block *block = blockArg.getOwner(); Operation *parentOp = block->getParentOp(); // TODO: determine whether we want to recurse backward into the other // blocks of parentOp, which are not technically backward unless they flow // into us. For now, just bail. if (parentOp && backwardSlice->count(parentOp) == 0) { assert(parentOp->getNumRegions() == 1 && parentOp->getRegion(0).getBlocks().size() == 1); getBackwardSliceImpl(parentOp, backwardSlice, options); } } else { llvm_unreachable("No definingOp and not a block argument."); } } backwardSlice->insert(op); } void mlir::getBackwardSlice(Operation *op, SetVector *backwardSlice, const BackwardSliceOptions &options) { getBackwardSliceImpl(op, backwardSlice, options); if (!options.inclusive) { // Don't insert the top level operation, we just queried on it and don't // want it in the results. backwardSlice->remove(op); } } void mlir::getBackwardSlice(Value root, SetVector *backwardSlice, const BackwardSliceOptions &options) { if (Operation *definingOp = root.getDefiningOp()) { getBackwardSlice(definingOp, backwardSlice, options); return; } Operation *bbAargOwner = cast(root).getOwner()->getParentOp(); getBackwardSlice(bbAargOwner, backwardSlice, options); } SetVector mlir::getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions, const ForwardSliceOptions &forwardSliceOptions) { SetVector slice; slice.insert(op); unsigned currentIndex = 0; SetVector backwardSlice; SetVector forwardSlice; while (currentIndex != slice.size()) { auto *currentOp = (slice)[currentIndex]; // Compute and insert the backwardSlice starting from currentOp. backwardSlice.clear(); getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); slice.insert(backwardSlice.begin(), backwardSlice.end()); // Compute and insert the forwardSlice starting from currentOp. forwardSlice.clear(); getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions); slice.insert(forwardSlice.begin(), forwardSlice.end()); ++currentIndex; } return topologicalSort(slice); } namespace { /// DFS post-order implementation that maintains a global count to work across /// multiple invocations, to help implement topological sort on multi-root DAGs. /// We traverse all operations but only record the ones that appear in /// `toSort` for the final result. struct DFSState { DFSState(const SetVector &set) : toSort(set), seen() {} const SetVector &toSort; SmallVector topologicalCounts; DenseSet seen; }; } // namespace static void dfsPostorder(Operation *root, DFSState *state) { SmallVector queue(1, root); std::vector ops; while (!queue.empty()) { Operation *current = queue.pop_back_val(); ops.push_back(current); for (Operation *op : current->getUsers()) queue.push_back(op); for (Region ®ion : current->getRegions()) { for (Operation &op : region.getOps()) queue.push_back(&op); } } for (Operation *op : llvm::reverse(ops)) { if (state->seen.insert(op).second && state->toSort.count(op) > 0) state->topologicalCounts.push_back(op); } } SetVector mlir::topologicalSort(const SetVector &toSort) { if (toSort.empty()) { return toSort; } // Run from each root with global count and `seen` set. DFSState state(toSort); for (auto *s : toSort) { assert(toSort.count(s) == 1 && "NYI: multi-sets not supported"); dfsPostorder(s, &state); } // Reorder and return. SetVector res; for (auto it = state.topologicalCounts.rbegin(), eit = state.topologicalCounts.rend(); it != eit; ++it) { res.insert(*it); } return res; } /// Returns true if `value` (transitively) depends on iteration-carried values /// of the given `ancestorOp`. static bool dependsOnCarriedVals(Value value, ArrayRef iterCarriedArgs, Operation *ancestorOp) { // Compute the backward slice of the value. SetVector slice; BackwardSliceOptions sliceOptions; sliceOptions.filter = [&](Operation *op) { return !ancestorOp->isAncestor(op); }; getBackwardSlice(value, &slice, sliceOptions); // Check that none of the operands of the operations in the backward slice are // loop iteration arguments, and neither is the value itself. SmallPtrSet iterCarriedValSet(iterCarriedArgs.begin(), iterCarriedArgs.end()); if (iterCarriedValSet.contains(value)) return true; for (Operation *op : slice) for (Value operand : op->getOperands()) if (iterCarriedValSet.contains(operand)) return true; return false; } /// Utility to match a generic reduction given a list of iteration-carried /// arguments, `iterCarriedArgs` and the position of the potential reduction /// argument within the list, `redPos`. If a reduction is matched, returns the /// reduced value and the topologically-sorted list of combiner operations /// involved in the reduction. Otherwise, returns a null value. /// /// The matching algorithm relies on the following invariants, which are subject /// to change: /// 1. The first combiner operation must be a binary operation with the /// iteration-carried value and the reduced value as operands. /// 2. The iteration-carried value and combiner operations must be side /// effect-free, have single result and a single use. /// 3. Combiner operations must be immediately nested in the region op /// performing the reduction. /// 4. Reduction def-use chain must end in a terminator op that yields the /// next iteration/output values in the same order as the iteration-carried /// values in `iterCarriedArgs`. /// 5. `iterCarriedArgs` must contain all the iteration-carried/output values /// of the region op performing the reduction. /// /// This utility is generic enough to detect reductions involving multiple /// combiner operations (disabled for now) across multiple dialects, including /// Linalg, Affine and SCF. For the sake of genericity, it does not return /// specific enum values for the combiner operations since its goal is also /// matching reductions without pre-defined semantics in core MLIR. It's up to /// each client to make sense out of the list of combiner operations. It's also /// up to each client to check for additional invariants on the expected /// reductions not covered by this generic matching. Value mlir::matchReduction(ArrayRef iterCarriedArgs, unsigned redPos, SmallVectorImpl &combinerOps) { assert(redPos < iterCarriedArgs.size() && "'redPos' is out of bounds"); BlockArgument redCarriedVal = iterCarriedArgs[redPos]; if (!redCarriedVal.hasOneUse()) return nullptr; // For now, the first combiner op must be a binary op. Operation *combinerOp = *redCarriedVal.getUsers().begin(); if (combinerOp->getNumOperands() != 2) return nullptr; Value reducedVal = combinerOp->getOperand(0) == redCarriedVal ? combinerOp->getOperand(1) : combinerOp->getOperand(0); Operation *redRegionOp = iterCarriedArgs.front().getOwner()->getParent()->getParentOp(); if (dependsOnCarriedVals(reducedVal, iterCarriedArgs, redRegionOp)) return nullptr; // Traverse the def-use chain starting from the first combiner op until a // terminator is found. Gather all the combiner ops along the way in // topological order. while (!combinerOp->mightHaveTrait()) { if (!isMemoryEffectFree(combinerOp) || combinerOp->getNumResults() != 1 || !combinerOp->hasOneUse() || combinerOp->getParentOp() != redRegionOp) return nullptr; combinerOps.push_back(combinerOp); combinerOp = *combinerOp->getUsers().begin(); } // Limit matching to single combiner op until we can properly test reductions // involving multiple combiners. if (combinerOps.size() != 1) return nullptr; // Check that the yielded value is in the same position as in // `iterCarriedArgs`. Operation *terminatorOp = combinerOp; if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0]) return nullptr; return reducedVal; }