//===- Utils.cpp ---- Misc utilities for analysis -------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements miscellaneous analysis routines for non-loop IR // structures. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/Analysis/Utils.h" #include "mlir/Analysis/Presburger/PresburgerRelation.h" #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/IntegerSet.h" #include "mlir/Interfaces/CallInterfaces.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include #define DEBUG_TYPE "analysis-utils" using namespace mlir; using namespace affine; using namespace presburger; using llvm::SmallDenseMap; using Node = MemRefDependenceGraph::Node; // LoopNestStateCollector walks loop nests and collects load and store // operations, and whether or not a region holding op other than ForOp and IfOp // was encountered in the loop nest. void LoopNestStateCollector::collect(Operation *opToWalk) { opToWalk->walk([&](Operation *op) { if (isa(op)) forOps.push_back(cast(op)); else if (op->getNumRegions() != 0 && !isa(op)) hasNonAffineRegionOp = true; else if (isa(op)) loadOpInsts.push_back(op); else if (isa(op)) storeOpInsts.push_back(op); }); } // Returns the load op count for 'memref'. unsigned Node::getLoadOpCount(Value memref) const { unsigned loadOpCount = 0; for (Operation *loadOp : loads) { if (memref == cast(loadOp).getMemRef()) ++loadOpCount; } return loadOpCount; } // Returns the store op count for 'memref'. unsigned Node::getStoreOpCount(Value memref) const { unsigned storeOpCount = 0; for (Operation *storeOp : stores) { if (memref == cast(storeOp).getMemRef()) ++storeOpCount; } return storeOpCount; } // Returns all store ops in 'storeOps' which access 'memref'. void Node::getStoreOpsForMemref(Value memref, SmallVectorImpl *storeOps) const { for (Operation *storeOp : stores) { if (memref == cast(storeOp).getMemRef()) storeOps->push_back(storeOp); } } // Returns all load ops in 'loadOps' which access 'memref'. void Node::getLoadOpsForMemref(Value memref, SmallVectorImpl *loadOps) const { for (Operation *loadOp : loads) { if (memref == cast(loadOp).getMemRef()) loadOps->push_back(loadOp); } } // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node // has at least one load and store operation. void Node::getLoadAndStoreMemrefSet( DenseSet *loadAndStoreMemrefSet) const { llvm::SmallDenseSet loadMemrefs; for (Operation *loadOp : loads) { loadMemrefs.insert(cast(loadOp).getMemRef()); } for (Operation *storeOp : stores) { auto memref = cast(storeOp).getMemRef(); if (loadMemrefs.count(memref) > 0) loadAndStoreMemrefSet->insert(memref); } } // Initializes the data dependence graph by walking operations in `block`. // Assigns each node in the graph a node id based on program order in 'f'. bool MemRefDependenceGraph::init() { LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n"); // Map from a memref to the set of ids of the nodes that have ops accessing // the memref. DenseMap> memrefAccesses; DenseMap forToNodeMap; for (Operation &op : block) { if (dyn_cast(op)) { // Create graph node 'id' to represent top-level 'forOp' and record // all loads and store accesses it contains. LoopNestStateCollector collector; collector.collect(&op); // Return false if a region holding op other than 'affine.for' and // 'affine.if' was found (not currently supported). if (collector.hasNonAffineRegionOp) return false; Node node(nextNodeId++, &op); for (auto *opInst : collector.loadOpInsts) { node.loads.push_back(opInst); auto memref = cast(opInst).getMemRef(); memrefAccesses[memref].insert(node.id); } for (auto *opInst : collector.storeOpInsts) { node.stores.push_back(opInst); auto memref = cast(opInst).getMemRef(); memrefAccesses[memref].insert(node.id); } forToNodeMap[&op] = node.id; nodes.insert({node.id, node}); } else if (dyn_cast(op)) { // Create graph node for top-level load op. Node node(nextNodeId++, &op); node.loads.push_back(&op); auto memref = cast(op).getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } else if (dyn_cast(op)) { // Create graph node for top-level store op. Node node(nextNodeId++, &op); node.stores.push_back(&op); auto memref = cast(op).getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); } else if (op.getNumResults() > 0 && !op.use_empty()) { // Create graph node for top-level producer of SSA values, which // could be used by loop nest nodes. Node node(nextNodeId++, &op); nodes.insert({node.id, node}); } else if (!isMemoryEffectFree(&op) && (op.getNumRegions() == 0 || isa(op))) { // Create graph node for top-level op unless it is known to be // memory-effect free. This covers all unknown/unregistered ops, // non-affine ops with memory effects, and region-holding ops with a // well-defined control flow. During the fusion validity checks, we look // for non-affine ops on the path from source to destination, at which // point we check which memrefs if any are used in the region. Node node(nextNodeId++, &op); nodes.insert({node.id, node}); } else if (op.getNumRegions() != 0) { // Return false if non-handled/unknown region-holding ops are found. We // won't know what such ops do or what its regions mean; for e.g., it may // not be an imperative op. LLVM_DEBUG(llvm::dbgs() << "MDG init failed; unknown region-holding op found!\n"); return false; } } for (auto &idAndNode : nodes) { LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n" << *(idAndNode.second.op) << "\n"); (void)idAndNode; } // Add dependence edges between nodes which produce SSA values and their // users. Load ops can be considered as the ones producing SSA values. for (auto &idAndNode : nodes) { const Node &node = idAndNode.second; // Stores don't define SSA values, skip them. if (!node.stores.empty()) continue; Operation *opInst = node.op; for (Value value : opInst->getResults()) { for (Operation *user : value.getUsers()) { // Ignore users outside of the block. if (block.getParent()->findAncestorOpInRegion(*user)->getBlock() != &block) continue; SmallVector loops; getAffineForIVs(*user, &loops); // Find the surrounding affine.for nested immediately within the // block. auto *it = llvm::find_if(loops, [&](AffineForOp loop) { return loop->getBlock() == █ }); if (it == loops.end()) continue; assert(forToNodeMap.count(*it) > 0 && "missing mapping"); unsigned userLoopNestId = forToNodeMap[*it]; addEdge(node.id, userLoopNestId, value); } } } // Walk memref access lists and add graph edges between dependent nodes. for (auto &memrefAndList : memrefAccesses) { unsigned n = memrefAndList.second.size(); for (unsigned i = 0; i < n; ++i) { unsigned srcId = memrefAndList.second[i]; bool srcHasStore = getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0; for (unsigned j = i + 1; j < n; ++j) { unsigned dstId = memrefAndList.second[j]; bool dstHasStore = getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0; if (srcHasStore || dstHasStore) addEdge(srcId, dstId, memrefAndList.first); } } } return true; } // Returns the graph node for 'id'. Node *MemRefDependenceGraph::getNode(unsigned id) { auto it = nodes.find(id); assert(it != nodes.end()); return &it->second; } // Returns the graph node for 'forOp'. Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) { for (auto &idAndNode : nodes) if (idAndNode.second.op == forOp) return &idAndNode.second; return nullptr; } // Adds a node with 'op' to the graph and returns its unique identifier. unsigned MemRefDependenceGraph::addNode(Operation *op) { Node node(nextNodeId++, op); nodes.insert({node.id, node}); return node.id; } // Remove node 'id' (and its associated edges) from graph. void MemRefDependenceGraph::removeNode(unsigned id) { // Remove each edge in 'inEdges[id]'. if (inEdges.count(id) > 0) { SmallVector oldInEdges = inEdges[id]; for (auto &inEdge : oldInEdges) { removeEdge(inEdge.id, id, inEdge.value); } } // Remove each edge in 'outEdges[id]'. if (outEdges.count(id) > 0) { SmallVector oldOutEdges = outEdges[id]; for (auto &outEdge : oldOutEdges) { removeEdge(id, outEdge.id, outEdge.value); } } // Erase remaining node state. inEdges.erase(id); outEdges.erase(id); nodes.erase(id); } // Returns true if node 'id' writes to any memref which escapes (or is an // argument to) the block. Returns false otherwise. bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) { Node *node = getNode(id); for (auto *storeOpInst : node->stores) { auto memref = cast(storeOpInst).getMemRef(); auto *op = memref.getDefiningOp(); // Return true if 'memref' is a block argument. if (!op) return true; // Return true if any use of 'memref' does not deference it in an affine // way. for (auto *user : memref.getUsers()) if (!isa(*user)) return true; } return false; } // Returns true iff there is an edge from node 'srcId' to node 'dstId' which // is for 'value' if non-null, or for any value otherwise. Returns false // otherwise. bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId, Value value) { if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) { return false; } bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) { return edge.id == dstId && (!value || edge.value == value); }); bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) { return edge.id == srcId && (!value || edge.value == value); }); return hasOutEdge && hasInEdge; } // Adds an edge from node 'srcId' to node 'dstId' for 'value'. void MemRefDependenceGraph::addEdge(unsigned srcId, unsigned dstId, Value value) { if (!hasEdge(srcId, dstId, value)) { outEdges[srcId].push_back({dstId, value}); inEdges[dstId].push_back({srcId, value}); if (isa(value.getType())) memrefEdgeCount[value]++; } } // Removes an edge from node 'srcId' to node 'dstId' for 'value'. void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId, Value value) { assert(inEdges.count(dstId) > 0); assert(outEdges.count(srcId) > 0); if (isa(value.getType())) { assert(memrefEdgeCount.count(value) > 0); memrefEdgeCount[value]--; } // Remove 'srcId' from 'inEdges[dstId]'. for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) { if ((*it).id == srcId && (*it).value == value) { inEdges[dstId].erase(it); break; } } // Remove 'dstId' from 'outEdges[srcId]'. for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) { if ((*it).id == dstId && (*it).value == value) { outEdges[srcId].erase(it); break; } } } // Returns true if there is a path in the dependence graph from node 'srcId' // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the // operations that the edges connected are expected to be from the same block. bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) { // Worklist state is: SmallVector, 4> worklist; worklist.push_back({srcId, 0}); Operation *dstOp = getNode(dstId)->op; // Run DFS traversal to see if 'dstId' is reachable from 'srcId'. while (!worklist.empty()) { auto &idAndIndex = worklist.back(); // Return true if we have reached 'dstId'. if (idAndIndex.first == dstId) return true; // Pop and continue if node has no out edges, or if all out edges have // already been visited. if (outEdges.count(idAndIndex.first) == 0 || idAndIndex.second == outEdges[idAndIndex.first].size()) { worklist.pop_back(); continue; } // Get graph edge to traverse. Edge edge = outEdges[idAndIndex.first][idAndIndex.second]; // Increment next output edge index for 'idAndIndex'. ++idAndIndex.second; // Add node at 'edge.id' to the worklist. We don't need to consider // nodes that are "after" dstId in the containing block; one can't have a // path to `dstId` from any of those nodes. bool afterDst = dstOp->isBeforeInBlock(getNode(edge.id)->op); if (!afterDst && edge.id != idAndIndex.first) worklist.push_back({edge.id, 0}); } return false; } // Returns the input edge count for node 'id' and 'memref' from src nodes // which access 'memref' with a store operation. unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id, Value memref) { unsigned inEdgeCount = 0; if (inEdges.count(id) > 0) for (auto &inEdge : inEdges[id]) if (inEdge.value == memref) { Node *srcNode = getNode(inEdge.id); // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref' if (srcNode->getStoreOpCount(memref) > 0) ++inEdgeCount; } return inEdgeCount; } // Returns the output edge count for node 'id' and 'memref' (if non-null), // otherwise returns the total output edge count from node 'id'. unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) { unsigned outEdgeCount = 0; if (outEdges.count(id) > 0) for (auto &outEdge : outEdges[id]) if (!memref || outEdge.value == memref) ++outEdgeCount; return outEdgeCount; } /// Return all nodes which define SSA values used in node 'id'. void MemRefDependenceGraph::gatherDefiningNodes( unsigned id, DenseSet &definingNodes) { for (MemRefDependenceGraph::Edge edge : inEdges[id]) // By definition of edge, if the edge value is a non-memref value, // then the dependence is between a graph node which defines an SSA value // and another graph node which uses the SSA value. if (!isa(edge.value.getType())) definingNodes.insert(edge.id); } // Computes and returns an insertion point operation, before which the // the fused loop nest can be inserted while preserving // dependences. Returns nullptr if no such insertion point is found. Operation * MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) { if (outEdges.count(srcId) == 0) return getNode(dstId)->op; // Skip if there is any defining node of 'dstId' that depends on 'srcId'. DenseSet definingNodes; gatherDefiningNodes(dstId, definingNodes); if (llvm::any_of(definingNodes, [&](unsigned id) { return hasDependencePath(srcId, id); })) { LLVM_DEBUG(llvm::dbgs() << "Can't fuse: a defining op with a user in the dst " "loop has dependence from the src loop\n"); return nullptr; } // Build set of insts in range (srcId, dstId) which depend on 'srcId'. SmallPtrSet srcDepInsts; for (auto &outEdge : outEdges[srcId]) if (outEdge.id != dstId) srcDepInsts.insert(getNode(outEdge.id)->op); // Build set of insts in range (srcId, dstId) on which 'dstId' depends. SmallPtrSet dstDepInsts; for (auto &inEdge : inEdges[dstId]) if (inEdge.id != srcId) dstDepInsts.insert(getNode(inEdge.id)->op); Operation *srcNodeInst = getNode(srcId)->op; Operation *dstNodeInst = getNode(dstId)->op; // Computing insertion point: // *) Walk all operation positions in Block operation list in the // range (src, dst). For each operation 'op' visited in this search: // *) Store in 'firstSrcDepPos' the first position where 'op' has a // dependence edge from 'srcNode'. // *) Store in 'lastDstDepPost' the last position where 'op' has a // dependence edge to 'dstNode'. // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the // operation insertion point (or return null pointer if no such // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos'). SmallVector depInsts; std::optional firstSrcDepPos; std::optional lastDstDepPos; unsigned pos = 0; for (Block::iterator it = std::next(Block::iterator(srcNodeInst)); it != Block::iterator(dstNodeInst); ++it) { Operation *op = &(*it); if (srcDepInsts.count(op) > 0 && firstSrcDepPos == std::nullopt) firstSrcDepPos = pos; if (dstDepInsts.count(op) > 0) lastDstDepPos = pos; depInsts.push_back(op); ++pos; } if (firstSrcDepPos.has_value()) { if (lastDstDepPos.has_value()) { if (*firstSrcDepPos <= *lastDstDepPos) { // No valid insertion point exists which preserves dependences. return nullptr; } } // Return the insertion point at 'firstSrcDepPos'. return depInsts[*firstSrcDepPos]; } // No dependence targets in range (or only dst deps in range), return // 'dstNodInst' insertion point. return dstNodeInst; } // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them, // taking into account that: // *) if 'removeSrcId' is true, 'srcId' will be removed after fusion, // *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a // private memref. void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId, const DenseSet &privateMemRefs, bool removeSrcId) { // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'. if (inEdges.count(srcId) > 0) { SmallVector oldInEdges = inEdges[srcId]; for (auto &inEdge : oldInEdges) { // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref. if (privateMemRefs.count(inEdge.value) == 0) addEdge(inEdge.id, dstId, inEdge.value); } } // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'. // If 'srcId' is going to be removed, remap all the out edges to 'dstId'. if (outEdges.count(srcId) > 0) { SmallVector oldOutEdges = outEdges[srcId]; for (auto &outEdge : oldOutEdges) { // Remove any out edges from 'srcId' to 'dstId' across memrefs. if (outEdge.id == dstId) removeEdge(srcId, outEdge.id, outEdge.value); else if (removeSrcId) { addEdge(dstId, outEdge.id, outEdge.value); removeEdge(srcId, outEdge.id, outEdge.value); } } } // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being // replaced by a private memref). These edges could come from nodes // other than 'srcId' which were removed in the previous step. if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) { SmallVector oldInEdges = inEdges[dstId]; for (auto &inEdge : oldInEdges) if (privateMemRefs.count(inEdge.value) > 0) removeEdge(inEdge.id, dstId, inEdge.value); } } // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion // of sibling node 'sibId' into node 'dstId'. void MemRefDependenceGraph::updateEdges(unsigned sibId, unsigned dstId) { // For each edge in 'inEdges[sibId]': // *) Add new edge from source node 'inEdge.id' to 'dstNode'. // *) Remove edge from source node 'inEdge.id' to 'sibNode'. if (inEdges.count(sibId) > 0) { SmallVector oldInEdges = inEdges[sibId]; for (auto &inEdge : oldInEdges) { addEdge(inEdge.id, dstId, inEdge.value); removeEdge(inEdge.id, sibId, inEdge.value); } } // For each edge in 'outEdges[sibId]' to node 'id' // *) Add new edge from 'dstId' to 'outEdge.id'. // *) Remove edge from 'sibId' to 'outEdge.id'. if (outEdges.count(sibId) > 0) { SmallVector oldOutEdges = outEdges[sibId]; for (auto &outEdge : oldOutEdges) { addEdge(dstId, outEdge.id, outEdge.value); removeEdge(sibId, outEdge.id, outEdge.value); } } } // Adds ops in 'loads' and 'stores' to node at 'id'. void MemRefDependenceGraph::addToNode( unsigned id, const SmallVectorImpl &loads, const SmallVectorImpl &stores) { Node *node = getNode(id); llvm::append_range(node->loads, loads); llvm::append_range(node->stores, stores); } void MemRefDependenceGraph::clearNodeLoadAndStores(unsigned id) { Node *node = getNode(id); node->loads.clear(); node->stores.clear(); } // Calls 'callback' for each input edge incident to node 'id' which carries a // memref dependence. void MemRefDependenceGraph::forEachMemRefInputEdge( unsigned id, const std::function &callback) { if (inEdges.count(id) > 0) forEachMemRefEdge(inEdges[id], callback); } // Calls 'callback' for each output edge from node 'id' which carries a // memref dependence. void MemRefDependenceGraph::forEachMemRefOutputEdge( unsigned id, const std::function &callback) { if (outEdges.count(id) > 0) forEachMemRefEdge(outEdges[id], callback); } // Calls 'callback' for each edge in 'edges' which carries a memref // dependence. void MemRefDependenceGraph::forEachMemRefEdge( ArrayRef edges, const std::function &callback) { for (const auto &edge : edges) { // Skip if 'edge' is not a memref dependence edge. if (!isa(edge.value.getType())) continue; assert(nodes.count(edge.id) > 0); // Skip if 'edge.id' is not a loop nest. if (!isa(getNode(edge.id)->op)) continue; // Visit current input edge 'edge'. callback(edge); } } void MemRefDependenceGraph::print(raw_ostream &os) const { os << "\nMemRefDependenceGraph\n"; os << "\nNodes:\n"; for (const auto &idAndNode : nodes) { os << "Node: " << idAndNode.first << "\n"; auto it = inEdges.find(idAndNode.first); if (it != inEdges.end()) { for (const auto &e : it->second) os << " InEdge: " << e.id << " " << e.value << "\n"; } it = outEdges.find(idAndNode.first); if (it != outEdges.end()) { for (const auto &e : it->second) os << " OutEdge: " << e.id << " " << e.value << "\n"; } } } void mlir::affine::getAffineForIVs(Operation &op, SmallVectorImpl *loops) { auto *currOp = op.getParentOp(); AffineForOp currAffineForOp; // Traverse up the hierarchy collecting all 'affine.for' operation while // skipping over 'affine.if' operations. while (currOp && !currOp->hasTrait()) { if (auto currAffineForOp = dyn_cast(currOp)) loops->push_back(currAffineForOp); currOp = currOp->getParentOp(); } std::reverse(loops->begin(), loops->end()); } void mlir::affine::getEnclosingAffineOps(Operation &op, SmallVectorImpl *ops) { ops->clear(); Operation *currOp = op.getParentOp(); // Traverse up the hierarchy collecting all `affine.for`, `affine.if`, and // affine.parallel operations. while (currOp && !currOp->hasTrait()) { if (isa(currOp)) ops->push_back(currOp); currOp = currOp->getParentOp(); } std::reverse(ops->begin(), ops->end()); } // Populates 'cst' with FlatAffineValueConstraints which represent original // domain of the loop bounds that define 'ivs'. LogicalResult ComputationSliceState::getSourceAsConstraints( FlatAffineValueConstraints &cst) const { assert(!ivs.empty() && "Cannot have a slice without its IVs"); cst = FlatAffineValueConstraints(/*numDims=*/ivs.size(), /*numSymbols=*/0, /*numLocals=*/0, ivs); for (Value iv : ivs) { AffineForOp loop = getForInductionVarOwner(iv); assert(loop && "Expected affine for"); if (failed(cst.addAffineForOpDomain(loop))) return failure(); } return success(); } // Populates 'cst' with FlatAffineValueConstraints which represent slice bounds. LogicalResult ComputationSliceState::getAsConstraints(FlatAffineValueConstraints *cst) const { assert(!lbOperands.empty()); // Adds src 'ivs' as dimension variables in 'cst'. unsigned numDims = ivs.size(); // Adds operands (dst ivs and symbols) as symbols in 'cst'. unsigned numSymbols = lbOperands[0].size(); SmallVector values(ivs); // Append 'ivs' then 'operands' to 'values'. values.append(lbOperands[0].begin(), lbOperands[0].end()); *cst = FlatAffineValueConstraints(numDims, numSymbols, 0, values); // Add loop bound constraints for values which are loop IVs of the destination // of fusion and equality constraints for symbols which are constants. for (unsigned i = numDims, end = values.size(); i < end; ++i) { Value value = values[i]; assert(cst->containsVar(value) && "value expected to be present"); if (isValidSymbol(value)) { // Check if the symbol is a constant. if (std::optional cOp = getConstantIntValue(value)) cst->addBound(BoundType::EQ, value, cOp.value()); } else if (auto loop = getForInductionVarOwner(value)) { if (failed(cst->addAffineForOpDomain(loop))) return failure(); } } // Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]' LogicalResult ret = cst->addSliceBounds(ivs, lbs, ubs, lbOperands[0]); assert(succeeded(ret) && "should not fail as we never have semi-affine slice maps"); (void)ret; return success(); } // Clears state bounds and operand state. void ComputationSliceState::clearBounds() { lbs.clear(); ubs.clear(); lbOperands.clear(); ubOperands.clear(); } void ComputationSliceState::dump() const { llvm::errs() << "\tIVs:\n"; for (Value iv : ivs) llvm::errs() << "\t\t" << iv << "\n"; llvm::errs() << "\tLBs:\n"; for (auto en : llvm::enumerate(lbs)) { llvm::errs() << "\t\t" << en.value() << "\n"; llvm::errs() << "\t\tOperands:\n"; for (Value lbOp : lbOperands[en.index()]) llvm::errs() << "\t\t\t" << lbOp << "\n"; } llvm::errs() << "\tUBs:\n"; for (auto en : llvm::enumerate(ubs)) { llvm::errs() << "\t\t" << en.value() << "\n"; llvm::errs() << "\t\tOperands:\n"; for (Value ubOp : ubOperands[en.index()]) llvm::errs() << "\t\t\t" << ubOp << "\n"; } } /// Fast check to determine if the computation slice is maximal. Returns true if /// each slice dimension maps to an existing dst dimension and both the src /// and the dst loops for those dimensions have the same bounds. Returns false /// if both the src and the dst loops don't have the same bounds. Returns /// std::nullopt if none of the above can be proven. std::optional ComputationSliceState::isSliceMaximalFastCheck() const { assert(lbs.size() == ubs.size() && !lbs.empty() && !ivs.empty() && "Unexpected number of lbs, ubs and ivs in slice"); for (unsigned i = 0, end = lbs.size(); i < end; ++i) { AffineMap lbMap = lbs[i]; AffineMap ubMap = ubs[i]; // Check if this slice is just an equality along this dimension. if (!lbMap || !ubMap || lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1 || lbMap.getResult(0) + 1 != ubMap.getResult(0) || // The condition above will be true for maps describing a single // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1). // Make sure we skip those cases by checking that the lb result is not // just a constant. isa(lbMap.getResult(0))) return std::nullopt; // Limited support: we expect the lb result to be just a loop dimension for // now. AffineDimExpr result = dyn_cast(lbMap.getResult(0)); if (!result) return std::nullopt; // Retrieve dst loop bounds. AffineForOp dstLoop = getForInductionVarOwner(lbOperands[i][result.getPosition()]); if (!dstLoop) return std::nullopt; AffineMap dstLbMap = dstLoop.getLowerBoundMap(); AffineMap dstUbMap = dstLoop.getUpperBoundMap(); // Retrieve src loop bounds. AffineForOp srcLoop = getForInductionVarOwner(ivs[i]); assert(srcLoop && "Expected affine for"); AffineMap srcLbMap = srcLoop.getLowerBoundMap(); AffineMap srcUbMap = srcLoop.getUpperBoundMap(); // Limited support: we expect simple src and dst loops with a single // constant component per bound for now. if (srcLbMap.getNumResults() != 1 || srcUbMap.getNumResults() != 1 || dstLbMap.getNumResults() != 1 || dstUbMap.getNumResults() != 1) return std::nullopt; AffineExpr srcLbResult = srcLbMap.getResult(0); AffineExpr dstLbResult = dstLbMap.getResult(0); AffineExpr srcUbResult = srcUbMap.getResult(0); AffineExpr dstUbResult = dstUbMap.getResult(0); if (!isa(srcLbResult) || !isa(srcUbResult) || !isa(dstLbResult) || !isa(dstUbResult)) return std::nullopt; // Check if src and dst loop bounds are the same. If not, we can guarantee // that the slice is not maximal. if (srcLbResult != dstLbResult || srcUbResult != dstUbResult || srcLoop.getStep() != dstLoop.getStep()) return false; } return true; } /// Returns true if it is deterministically verified that the original iteration /// space of the slice is contained within the new iteration space that is /// created after fusing 'this' slice into its destination. std::optional ComputationSliceState::isSliceValid() const { // Fast check to determine if the slice is valid. If the following conditions // are verified to be true, slice is declared valid by the fast check: // 1. Each slice loop is a single iteration loop bound in terms of a single // destination loop IV. // 2. Loop bounds of the destination loop IV (from above) and those of the // source loop IV are exactly the same. // If the fast check is inconclusive or false, we proceed with a more // expensive analysis. // TODO: Store the result of the fast check, as it might be used again in // `canRemoveSrcNodeAfterFusion`. std::optional isValidFastCheck = isSliceMaximalFastCheck(); if (isValidFastCheck && *isValidFastCheck) return true; // Create constraints for the source loop nest using which slice is computed. FlatAffineValueConstraints srcConstraints; // TODO: Store the source's domain to avoid computation at each depth. if (failed(getSourceAsConstraints(srcConstraints))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n"); return std::nullopt; } // As the set difference utility currently cannot handle symbols in its // operands, validity of the slice cannot be determined. if (srcConstraints.getNumSymbolVars() > 0) { LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n"); return std::nullopt; } // TODO: Handle local vars in the source domains while using the 'projectOut' // utility below. Currently, aligning is not done assuming that there will be // no local vars in the source domain. if (srcConstraints.getNumLocalVars() != 0) { LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n"); return std::nullopt; } // Create constraints for the slice loop nest that would be created if the // fusion succeeds. FlatAffineValueConstraints sliceConstraints; if (failed(getAsConstraints(&sliceConstraints))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n"); return std::nullopt; } // Projecting out every dimension other than the 'ivs' to express slice's // domain completely in terms of source's IVs. sliceConstraints.projectOut(ivs.size(), sliceConstraints.getNumVars() - ivs.size()); LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n"); LLVM_DEBUG(srcConstraints.dump()); LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds " "(expressed in terms of its source's IVs):\n"); LLVM_DEBUG(sliceConstraints.dump()); // TODO: Store 'srcSet' to avoid recalculating for each depth. PresburgerSet srcSet(srcConstraints); PresburgerSet sliceSet(sliceConstraints); PresburgerSet diffSet = sliceSet.subtract(srcSet); if (!diffSet.isIntegerEmpty()) { LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n"); return false; } return true; } /// Returns true if the computation slice encloses all the iterations of the /// sliced loop nest. Returns false if it does not. Returns std::nullopt if it /// cannot determine if the slice is maximal or not. std::optional ComputationSliceState::isMaximal() const { // Fast check to determine if the computation slice is maximal. If the result // is inconclusive, we proceed with a more expensive analysis. std::optional isMaximalFastCheck = isSliceMaximalFastCheck(); if (isMaximalFastCheck) return isMaximalFastCheck; // Create constraints for the src loop nest being sliced. FlatAffineValueConstraints srcConstraints(/*numDims=*/ivs.size(), /*numSymbols=*/0, /*numLocals=*/0, ivs); for (Value iv : ivs) { AffineForOp loop = getForInductionVarOwner(iv); assert(loop && "Expected affine for"); if (failed(srcConstraints.addAffineForOpDomain(loop))) return std::nullopt; } // Create constraints for the slice using the dst loop nest information. We // retrieve existing dst loops from the lbOperands. SmallVector consumerIVs; for (Value lbOp : lbOperands[0]) if (getForInductionVarOwner(lbOp)) consumerIVs.push_back(lbOp); // Add empty IV Values for those new loops that are not equalities and, // therefore, are not yet materialized in the IR. for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i) consumerIVs.push_back(Value()); FlatAffineValueConstraints sliceConstraints(/*numDims=*/consumerIVs.size(), /*numSymbols=*/0, /*numLocals=*/0, consumerIVs); if (failed(sliceConstraints.addDomainFromSliceMaps(lbs, ubs, lbOperands[0]))) return std::nullopt; if (srcConstraints.getNumDimVars() != sliceConstraints.getNumDimVars()) // Constraint dims are different. The integer set difference can't be // computed so we don't know if the slice is maximal. return std::nullopt; // Compute the difference between the src loop nest and the slice integer // sets. PresburgerSet srcSet(srcConstraints); PresburgerSet sliceSet(sliceConstraints); PresburgerSet diffSet = srcSet.subtract(sliceSet); return diffSet.isIntegerEmpty(); } unsigned MemRefRegion::getRank() const { return cast(memref.getType()).getRank(); } std::optional MemRefRegion::getConstantBoundingSizeAndShape( SmallVectorImpl *shape, std::vector> *lbs, SmallVectorImpl *lbDivisors) const { auto memRefType = cast(memref.getType()); unsigned rank = memRefType.getRank(); if (shape) shape->reserve(rank); assert(rank == cst.getNumDimVars() && "inconsistent memref region"); // Use a copy of the region constraints that has upper/lower bounds for each // memref dimension with static size added to guard against potential // over-approximation from projection or union bounding box. We may not add // this on the region itself since they might just be redundant constraints // that will need non-trivials means to eliminate. FlatAffineValueConstraints cstWithShapeBounds(cst); for (unsigned r = 0; r < rank; r++) { cstWithShapeBounds.addBound(BoundType::LB, r, 0); int64_t dimSize = memRefType.getDimSize(r); if (ShapedType::isDynamic(dimSize)) continue; cstWithShapeBounds.addBound(BoundType::UB, r, dimSize - 1); } // Find a constant upper bound on the extent of this memref region along each // dimension. int64_t numElements = 1; int64_t diffConstant; int64_t lbDivisor; for (unsigned d = 0; d < rank; d++) { SmallVector lb; std::optional diff = cstWithShapeBounds.getConstantBoundOnDimSize64(d, &lb, &lbDivisor); if (diff.has_value()) { diffConstant = *diff; assert(diffConstant >= 0 && "Dim size bound can't be negative"); assert(lbDivisor > 0); } else { // If no constant bound is found, then it can always be bound by the // memref's dim size if the latter has a constant size along this dim. auto dimSize = memRefType.getDimSize(d); if (dimSize == ShapedType::kDynamic) return std::nullopt; diffConstant = dimSize; // Lower bound becomes 0. lb.resize(cstWithShapeBounds.getNumSymbolVars() + 1, 0); lbDivisor = 1; } numElements *= diffConstant; if (lbs) { lbs->push_back(lb); assert(lbDivisors && "both lbs and lbDivisor or none"); lbDivisors->push_back(lbDivisor); } if (shape) { shape->push_back(diffConstant); } } return numElements; } void MemRefRegion::getLowerAndUpperBound(unsigned pos, AffineMap &lbMap, AffineMap &ubMap) const { assert(pos < cst.getNumDimVars() && "invalid position"); auto memRefType = cast(memref.getType()); unsigned rank = memRefType.getRank(); assert(rank == cst.getNumDimVars() && "inconsistent memref region"); auto boundPairs = cst.getLowerAndUpperBound( pos, /*offset=*/0, /*num=*/rank, cst.getNumDimAndSymbolVars(), /*localExprs=*/{}, memRefType.getContext()); lbMap = boundPairs.first; ubMap = boundPairs.second; assert(lbMap && "lower bound for a region must exist"); assert(ubMap && "upper bound for a region must exist"); assert(lbMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank); assert(ubMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank); } LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) { assert(memref == other.memref); return cst.unionBoundingBox(*other.getConstraints()); } /// Computes the memory region accessed by this memref with the region /// represented as constraints symbolic/parametric in 'loopDepth' loops /// surrounding opInst and any additional Function symbols. // For example, the memref region for this load operation at loopDepth = 1 will // be as below: // // affine.for %i = 0 to 32 { // affine.for %ii = %i to (d0) -> (d0 + 8) (%i) { // load %A[%ii] // } // } // // region: {memref = %A, write = false, {%i <= m0 <= %i + 7} } // The last field is a 2-d FlatAffineValueConstraints symbolic in %i. // // TODO: extend this to any other memref dereferencing ops // (dma_start, dma_wait). LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, const ComputationSliceState *sliceState, bool addMemRefDimBounds) { assert((isa(op)) && "affine read/write op expected"); MemRefAccess access(op); memref = access.memref; write = access.isStore(); unsigned rank = access.getRank(); LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op << "\ndepth: " << loopDepth << "\n";); // 0-d memrefs. if (rank == 0) { SmallVector ivs; getAffineIVs(*op, ivs); assert(loopDepth <= ivs.size() && "invalid 'loopDepth'"); // The first 'loopDepth' IVs are symbols for this region. ivs.resize(loopDepth); // A 0-d memref has a 0-d region. cst = FlatAffineValueConstraints(rank, loopDepth, /*numLocals=*/0, ivs); return success(); } // Build the constraints for this region. AffineValueMap accessValueMap; access.getAccessMap(&accessValueMap); AffineMap accessMap = accessValueMap.getAffineMap(); unsigned numDims = accessMap.getNumDims(); unsigned numSymbols = accessMap.getNumSymbols(); unsigned numOperands = accessValueMap.getNumOperands(); // Merge operands with slice operands. SmallVector operands; operands.resize(numOperands); for (unsigned i = 0; i < numOperands; ++i) operands[i] = accessValueMap.getOperand(i); if (sliceState != nullptr) { operands.reserve(operands.size() + sliceState->lbOperands[0].size()); // Append slice operands to 'operands' as symbols. for (auto extraOperand : sliceState->lbOperands[0]) { if (!llvm::is_contained(operands, extraOperand)) { operands.push_back(extraOperand); numSymbols++; } } } // We'll first associate the dims and symbols of the access map to the dims // and symbols resp. of cst. This will change below once cst is // fully constructed out. cst = FlatAffineValueConstraints(numDims, numSymbols, 0, operands); // Add equality constraints. // Add inequalities for loop lower/upper bounds. for (unsigned i = 0; i < numDims + numSymbols; ++i) { auto operand = operands[i]; if (auto affineFor = getForInductionVarOwner(operand)) { // Note that cst can now have more dimensions than accessMap if the // bounds expressions involve outer loops or other symbols. // TODO: rewrite this to use getInstIndexSet; this way // conditionals will be handled when the latter supports it. if (failed(cst.addAffineForOpDomain(affineFor))) return failure(); } else if (auto parallelOp = getAffineParallelInductionVarOwner(operand)) { if (failed(cst.addAffineParallelOpDomain(parallelOp))) return failure(); } else if (isValidSymbol(operand)) { // Check if the symbol is a constant. Value symbol = operand; if (auto constVal = getConstantIntValue(symbol)) cst.addBound(BoundType::EQ, symbol, constVal.value()); } else { LLVM_DEBUG(llvm::dbgs() << "unknown affine dimensional value"); return failure(); } } // Add lower/upper bounds on loop IVs using bounds from 'sliceState'. if (sliceState != nullptr) { // Add dim and symbol slice operands. for (auto operand : sliceState->lbOperands[0]) { cst.addInductionVarOrTerminalSymbol(operand); } // Add upper/lower bounds from 'sliceState' to 'cst'. LogicalResult ret = cst.addSliceBounds(sliceState->ivs, sliceState->lbs, sliceState->ubs, sliceState->lbOperands[0]); assert(succeeded(ret) && "should not fail as we never have semi-affine slice maps"); (void)ret; } // Add access function equalities to connect loop IVs to data dimensions. if (failed(cst.composeMap(&accessValueMap))) { op->emitError("getMemRefRegion: compose affine map failed"); LLVM_DEBUG(accessValueMap.getAffineMap().dump()); return failure(); } // Set all variables appearing after the first 'rank' variables as // symbolic variables - so that the ones corresponding to the memref // dimensions are the dimensional variables for the memref region. cst.setDimSymbolSeparation(cst.getNumDimAndSymbolVars() - rank); // Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which // this memref region is symbolic. SmallVector enclosingIVs; getAffineIVs(*op, enclosingIVs); assert(loopDepth <= enclosingIVs.size() && "invalid loop depth"); enclosingIVs.resize(loopDepth); SmallVector vars; cst.getValues(cst.getNumDimVars(), cst.getNumDimAndSymbolVars(), &vars); for (Value var : vars) { if ((isAffineInductionVar(var)) && !llvm::is_contained(enclosingIVs, var)) { cst.projectOut(var); } } // Project out any local variables (these would have been added for any // mod/divs). cst.projectOut(cst.getNumDimAndSymbolVars(), cst.getNumLocalVars()); // Constant fold any symbolic variables. cst.constantFoldVarRange(/*pos=*/cst.getNumDimVars(), /*num=*/cst.getNumSymbolVars()); assert(cst.getNumDimVars() == rank && "unexpected MemRefRegion format"); // Add upper/lower bounds for each memref dimension with static size // to guard against potential over-approximation from projection. // TODO: Support dynamic memref dimensions. if (addMemRefDimBounds) { auto memRefType = cast(memref.getType()); for (unsigned r = 0; r < rank; r++) { cst.addBound(BoundType::LB, /*pos=*/r, /*value=*/0); if (memRefType.isDynamicDim(r)) continue; cst.addBound(BoundType::UB, /*pos=*/r, memRefType.getDimSize(r) - 1); } } cst.removeTrivialRedundancy(); LLVM_DEBUG(llvm::dbgs() << "Memory region:\n"); LLVM_DEBUG(cst.dump()); return success(); } std::optional mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) { auto elementType = memRefType.getElementType(); unsigned sizeInBits; if (elementType.isIntOrFloat()) { sizeInBits = elementType.getIntOrFloatBitWidth(); } else if (auto vectorType = dyn_cast(elementType)) { if (vectorType.getElementType().isIntOrFloat()) sizeInBits = vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); else return std::nullopt; } else { return std::nullopt; } return llvm::divideCeil(sizeInBits, 8); } // Returns the size of the region. std::optional MemRefRegion::getRegionSize() { auto memRefType = cast(memref.getType()); if (!memRefType.getLayout().isIdentity()) { LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); return false; } // Indices to use for the DmaStart op. // Indices for the original memref being DMAed from/to. SmallVector memIndices; // Indices for the faster buffer being DMAed into/from. SmallVector bufIndices; // Compute the extents of the buffer. std::optional numElements = getConstantBoundingSizeAndShape(); if (!numElements) { LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n"); return std::nullopt; } auto eltSize = getMemRefIntOrFloatEltSizeInBytes(memRefType); if (!eltSize) return std::nullopt; return *eltSize * *numElements; } /// Returns the size of memref data in bytes if it's statically shaped, /// std::nullopt otherwise. If the element of the memref has vector type, takes /// into account size of the vector as well. // TODO: improve/complete this when we have target data. std::optional mlir::affine::getIntOrFloatMemRefSizeInBytes(MemRefType memRefType) { if (!memRefType.hasStaticShape()) return std::nullopt; auto elementType = memRefType.getElementType(); if (!elementType.isIntOrFloat() && !isa(elementType)) return std::nullopt; auto sizeInBytes = getMemRefIntOrFloatEltSizeInBytes(memRefType); if (!sizeInBytes) return std::nullopt; for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) { sizeInBytes = *sizeInBytes * memRefType.getDimSize(i); } return sizeInBytes; } template LogicalResult mlir::affine::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp, bool emitError) { static_assert(llvm::is_one_of::value, "argument should be either a AffineReadOpInterface or a " "AffineWriteOpInterface"); Operation *op = loadOrStoreOp.getOperation(); MemRefRegion region(op->getLoc()); if (failed(region.compute(op, /*loopDepth=*/0, /*sliceState=*/nullptr, /*addMemRefDimBounds=*/false))) return success(); LLVM_DEBUG(llvm::dbgs() << "Memory region"); LLVM_DEBUG(region.getConstraints()->dump()); bool outOfBounds = false; unsigned rank = loadOrStoreOp.getMemRefType().getRank(); // For each dimension, check for out of bounds. for (unsigned r = 0; r < rank; r++) { FlatAffineValueConstraints ucst(*region.getConstraints()); // Intersect memory region with constraint capturing out of bounds (both out // of upper and out of lower), and check if the constraint system is // feasible. If it is, there is at least one point out of bounds. SmallVector ineq(rank + 1, 0); int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r); // TODO: handle dynamic dim sizes. if (dimSize == -1) continue; // Check for overflow: d_i >= memref dim size. ucst.addBound(BoundType::LB, r, dimSize); outOfBounds = !ucst.isEmpty(); if (outOfBounds && emitError) { loadOrStoreOp.emitOpError() << "memref out of upper bound access along dimension #" << (r + 1); } // Check for a negative index. FlatAffineValueConstraints lcst(*region.getConstraints()); std::fill(ineq.begin(), ineq.end(), 0); // d_i <= -1; lcst.addBound(BoundType::UB, r, -1); outOfBounds = !lcst.isEmpty(); if (outOfBounds && emitError) { loadOrStoreOp.emitOpError() << "memref out of lower bound access along dimension #" << (r + 1); } } return failure(outOfBounds); } // Explicitly instantiate the template so that the compiler knows we need them! template LogicalResult mlir::affine::boundCheckLoadOrStoreOp(AffineReadOpInterface loadOp, bool emitError); template LogicalResult mlir::affine::boundCheckLoadOrStoreOp(AffineWriteOpInterface storeOp, bool emitError); // Returns in 'positions' the Block positions of 'op' in each ancestor // Block from the Block containing operation, stopping at 'limitBlock'. static void findInstPosition(Operation *op, Block *limitBlock, SmallVectorImpl *positions) { Block *block = op->getBlock(); while (block != limitBlock) { // FIXME: This algorithm is unnecessarily O(n) and should be improved to not // rely on linear scans. int instPosInBlock = std::distance(block->begin(), op->getIterator()); positions->push_back(instPosInBlock); op = block->getParentOp(); block = op->getBlock(); } std::reverse(positions->begin(), positions->end()); } // Returns the Operation in a possibly nested set of Blocks, where the // position of the operation is represented by 'positions', which has a // Block position for each level of nesting. static Operation *getInstAtPosition(ArrayRef positions, unsigned level, Block *block) { unsigned i = 0; for (auto &op : *block) { if (i != positions[level]) { ++i; continue; } if (level == positions.size() - 1) return &op; if (auto childAffineForOp = dyn_cast(op)) return getInstAtPosition(positions, level + 1, childAffineForOp.getBody()); for (auto ®ion : op.getRegions()) { for (auto &b : region) if (auto *ret = getInstAtPosition(positions, level + 1, &b)) return ret; } return nullptr; } return nullptr; } // Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'. static LogicalResult addMissingLoopIVBounds(SmallPtrSet &ivs, FlatAffineValueConstraints *cst) { for (unsigned i = 0, e = cst->getNumDimVars(); i < e; ++i) { auto value = cst->getValue(i); if (ivs.count(value) == 0) { assert(isAffineForInductionVar(value)); auto loop = getForInductionVarOwner(value); if (failed(cst->addAffineForOpDomain(loop))) return failure(); } } return success(); } /// Returns the innermost common loop depth for the set of operations in 'ops'. // TODO: Move this to LoopUtils. unsigned mlir::affine::getInnermostCommonLoopDepth( ArrayRef ops, SmallVectorImpl *surroundingLoops) { unsigned numOps = ops.size(); assert(numOps > 0 && "Expected at least one operation"); std::vector> loops(numOps); unsigned loopDepthLimit = std::numeric_limits::max(); for (unsigned i = 0; i < numOps; ++i) { getAffineForIVs(*ops[i], &loops[i]); loopDepthLimit = std::min(loopDepthLimit, static_cast(loops[i].size())); } unsigned loopDepth = 0; for (unsigned d = 0; d < loopDepthLimit; ++d) { unsigned i; for (i = 1; i < numOps; ++i) { if (loops[i - 1][d] != loops[i][d]) return loopDepth; } if (surroundingLoops) surroundingLoops->push_back(loops[i - 1][d]); ++loopDepth; } return loopDepth; } /// Computes in 'sliceUnion' the union of all slice bounds computed at /// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and /// then verifies if it is valid. Returns 'SliceComputationResult::Success' if /// union was computed correctly, an appropriate failure otherwise. SliceComputationResult mlir::affine::computeSliceUnion(ArrayRef opsA, ArrayRef opsB, unsigned loopDepth, unsigned numCommonLoops, bool isBackwardSlice, ComputationSliceState *sliceUnion) { // Compute the union of slice bounds between all pairs in 'opsA' and // 'opsB' in 'sliceUnionCst'. FlatAffineValueConstraints sliceUnionCst; assert(sliceUnionCst.getNumDimAndSymbolVars() == 0); std::vector> dependentOpPairs; for (auto *i : opsA) { MemRefAccess srcAccess(i); for (auto *j : opsB) { MemRefAccess dstAccess(j); if (srcAccess.memref != dstAccess.memref) continue; // Check if 'loopDepth' exceeds nesting depth of src/dst ops. if ((!isBackwardSlice && loopDepth > getNestingDepth(i)) || (isBackwardSlice && loopDepth > getNestingDepth(j))) { LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n"); return SliceComputationResult::GenericFailure; } bool readReadAccesses = isa(srcAccess.opInst) && isa(dstAccess.opInst); FlatAffineValueConstraints dependenceConstraints; // Check dependence between 'srcAccess' and 'dstAccess'. DependenceResult result = checkMemrefAccessDependence( srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1, &dependenceConstraints, /*dependenceComponents=*/nullptr, /*allowRAR=*/readReadAccesses); if (result.value == DependenceResult::Failure) { LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n"); return SliceComputationResult::GenericFailure; } if (result.value == DependenceResult::NoDependence) continue; dependentOpPairs.emplace_back(i, j); // Compute slice bounds for 'srcAccess' and 'dstAccess'. ComputationSliceState tmpSliceState; mlir::affine::getComputationSliceState(i, j, &dependenceConstraints, loopDepth, isBackwardSlice, &tmpSliceState); if (sliceUnionCst.getNumDimAndSymbolVars() == 0) { // Initialize 'sliceUnionCst' with the bounds computed in previous step. if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bound constraints\n"); return SliceComputationResult::GenericFailure; } assert(sliceUnionCst.getNumDimAndSymbolVars() > 0); continue; } // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'. FlatAffineValueConstraints tmpSliceCst; if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice bound constraints\n"); return SliceComputationResult::GenericFailure; } // Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed. if (!sliceUnionCst.areVarsAlignedWithOther(tmpSliceCst)) { // Pre-constraint var alignment: record loop IVs used in each constraint // system. SmallPtrSet sliceUnionIVs; for (unsigned k = 0, l = sliceUnionCst.getNumDimVars(); k < l; ++k) sliceUnionIVs.insert(sliceUnionCst.getValue(k)); SmallPtrSet tmpSliceIVs; for (unsigned k = 0, l = tmpSliceCst.getNumDimVars(); k < l; ++k) tmpSliceIVs.insert(tmpSliceCst.getValue(k)); sliceUnionCst.mergeAndAlignVarsWithOther(/*offset=*/0, &tmpSliceCst); // Post-constraint var alignment: add loop IV bounds missing after // var alignment to constraint systems. This can occur if one constraint // system uses an loop IV that is not used by the other. The call // to unionBoundingBox below expects constraints for each Loop IV, even // if they are the unsliced full loop bounds added here. if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst))) return SliceComputationResult::GenericFailure; if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst))) return SliceComputationResult::GenericFailure; } // Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'. if (sliceUnionCst.getNumLocalVars() > 0 || tmpSliceCst.getNumLocalVars() > 0 || failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute union bounding box of slice bounds\n"); return SliceComputationResult::GenericFailure; } } } // Empty union. if (sliceUnionCst.getNumDimAndSymbolVars() == 0) return SliceComputationResult::GenericFailure; // Gather loops surrounding ops from loop nest where slice will be inserted. SmallVector ops; for (auto &dep : dependentOpPairs) { ops.push_back(isBackwardSlice ? dep.second : dep.first); } SmallVector surroundingLoops; unsigned innermostCommonLoopDepth = getInnermostCommonLoopDepth(ops, &surroundingLoops); if (loopDepth > innermostCommonLoopDepth) { LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n"); return SliceComputationResult::GenericFailure; } // Store 'numSliceLoopIVs' before converting dst loop IVs to dims. unsigned numSliceLoopIVs = sliceUnionCst.getNumDimVars(); // Convert any dst loop IVs which are symbol variables to dim variables. sliceUnionCst.convertLoopIVSymbolsToDims(); sliceUnion->clearBounds(); sliceUnion->lbs.resize(numSliceLoopIVs, AffineMap()); sliceUnion->ubs.resize(numSliceLoopIVs, AffineMap()); // Get slice bounds from slice union constraints 'sliceUnionCst'. sliceUnionCst.getSliceBounds(/*offset=*/0, numSliceLoopIVs, opsA[0]->getContext(), &sliceUnion->lbs, &sliceUnion->ubs); // Add slice bound operands of union. SmallVector sliceBoundOperands; sliceUnionCst.getValues(numSliceLoopIVs, sliceUnionCst.getNumDimAndSymbolVars(), &sliceBoundOperands); // Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'. sliceUnion->ivs.clear(); sliceUnionCst.getValues(0, numSliceLoopIVs, &sliceUnion->ivs); // Set loop nest insertion point to block start at 'loopDepth'. sliceUnion->insertPoint = isBackwardSlice ? surroundingLoops[loopDepth - 1].getBody()->begin() : std::prev(surroundingLoops[loopDepth - 1].getBody()->end()); // Give each bound its own copy of 'sliceBoundOperands' for subsequent // canonicalization. sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands); sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands); // Check if the slice computed is valid. Return success only if it is verified // that the slice is valid, otherwise return appropriate failure status. std::optional isSliceValid = sliceUnion->isSliceValid(); if (!isSliceValid) { LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n"); return SliceComputationResult::GenericFailure; } if (!*isSliceValid) return SliceComputationResult::IncorrectSliceFailure; return SliceComputationResult::Success; } // TODO: extend this to handle multiple result maps. static std::optional getConstDifference(AffineMap lbMap, AffineMap ubMap) { assert(lbMap.getNumResults() == 1 && "expected single result bound map"); assert(ubMap.getNumResults() == 1 && "expected single result bound map"); assert(lbMap.getNumDims() == ubMap.getNumDims()); assert(lbMap.getNumSymbols() == ubMap.getNumSymbols()); AffineExpr lbExpr(lbMap.getResult(0)); AffineExpr ubExpr(ubMap.getResult(0)); auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(), lbMap.getNumSymbols()); auto cExpr = dyn_cast(loopSpanExpr); if (!cExpr) return std::nullopt; return cExpr.getValue(); } // Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop // nest surrounding represented by slice loop bounds in 'slice'. Returns true // on success, false otherwise (if a non-constant trip count was encountered). // TODO: Make this work with non-unit step loops. bool mlir::affine::buildSliceTripCountMap( const ComputationSliceState &slice, llvm::SmallDenseMap *tripCountMap) { unsigned numSrcLoopIVs = slice.ivs.size(); // Populate map from AffineForOp -> trip count for (unsigned i = 0; i < numSrcLoopIVs; ++i) { AffineForOp forOp = getForInductionVarOwner(slice.ivs[i]); auto *op = forOp.getOperation(); AffineMap lbMap = slice.lbs[i]; AffineMap ubMap = slice.ubs[i]; // If lower or upper bound maps are null or provide no results, it implies // that source loop was not at all sliced, and the entire loop will be a // part of the slice. if (!lbMap || lbMap.getNumResults() == 0 || !ubMap || ubMap.getNumResults() == 0) { // The iteration of src loop IV 'i' was not sliced. Use full loop bounds. if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) { (*tripCountMap)[op] = forOp.getConstantUpperBound() - forOp.getConstantLowerBound(); continue; } std::optional maybeConstTripCount = getConstantTripCount(forOp); if (maybeConstTripCount.has_value()) { (*tripCountMap)[op] = *maybeConstTripCount; continue; } return false; } std::optional tripCount = getConstDifference(lbMap, ubMap); // Slice bounds are created with a constant ub - lb difference. if (!tripCount.has_value()) return false; (*tripCountMap)[op] = *tripCount; } return true; } // Return the number of iterations in the given slice. uint64_t mlir::affine::getSliceIterationCount( const llvm::SmallDenseMap &sliceTripCountMap) { uint64_t iterCount = 1; for (const auto &count : sliceTripCountMap) { iterCount *= count.second; } return iterCount; } const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier"; // Computes slice bounds by projecting out any loop IVs from // 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice // bounds in 'sliceState' which represent the one loop nest's IVs in terms of // the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice'). void mlir::affine::getComputationSliceState( Operation *depSourceOp, Operation *depSinkOp, FlatAffineValueConstraints *dependenceConstraints, unsigned loopDepth, bool isBackwardSlice, ComputationSliceState *sliceState) { // Get loop nest surrounding src operation. SmallVector srcLoopIVs; getAffineForIVs(*depSourceOp, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); // Get loop nest surrounding dst operation. SmallVector dstLoopIVs; getAffineForIVs(*depSinkOp, &dstLoopIVs); unsigned numDstLoopIVs = dstLoopIVs.size(); assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) || (isBackwardSlice && loopDepth <= numDstLoopIVs)); // Project out dimensions other than those up to 'loopDepth'. unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth; unsigned num = isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth; dependenceConstraints->projectOut(pos, num); // Add slice loop IV values to 'sliceState'. unsigned offset = isBackwardSlice ? 0 : loopDepth; unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs; dependenceConstraints->getValues(offset, offset + numSliceLoopIVs, &sliceState->ivs); // Set up lower/upper bound affine maps for the slice. sliceState->lbs.resize(numSliceLoopIVs, AffineMap()); sliceState->ubs.resize(numSliceLoopIVs, AffineMap()); // Get bounds for slice IVs in terms of other IVs, symbols, and constants. dependenceConstraints->getSliceBounds(offset, numSliceLoopIVs, depSourceOp->getContext(), &sliceState->lbs, &sliceState->ubs); // Set up bound operands for the slice's lower and upper bounds. SmallVector sliceBoundOperands; unsigned numDimsAndSymbols = dependenceConstraints->getNumDimAndSymbolVars(); for (unsigned i = 0; i < numDimsAndSymbols; ++i) { if (i < offset || i >= offset + numSliceLoopIVs) { sliceBoundOperands.push_back(dependenceConstraints->getValue(i)); } } // Give each bound its own copy of 'sliceBoundOperands' for subsequent // canonicalization. sliceState->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands); sliceState->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands); // Set destination loop nest insertion point to block start at 'dstLoopDepth'. sliceState->insertPoint = isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin() : std::prev(srcLoopIVs[loopDepth - 1].getBody()->end()); llvm::SmallDenseSet sequentialLoops; if (isa(depSourceOp) && isa(depSinkOp)) { // For read-read access pairs, clear any slice bounds on sequential loops. // Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'. getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0], &sequentialLoops); } auto getSliceLoop = [&](unsigned i) { return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i]; }; auto isInnermostInsertion = [&]() { return (isBackwardSlice ? loopDepth >= srcLoopIVs.size() : loopDepth >= dstLoopIVs.size()); }; llvm::SmallDenseMap sliceTripCountMap; auto srcIsUnitSlice = [&]() { return (buildSliceTripCountMap(*sliceState, &sliceTripCountMap) && (getSliceIterationCount(sliceTripCountMap) == 1)); }; // Clear all sliced loop bounds beginning at the first sequential loop, or // first loop with a slice fusion barrier attribute.. for (unsigned i = 0; i < numSliceLoopIVs; ++i) { Value iv = getSliceLoop(i).getInductionVar(); if (sequentialLoops.count(iv) == 0 && getSliceLoop(i)->getAttr(kSliceFusionBarrierAttrName) == nullptr) continue; // Skip reset of bounds of reduction loop inserted in the destination loop // that meets the following conditions: // 1. Slice is single trip count. // 2. Loop bounds of the source and destination match. // 3. Is being inserted at the innermost insertion point. std::optional isMaximal = sliceState->isMaximal(); if (isLoopParallelAndContainsReduction(getSliceLoop(i)) && isInnermostInsertion() && srcIsUnitSlice() && isMaximal && *isMaximal) continue; for (unsigned j = i; j < numSliceLoopIVs; ++j) { sliceState->lbs[j] = AffineMap(); sliceState->ubs[j] = AffineMap(); } break; } } /// Creates a computation slice of the loop nest surrounding 'srcOpInst', /// updates the slice loop bounds with any non-null bound maps specified in /// 'sliceState', and inserts this slice into the loop nest surrounding /// 'dstOpInst' at loop depth 'dstLoopDepth'. // TODO: extend the slicing utility to compute slices that // aren't necessarily a one-to-one relation b/w the source and destination. The // relation between the source and destination could be many-to-many in general. // TODO: the slice computation is incorrect in the cases // where the dependence from the source to the destination does not cover the // entire destination index set. Subtract out the dependent destination // iterations from destination index set and check for emptiness --- this is one // solution. AffineForOp mlir::affine::insertBackwardComputationSlice( Operation *srcOpInst, Operation *dstOpInst, unsigned dstLoopDepth, ComputationSliceState *sliceState) { // Get loop nest surrounding src operation. SmallVector srcLoopIVs; getAffineForIVs(*srcOpInst, &srcLoopIVs); unsigned numSrcLoopIVs = srcLoopIVs.size(); // Get loop nest surrounding dst operation. SmallVector dstLoopIVs; getAffineForIVs(*dstOpInst, &dstLoopIVs); unsigned dstLoopIVsSize = dstLoopIVs.size(); if (dstLoopDepth > dstLoopIVsSize) { dstOpInst->emitError("invalid destination loop depth"); return AffineForOp(); } // Find the op block positions of 'srcOpInst' within 'srcLoopIVs'. SmallVector positions; // TODO: This code is incorrect since srcLoopIVs can be 0-d. findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions); // Clone src loop nest and insert it a the beginning of the operation block // of the loop at 'dstLoopDepth' in 'dstLoopIVs'. auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1]; OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin()); auto sliceLoopNest = cast(b.clone(*srcLoopIVs[0].getOperation())); Operation *sliceInst = getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody()); // Get loop nest surrounding 'sliceInst'. SmallVector sliceSurroundingLoops; getAffineForIVs(*sliceInst, &sliceSurroundingLoops); // Sanity check. unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size(); (void)sliceSurroundingLoopsSize; assert(dstLoopDepth + numSrcLoopIVs >= sliceSurroundingLoopsSize); unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs; (void)sliceLoopLimit; assert(sliceLoopLimit >= sliceSurroundingLoopsSize); // Update loop bounds for loops in 'sliceLoopNest'. for (unsigned i = 0; i < numSrcLoopIVs; ++i) { auto forOp = sliceSurroundingLoops[dstLoopDepth + i]; if (AffineMap lbMap = sliceState->lbs[i]) forOp.setLowerBound(sliceState->lbOperands[i], lbMap); if (AffineMap ubMap = sliceState->ubs[i]) forOp.setUpperBound(sliceState->ubOperands[i], ubMap); } return sliceLoopNest; } // Constructs MemRefAccess populating it with the memref, its indices and // opinst from 'loadOrStoreOpInst'. MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) { if (auto loadOp = dyn_cast(loadOrStoreOpInst)) { memref = loadOp.getMemRef(); opInst = loadOrStoreOpInst; llvm::append_range(indices, loadOp.getMapOperands()); } else { assert(isa(loadOrStoreOpInst) && "Affine read/write op expected"); auto storeOp = cast(loadOrStoreOpInst); opInst = loadOrStoreOpInst; memref = storeOp.getMemRef(); llvm::append_range(indices, storeOp.getMapOperands()); } } unsigned MemRefAccess::getRank() const { return cast(memref.getType()).getRank(); } bool MemRefAccess::isStore() const { return isa(opInst); } /// Returns the nesting depth of this statement, i.e., the number of loops /// surrounding this statement. unsigned mlir::affine::getNestingDepth(Operation *op) { Operation *currOp = op; unsigned depth = 0; while ((currOp = currOp->getParentOp())) { if (isa(currOp)) depth++; } return depth; } /// Equal if both affine accesses are provably equivalent (at compile /// time) when considering the memref, the affine maps and their respective /// operands. The equality of access functions + operands is checked by /// subtracting fully composed value maps, and then simplifying the difference /// using the expression flattener. /// TODO: this does not account for aliasing of memrefs. bool MemRefAccess::operator==(const MemRefAccess &rhs) const { if (memref != rhs.memref) return false; AffineValueMap diff, thisMap, rhsMap; getAccessMap(&thisMap); rhs.getAccessMap(&rhsMap); AffineValueMap::difference(thisMap, rhsMap, &diff); return llvm::all_of(diff.getAffineMap().getResults(), [](AffineExpr e) { return e == 0; }); } void mlir::affine::getAffineIVs(Operation &op, SmallVectorImpl &ivs) { auto *currOp = op.getParentOp(); AffineForOp currAffineForOp; // Traverse up the hierarchy collecting all 'affine.for' and affine.parallel // operation while skipping over 'affine.if' operations. while (currOp) { if (AffineForOp currAffineForOp = dyn_cast(currOp)) ivs.push_back(currAffineForOp.getInductionVar()); else if (auto parOp = dyn_cast(currOp)) llvm::append_range(ivs, parOp.getIVs()); currOp = currOp->getParentOp(); } std::reverse(ivs.begin(), ivs.end()); } /// Returns the number of surrounding loops common to 'loopsA' and 'loopsB', /// where each lists loops from outer-most to inner-most in loop nest. unsigned mlir::affine::getNumCommonSurroundingLoops(Operation &a, Operation &b) { SmallVector loopsA, loopsB; getAffineIVs(a, loopsA); getAffineIVs(b, loopsB); unsigned minNumLoops = std::min(loopsA.size(), loopsB.size()); unsigned numCommonLoops = 0; for (unsigned i = 0; i < minNumLoops; ++i) { if (loopsA[i] != loopsB[i]) break; ++numCommonLoops; } return numCommonLoops; } static std::optional getMemoryFootprintBytes(Block &block, Block::iterator start, Block::iterator end, int memorySpace) { SmallDenseMap, 4> regions; // Walk this 'affine.for' operation to gather all memory regions. auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult { if (!isa(opInst)) { // Neither load nor a store op. return WalkResult::advance(); } // Compute the memref region symbolic in any IVs enclosing this block. auto region = std::make_unique(opInst->getLoc()); if (failed( region->compute(opInst, /*loopDepth=*/getNestingDepth(&*block.begin())))) { return opInst->emitError("error obtaining memory region\n"); } auto it = regions.find(region->memref); if (it == regions.end()) { regions[region->memref] = std::move(region); } else if (failed(it->second->unionBoundingBox(*region))) { return opInst->emitWarning( "getMemoryFootprintBytes: unable to perform a union on a memory " "region"); } return WalkResult::advance(); }); if (result.wasInterrupted()) return std::nullopt; int64_t totalSizeInBytes = 0; for (const auto ®ion : regions) { std::optional size = region.second->getRegionSize(); if (!size.has_value()) return std::nullopt; totalSizeInBytes += *size; } return totalSizeInBytes; } std::optional mlir::affine::getMemoryFootprintBytes(AffineForOp forOp, int memorySpace) { auto *forInst = forOp.getOperation(); return ::getMemoryFootprintBytes( *forInst->getBlock(), Block::iterator(forInst), std::next(Block::iterator(forInst)), memorySpace); } /// Returns whether a loop is parallel and contains a reduction loop. bool mlir::affine::isLoopParallelAndContainsReduction(AffineForOp forOp) { SmallVector reductions; if (!isLoopParallel(forOp, &reductions)) return false; return !reductions.empty(); } /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted /// at 'forOp'. void mlir::affine::getSequentialLoops( AffineForOp forOp, llvm::SmallDenseSet *sequentialLoops) { forOp->walk([&](Operation *op) { if (auto innerFor = dyn_cast(op)) if (!isLoopParallel(innerFor)) sequentialLoops->insert(innerFor.getInductionVar()); }); } IntegerSet mlir::affine::simplifyIntegerSet(IntegerSet set) { FlatAffineValueConstraints fac(set); if (fac.isEmpty()) return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(), set.getContext()); fac.removeTrivialRedundancy(); auto simplifiedSet = fac.getAsIntegerSet(set.getContext()); assert(simplifiedSet && "guaranteed to succeed while roundtripping"); return simplifiedSet; } static void unpackOptionalValues(ArrayRef> source, SmallVector &target) { target = llvm::to_vector<4>(llvm::map_range(source, [](std::optional val) { return val.has_value() ? *val : Value(); })); } /// Bound an identifier `pos` in a given FlatAffineValueConstraints with /// constraints drawn from an affine map. Before adding the constraint, the /// dimensions/symbols of the affine map are aligned with `constraints`. /// `operands` are the SSA Value operands used with the affine map. /// Note: This function adds a new symbol column to the `constraints` for each /// dimension/symbol that exists in the affine map but not in `constraints`. static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints, BoundType type, unsigned pos, AffineMap map, ValueRange operands) { SmallVector dims, syms, newSyms; unpackOptionalValues(constraints.getMaybeValues(VarKind::SetDim), dims); unpackOptionalValues(constraints.getMaybeValues(VarKind::Symbol), syms); AffineMap alignedMap = alignAffineMapWithValues(map, operands, dims, syms, &newSyms); for (unsigned i = syms.size(); i < newSyms.size(); ++i) constraints.appendSymbolVar(newSyms[i]); return constraints.addBound(type, pos, alignedMap); } /// Add `val` to each result of `map`. static AffineMap addConstToResults(AffineMap map, int64_t val) { SmallVector newResults; for (AffineExpr r : map.getResults()) newResults.push_back(r + val); return AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults, map.getContext()); } // Attempt to simplify the given min/max operation by proving that its value is // bounded by the same lower and upper bound. // // Bounds are computed by FlatAffineValueConstraints. Invariants required for // finding/proving bounds should be supplied via `constraints`. // // 1. Add dimensions for `op` and `opBound` (lower or upper bound of `op`). // 2. Compute an upper bound of `op` (in case of `isMin`) or a lower bound (in // case of `!isMin`) and bind it to `opBound`. SSA values that are used in // `op` but are not part of `constraints`, are added as extra symbols. // 3. For each result of `op`: Add result as a dimension `r_i`. Prove that: // * If `isMin`: r_i >= opBound // * If `isMax`: r_i <= opBound // If this is the case, ub(op) == lb(op). // 4. Replace `op` with `opBound`. // // In summary, the following constraints are added throughout this function. // Note: `invar` are dimensions added by the caller to express the invariants. // (Showing only the case where `isMin`.) // // invar | op | opBound | r_i | extra syms... | const | eq/ineq // ------+-------+---------+-----+---------------+-------+------------------- // (various eq./ineq. constraining `invar`, added by the caller) // ... | 0 | 0 | 0 | 0 | ... | ... // ------+-------+---------+-----+---------------+-------+------------------- // (various ineq. constraining `op` in terms of `op` operands (`invar` and // extra `op` operands "extra syms" that are not in `invar`)). // ... | -1 | 0 | 0 | ... | ... | >= 0 // ------+-------+---------+-----+---------------+-------+------------------- // (set `opBound` to `op` upper bound in terms of `invar` and "extra syms") // ... | 0 | -1 | 0 | ... | ... | = 0 // ------+-------+---------+-----+---------------+-------+------------------- // (for each `op` map result r_i: set r_i to corresponding map result, // prove that r_i >= minOpUb via contradiction) // ... | 0 | 0 | -1 | ... | ... | = 0 // 0 | 0 | 1 | -1 | 0 | -1 | >= 0 // FailureOr mlir::affine::simplifyConstrainedMinMaxOp( Operation *op, FlatAffineValueConstraints constraints) { bool isMin = isa(op); assert((isMin || isa(op)) && "expect AffineMin/MaxOp"); MLIRContext *ctx = op->getContext(); Builder builder(ctx); AffineMap map = isMin ? cast(op).getMap() : cast(op).getMap(); ValueRange operands = op->getOperands(); unsigned numResults = map.getNumResults(); // Add a few extra dimensions. unsigned dimOp = constraints.appendDimVar(); // `op` unsigned dimOpBound = constraints.appendDimVar(); // `op` lower/upper bound unsigned resultDimStart = constraints.appendDimVar(/*num=*/numResults); // Add an inequality for each result expr_i of map: // isMin: op <= expr_i, !isMin: op >= expr_i auto boundType = isMin ? BoundType::UB : BoundType::LB; // Upper bounds are exclusive, so add 1. (`affine.min` ops are inclusive.) AffineMap mapLbUb = isMin ? addConstToResults(map, 1) : map; if (failed( alignAndAddBound(constraints, boundType, dimOp, mapLbUb, operands))) return failure(); // Try to compute a lower/upper bound for op, expressed in terms of the other // `dims` and extra symbols. SmallVector opLb(1), opUb(1); constraints.getSliceBounds(dimOp, 1, ctx, &opLb, &opUb); AffineMap sliceBound = isMin ? opUb[0] : opLb[0]; // TODO: `getSliceBounds` may return multiple bounds at the moment. This is // a TODO of `getSliceBounds` and not handled here. if (!sliceBound || sliceBound.getNumResults() != 1) return failure(); // No or multiple bounds found. // Recover the inclusive UB in the case of an `affine.min`. AffineMap boundMap = isMin ? addConstToResults(sliceBound, -1) : sliceBound; // Add an equality: Set dimOpBound to computed bound. // Add back dimension for op. (Was removed by `getSliceBounds`.) AffineMap alignedBoundMap = boundMap.shiftDims(/*shift=*/1, /*offset=*/dimOp); if (failed(constraints.addBound(BoundType::EQ, dimOpBound, alignedBoundMap))) return failure(); // If the constraint system is empty, there is an inconsistency. (E.g., this // can happen if loop lb > ub.) if (constraints.isEmpty()) return failure(); // In the case of `isMin` (`!isMin` is inversed): // Prove that each result of `map` has a lower bound that is equal to (or // greater than) the upper bound of `op` (`dimOpBound`). In that case, `op` // can be replaced with the bound. I.e., prove that for each result // expr_i (represented by dimension r_i): // // r_i >= opBound // // To prove this inequality, add its negation to the constraint set and prove // that the constraint set is empty. for (unsigned i = resultDimStart; i < resultDimStart + numResults; ++i) { FlatAffineValueConstraints newConstr(constraints); // Add an equality: r_i = expr_i // Note: These equalities could have been added earlier and used to express // minOp <= expr_i. However, then we run the risk that `getSliceBounds` // computes minOpUb in terms of r_i dims, which is not desired. if (failed(alignAndAddBound(newConstr, BoundType::EQ, i, map.getSubMap({i - resultDimStart}), operands))) return failure(); // If `isMin`: Add inequality: r_i < opBound // equiv.: opBound - r_i - 1 >= 0 // If `!isMin`: Add inequality: r_i > opBound // equiv.: -opBound + r_i - 1 >= 0 SmallVector ineq(newConstr.getNumCols(), 0); ineq[dimOpBound] = isMin ? 1 : -1; ineq[i] = isMin ? -1 : 1; ineq[newConstr.getNumCols() - 1] = -1; newConstr.addInequality(ineq); if (!newConstr.isEmpty()) return failure(); } // Lower and upper bound of `op` are equal. Replace `minOp` with its bound. AffineMap newMap = alignedBoundMap; SmallVector newOperands; unpackOptionalValues(constraints.getMaybeValues(), newOperands); // If dims/symbols have known constant values, use those in order to simplify // the affine map further. for (int64_t i = 0, e = constraints.getNumDimAndSymbolVars(); i < e; ++i) { // Skip unused operands and operands that are already constants. if (!newOperands[i] || getConstantIntValue(newOperands[i])) continue; if (auto bound = constraints.getConstantBound64(BoundType::EQ, i)) { AffineExpr expr = i < newMap.getNumDims() ? builder.getAffineDimExpr(i) : builder.getAffineSymbolExpr(i - newMap.getNumDims()); newMap = newMap.replace(expr, builder.getAffineConstantExpr(*bound), newMap.getNumDims(), newMap.getNumSymbols()); } } affine::canonicalizeMapAndOperands(&newMap, &newOperands); return AffineValueMap(newMap, newOperands); }