2170 lines
87 KiB
C++
2170 lines
87 KiB
C++
|
//===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
//
|
||
|
// This file implements miscellaneous analysis routines for non-loop IR
|
||
|
// structures.
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "mlir/Dialect/Affine/Analysis/Utils.h"
|
||
|
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
|
||
|
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
|
||
|
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
|
||
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||
|
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
|
||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||
|
#include "mlir/IR/IntegerSet.h"
|
||
|
#include "mlir/Interfaces/CallInterfaces.h"
|
||
|
#include "llvm/ADT/SetVector.h"
|
||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||
|
#include "llvm/Support/Debug.h"
|
||
|
#include "llvm/Support/raw_ostream.h"
|
||
|
#include <optional>
|
||
|
|
||
|
#define DEBUG_TYPE "analysis-utils"
|
||
|
|
||
|
using namespace mlir;
|
||
|
using namespace affine;
|
||
|
using namespace presburger;
|
||
|
|
||
|
using llvm::SmallDenseMap;
|
||
|
|
||
|
using Node = MemRefDependenceGraph::Node;
|
||
|
|
||
|
// LoopNestStateCollector walks loop nests and collects load and store
|
||
|
// operations, and whether or not a region holding op other than ForOp and IfOp
|
||
|
// was encountered in the loop nest.
|
||
|
void LoopNestStateCollector::collect(Operation *opToWalk) {
|
||
|
opToWalk->walk([&](Operation *op) {
|
||
|
if (isa<AffineForOp>(op))
|
||
|
forOps.push_back(cast<AffineForOp>(op));
|
||
|
else if (op->getNumRegions() != 0 && !isa<AffineIfOp>(op))
|
||
|
hasNonAffineRegionOp = true;
|
||
|
else if (isa<AffineReadOpInterface>(op))
|
||
|
loadOpInsts.push_back(op);
|
||
|
else if (isa<AffineWriteOpInterface>(op))
|
||
|
storeOpInsts.push_back(op);
|
||
|
});
|
||
|
}
|
||
|
|
||
|
// Returns the load op count for 'memref'.
|
||
|
unsigned Node::getLoadOpCount(Value memref) const {
|
||
|
unsigned loadOpCount = 0;
|
||
|
for (Operation *loadOp : loads) {
|
||
|
if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
|
||
|
++loadOpCount;
|
||
|
}
|
||
|
return loadOpCount;
|
||
|
}
|
||
|
|
||
|
// Returns the store op count for 'memref'.
|
||
|
unsigned Node::getStoreOpCount(Value memref) const {
|
||
|
unsigned storeOpCount = 0;
|
||
|
for (Operation *storeOp : stores) {
|
||
|
if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
|
||
|
++storeOpCount;
|
||
|
}
|
||
|
return storeOpCount;
|
||
|
}
|
||
|
|
||
|
// Returns all store ops in 'storeOps' which access 'memref'.
|
||
|
void Node::getStoreOpsForMemref(Value memref,
|
||
|
SmallVectorImpl<Operation *> *storeOps) const {
|
||
|
for (Operation *storeOp : stores) {
|
||
|
if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
|
||
|
storeOps->push_back(storeOp);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Returns all load ops in 'loadOps' which access 'memref'.
|
||
|
void Node::getLoadOpsForMemref(Value memref,
|
||
|
SmallVectorImpl<Operation *> *loadOps) const {
|
||
|
for (Operation *loadOp : loads) {
|
||
|
if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
|
||
|
loadOps->push_back(loadOp);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
|
||
|
// has at least one load and store operation.
|
||
|
void Node::getLoadAndStoreMemrefSet(
|
||
|
DenseSet<Value> *loadAndStoreMemrefSet) const {
|
||
|
llvm::SmallDenseSet<Value, 2> loadMemrefs;
|
||
|
for (Operation *loadOp : loads) {
|
||
|
loadMemrefs.insert(cast<AffineReadOpInterface>(loadOp).getMemRef());
|
||
|
}
|
||
|
for (Operation *storeOp : stores) {
|
||
|
auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
|
||
|
if (loadMemrefs.count(memref) > 0)
|
||
|
loadAndStoreMemrefSet->insert(memref);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Initializes the data dependence graph by walking operations in `block`.
|
||
|
// Assigns each node in the graph a node id based on program order in 'f'.
|
||
|
bool MemRefDependenceGraph::init() {
|
||
|
LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
|
||
|
// Map from a memref to the set of ids of the nodes that have ops accessing
|
||
|
// the memref.
|
||
|
DenseMap<Value, SetVector<unsigned>> memrefAccesses;
|
||
|
|
||
|
DenseMap<Operation *, unsigned> forToNodeMap;
|
||
|
for (Operation &op : block) {
|
||
|
if (dyn_cast<AffineForOp>(op)) {
|
||
|
// Create graph node 'id' to represent top-level 'forOp' and record
|
||
|
// all loads and store accesses it contains.
|
||
|
LoopNestStateCollector collector;
|
||
|
collector.collect(&op);
|
||
|
// Return false if a region holding op other than 'affine.for' and
|
||
|
// 'affine.if' was found (not currently supported).
|
||
|
if (collector.hasNonAffineRegionOp)
|
||
|
return false;
|
||
|
Node node(nextNodeId++, &op);
|
||
|
for (auto *opInst : collector.loadOpInsts) {
|
||
|
node.loads.push_back(opInst);
|
||
|
auto memref = cast<AffineReadOpInterface>(opInst).getMemRef();
|
||
|
memrefAccesses[memref].insert(node.id);
|
||
|
}
|
||
|
for (auto *opInst : collector.storeOpInsts) {
|
||
|
node.stores.push_back(opInst);
|
||
|
auto memref = cast<AffineWriteOpInterface>(opInst).getMemRef();
|
||
|
memrefAccesses[memref].insert(node.id);
|
||
|
}
|
||
|
forToNodeMap[&op] = node.id;
|
||
|
nodes.insert({node.id, node});
|
||
|
} else if (dyn_cast<AffineReadOpInterface>(op)) {
|
||
|
// Create graph node for top-level load op.
|
||
|
Node node(nextNodeId++, &op);
|
||
|
node.loads.push_back(&op);
|
||
|
auto memref = cast<AffineReadOpInterface>(op).getMemRef();
|
||
|
memrefAccesses[memref].insert(node.id);
|
||
|
nodes.insert({node.id, node});
|
||
|
} else if (dyn_cast<AffineWriteOpInterface>(op)) {
|
||
|
// Create graph node for top-level store op.
|
||
|
Node node(nextNodeId++, &op);
|
||
|
node.stores.push_back(&op);
|
||
|
auto memref = cast<AffineWriteOpInterface>(op).getMemRef();
|
||
|
memrefAccesses[memref].insert(node.id);
|
||
|
nodes.insert({node.id, node});
|
||
|
} else if (op.getNumResults() > 0 && !op.use_empty()) {
|
||
|
// Create graph node for top-level producer of SSA values, which
|
||
|
// could be used by loop nest nodes.
|
||
|
Node node(nextNodeId++, &op);
|
||
|
nodes.insert({node.id, node});
|
||
|
} else if (!isMemoryEffectFree(&op) &&
|
||
|
(op.getNumRegions() == 0 || isa<RegionBranchOpInterface>(op))) {
|
||
|
// Create graph node for top-level op unless it is known to be
|
||
|
// memory-effect free. This covers all unknown/unregistered ops,
|
||
|
// non-affine ops with memory effects, and region-holding ops with a
|
||
|
// well-defined control flow. During the fusion validity checks, we look
|
||
|
// for non-affine ops on the path from source to destination, at which
|
||
|
// point we check which memrefs if any are used in the region.
|
||
|
Node node(nextNodeId++, &op);
|
||
|
nodes.insert({node.id, node});
|
||
|
} else if (op.getNumRegions() != 0) {
|
||
|
// Return false if non-handled/unknown region-holding ops are found. We
|
||
|
// won't know what such ops do or what its regions mean; for e.g., it may
|
||
|
// not be an imperative op.
|
||
|
LLVM_DEBUG(llvm::dbgs()
|
||
|
<< "MDG init failed; unknown region-holding op found!\n");
|
||
|
return false;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
for (auto &idAndNode : nodes) {
|
||
|
LLVM_DEBUG(llvm::dbgs() << "Create node " << idAndNode.first << " for:\n"
|
||
|
<< *(idAndNode.second.op) << "\n");
|
||
|
(void)idAndNode;
|
||
|
}
|
||
|
|
||
|
// Add dependence edges between nodes which produce SSA values and their
|
||
|
// users. Load ops can be considered as the ones producing SSA values.
|
||
|
for (auto &idAndNode : nodes) {
|
||
|
const Node &node = idAndNode.second;
|
||
|
// Stores don't define SSA values, skip them.
|
||
|
if (!node.stores.empty())
|
||
|
continue;
|
||
|
Operation *opInst = node.op;
|
||
|
for (Value value : opInst->getResults()) {
|
||
|
for (Operation *user : value.getUsers()) {
|
||
|
// Ignore users outside of the block.
|
||
|
if (block.getParent()->findAncestorOpInRegion(*user)->getBlock() !=
|
||
|
&block)
|
||
|
continue;
|
||
|
SmallVector<AffineForOp, 4> loops;
|
||
|
getAffineForIVs(*user, &loops);
|
||
|
// Find the surrounding affine.for nested immediately within the
|
||
|
// block.
|
||
|
auto *it = llvm::find_if(loops, [&](AffineForOp loop) {
|
||
|
return loop->getBlock() == █
|
||
|
});
|
||
|
if (it == loops.end())
|
||
|
continue;
|
||
|
assert(forToNodeMap.count(*it) > 0 && "missing mapping");
|
||
|
unsigned userLoopNestId = forToNodeMap[*it];
|
||
|
addEdge(node.id, userLoopNestId, value);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Walk memref access lists and add graph edges between dependent nodes.
|
||
|
for (auto &memrefAndList : memrefAccesses) {
|
||
|
unsigned n = memrefAndList.second.size();
|
||
|
for (unsigned i = 0; i < n; ++i) {
|
||
|
unsigned srcId = memrefAndList.second[i];
|
||
|
bool srcHasStore =
|
||
|
getNode(srcId)->getStoreOpCount(memrefAndList.first) > 0;
|
||
|
for (unsigned j = i + 1; j < n; ++j) {
|
||
|
unsigned dstId = memrefAndList.second[j];
|
||
|
bool dstHasStore =
|
||
|
getNode(dstId)->getStoreOpCount(memrefAndList.first) > 0;
|
||
|
if (srcHasStore || dstHasStore)
|
||
|
addEdge(srcId, dstId, memrefAndList.first);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
// Returns the graph node for 'id'.
|
||
|
Node *MemRefDependenceGraph::getNode(unsigned id) {
|
||
|
auto it = nodes.find(id);
|
||
|
assert(it != nodes.end());
|
||
|
return &it->second;
|
||
|
}
|
||
|
|
||
|
// Returns the graph node for 'forOp'.
|
||
|
Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) {
|
||
|
for (auto &idAndNode : nodes)
|
||
|
if (idAndNode.second.op == forOp)
|
||
|
return &idAndNode.second;
|
||
|
return nullptr;
|
||
|
}
|
||
|
|
||
|
// Adds a node with 'op' to the graph and returns its unique identifier.
|
||
|
unsigned MemRefDependenceGraph::addNode(Operation *op) {
|
||
|
Node node(nextNodeId++, op);
|
||
|
nodes.insert({node.id, node});
|
||
|
return node.id;
|
||
|
}
|
||
|
|
||
|
// Remove node 'id' (and its associated edges) from graph.
|
||
|
void MemRefDependenceGraph::removeNode(unsigned id) {
|
||
|
// Remove each edge in 'inEdges[id]'.
|
||
|
if (inEdges.count(id) > 0) {
|
||
|
SmallVector<Edge, 2> oldInEdges = inEdges[id];
|
||
|
for (auto &inEdge : oldInEdges) {
|
||
|
removeEdge(inEdge.id, id, inEdge.value);
|
||
|
}
|
||
|
}
|
||
|
// Remove each edge in 'outEdges[id]'.
|
||
|
if (outEdges.count(id) > 0) {
|
||
|
SmallVector<Edge, 2> oldOutEdges = outEdges[id];
|
||
|
for (auto &outEdge : oldOutEdges) {
|
||
|
removeEdge(id, outEdge.id, outEdge.value);
|
||
|
}
|
||
|
}
|
||
|
// Erase remaining node state.
|
||
|
inEdges.erase(id);
|
||
|
outEdges.erase(id);
|
||
|
nodes.erase(id);
|
||
|
}
|
||
|
|
||
|
// Returns true if node 'id' writes to any memref which escapes (or is an
|
||
|
// argument to) the block. Returns false otherwise.
|
||
|
bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
|
||
|
Node *node = getNode(id);
|
||
|
for (auto *storeOpInst : node->stores) {
|
||
|
auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
|
||
|
auto *op = memref.getDefiningOp();
|
||
|
// Return true if 'memref' is a block argument.
|
||
|
if (!op)
|
||
|
return true;
|
||
|
// Return true if any use of 'memref' does not deference it in an affine
|
||
|
// way.
|
||
|
for (auto *user : memref.getUsers())
|
||
|
if (!isa<AffineMapAccessInterface>(*user))
|
||
|
return true;
|
||
|
}
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
|
||
|
// is for 'value' if non-null, or for any value otherwise. Returns false
|
||
|
// otherwise.
|
||
|
bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
|
||
|
Value value) {
|
||
|
if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
|
||
|
return false;
|
||
|
}
|
||
|
bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
|
||
|
return edge.id == dstId && (!value || edge.value == value);
|
||
|
});
|
||
|
bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
|
||
|
return edge.id == srcId && (!value || edge.value == value);
|
||
|
});
|
||
|
return hasOutEdge && hasInEdge;
|
||
|
}
|
||
|
|
||
|
// Adds an edge from node 'srcId' to node 'dstId' for 'value'.
|
||
|
void MemRefDependenceGraph::addEdge(unsigned srcId, unsigned dstId,
|
||
|
Value value) {
|
||
|
if (!hasEdge(srcId, dstId, value)) {
|
||
|
outEdges[srcId].push_back({dstId, value});
|
||
|
inEdges[dstId].push_back({srcId, value});
|
||
|
if (isa<MemRefType>(value.getType()))
|
||
|
memrefEdgeCount[value]++;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Removes an edge from node 'srcId' to node 'dstId' for 'value'.
|
||
|
void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
|
||
|
Value value) {
|
||
|
assert(inEdges.count(dstId) > 0);
|
||
|
assert(outEdges.count(srcId) > 0);
|
||
|
if (isa<MemRefType>(value.getType())) {
|
||
|
assert(memrefEdgeCount.count(value) > 0);
|
||
|
memrefEdgeCount[value]--;
|
||
|
}
|
||
|
// Remove 'srcId' from 'inEdges[dstId]'.
|
||
|
for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
|
||
|
if ((*it).id == srcId && (*it).value == value) {
|
||
|
inEdges[dstId].erase(it);
|
||
|
break;
|
||
|
}
|
||
|
}
|
||
|
// Remove 'dstId' from 'outEdges[srcId]'.
|
||
|
for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
|
||
|
if ((*it).id == dstId && (*it).value == value) {
|
||
|
outEdges[srcId].erase(it);
|
||
|
break;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Returns true if there is a path in the dependence graph from node 'srcId'
|
||
|
// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
|
||
|
// operations that the edges connected are expected to be from the same block.
|
||
|
bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
|
||
|
// Worklist state is: <node-id, next-output-edge-index-to-visit>
|
||
|
SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
|
||
|
worklist.push_back({srcId, 0});
|
||
|
Operation *dstOp = getNode(dstId)->op;
|
||
|
// Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
|
||
|
while (!worklist.empty()) {
|
||
|
auto &idAndIndex = worklist.back();
|
||
|
// Return true if we have reached 'dstId'.
|
||
|
if (idAndIndex.first == dstId)
|
||
|
return true;
|
||
|
// Pop and continue if node has no out edges, or if all out edges have
|
||
|
// already been visited.
|
||
|
if (outEdges.count(idAndIndex.first) == 0 ||
|
||
|
idAndIndex.second == outEdges[idAndIndex.first].size()) {
|
||
|
worklist.pop_back();
|
||
|
continue;
|
||
|
}
|
||
|
// Get graph edge to traverse.
|
||
|
Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
|
||
|
// Increment next output edge index for 'idAndIndex'.
|
||
|
++idAndIndex.second;
|
||
|
// Add node at 'edge.id' to the worklist. We don't need to consider
|
||
|
// nodes that are "after" dstId in the containing block; one can't have a
|
||
|
// path to `dstId` from any of those nodes.
|
||
|
bool afterDst = dstOp->isBeforeInBlock(getNode(edge.id)->op);
|
||
|
if (!afterDst && edge.id != idAndIndex.first)
|
||
|
worklist.push_back({edge.id, 0});
|
||
|
}
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
// Returns the input edge count for node 'id' and 'memref' from src nodes
|
||
|
// which access 'memref' with a store operation.
|
||
|
unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id,
|
||
|
Value memref) {
|
||
|
unsigned inEdgeCount = 0;
|
||
|
if (inEdges.count(id) > 0)
|
||
|
for (auto &inEdge : inEdges[id])
|
||
|
if (inEdge.value == memref) {
|
||
|
Node *srcNode = getNode(inEdge.id);
|
||
|
// Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
|
||
|
if (srcNode->getStoreOpCount(memref) > 0)
|
||
|
++inEdgeCount;
|
||
|
}
|
||
|
return inEdgeCount;
|
||
|
}
|
||
|
|
||
|
// Returns the output edge count for node 'id' and 'memref' (if non-null),
|
||
|
// otherwise returns the total output edge count from node 'id'.
|
||
|
unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
|
||
|
unsigned outEdgeCount = 0;
|
||
|
if (outEdges.count(id) > 0)
|
||
|
for (auto &outEdge : outEdges[id])
|
||
|
if (!memref || outEdge.value == memref)
|
||
|
++outEdgeCount;
|
||
|
return outEdgeCount;
|
||
|
}
|
||
|
|
||
|
/// Return all nodes which define SSA values used in node 'id'.
|
||
|
void MemRefDependenceGraph::gatherDefiningNodes(
|
||
|
unsigned id, DenseSet<unsigned> &definingNodes) {
|
||
|
for (MemRefDependenceGraph::Edge edge : inEdges[id])
|
||
|
// By definition of edge, if the edge value is a non-memref value,
|
||
|
// then the dependence is between a graph node which defines an SSA value
|
||
|
// and another graph node which uses the SSA value.
|
||
|
if (!isa<MemRefType>(edge.value.getType()))
|
||
|
definingNodes.insert(edge.id);
|
||
|
}
|
||
|
|
||
|
// Computes and returns an insertion point operation, before which the
|
||
|
// the fused <srcId, dstId> loop nest can be inserted while preserving
|
||
|
// dependences. Returns nullptr if no such insertion point is found.
|
||
|
Operation *
|
||
|
MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
|
||
|
unsigned dstId) {
|
||
|
if (outEdges.count(srcId) == 0)
|
||
|
return getNode(dstId)->op;
|
||
|
|
||
|
// Skip if there is any defining node of 'dstId' that depends on 'srcId'.
|
||
|
DenseSet<unsigned> definingNodes;
|
||
|
gatherDefiningNodes(dstId, definingNodes);
|
||
|
if (llvm::any_of(definingNodes,
|
||
|
[&](unsigned id) { return hasDependencePath(srcId, id); })) {
|
||
|
LLVM_DEBUG(llvm::dbgs()
|
||
|
<< "Can't fuse: a defining op with a user in the dst "
|
||
|
"loop has dependence from the src loop\n");
|
||
|
return nullptr;
|
||
|
}
|
||
|
|
||
|
// Build set of insts in range (srcId, dstId) which depend on 'srcId'.
|
||
|
SmallPtrSet<Operation *, 2> srcDepInsts;
|
||
|
for (auto &outEdge : outEdges[srcId])
|
||
|
if (outEdge.id != dstId)
|
||
|
srcDepInsts.insert(getNode(outEdge.id)->op);
|
||
|
|
||
|
// Build set of insts in range (srcId, dstId) on which 'dstId' depends.
|
||
|
SmallPtrSet<Operation *, 2> dstDepInsts;
|
||
|
for (auto &inEdge : inEdges[dstId])
|
||
|
if (inEdge.id != srcId)
|
||
|
dstDepInsts.insert(getNode(inEdge.id)->op);
|
||
|
|
||
|
Operation *srcNodeInst = getNode(srcId)->op;
|
||
|
Operation *dstNodeInst = getNode(dstId)->op;
|
||
|
|
||
|
// Computing insertion point:
|
||
|
// *) Walk all operation positions in Block operation list in the
|
||
|
// range (src, dst). For each operation 'op' visited in this search:
|
||
|
// *) Store in 'firstSrcDepPos' the first position where 'op' has a
|
||
|
// dependence edge from 'srcNode'.
|
||
|
// *) Store in 'lastDstDepPost' the last position where 'op' has a
|
||
|
// dependence edge to 'dstNode'.
|
||
|
// *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
|
||
|
// operation insertion point (or return null pointer if no such
|
||
|
// insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
|
||
|
SmallVector<Operation *, 2> depInsts;
|
||
|
std::optional<unsigned> firstSrcDepPos;
|
||
|
std::optional<unsigned> lastDstDepPos;
|
||
|
unsigned pos = 0;
|
||
|
for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
|
||
|
it != Block::iterator(dstNodeInst); ++it) {
|
||
|
Operation *op = &(*it);
|
||
|
if (srcDepInsts.count(op) > 0 && firstSrcDepPos == std::nullopt)
|
||
|
firstSrcDepPos = pos;
|
||
|
if (dstDepInsts.count(op) > 0)
|
||
|
lastDstDepPos = pos;
|
||
|
depInsts.push_back(op);
|
||
|
++pos;
|
||
|
}
|
||
|
|
||
|
if (firstSrcDepPos.has_value()) {
|
||
|
if (lastDstDepPos.has_value()) {
|
||
|
if (*firstSrcDepPos <= *lastDstDepPos) {
|
||
|
// No valid insertion point exists which preserves dependences.
|
||
|
return nullptr;
|
||
|
}
|
||
|
}
|
||
|
// Return the insertion point at 'firstSrcDepPos'.
|
||
|
return depInsts[*firstSrcDepPos];
|
||
|
}
|
||
|
// No dependence targets in range (or only dst deps in range), return
|
||
|
// 'dstNodInst' insertion point.
|
||
|
return dstNodeInst;
|
||
|
}
|
||
|
|
||
|
// Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
|
||
|
// taking into account that:
|
||
|
// *) if 'removeSrcId' is true, 'srcId' will be removed after fusion,
|
||
|
// *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a
|
||
|
// private memref.
|
||
|
void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId,
|
||
|
const DenseSet<Value> &privateMemRefs,
|
||
|
bool removeSrcId) {
|
||
|
// For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'.
|
||
|
if (inEdges.count(srcId) > 0) {
|
||
|
SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
|
||
|
for (auto &inEdge : oldInEdges) {
|
||
|
// Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
|
||
|
if (privateMemRefs.count(inEdge.value) == 0)
|
||
|
addEdge(inEdge.id, dstId, inEdge.value);
|
||
|
}
|
||
|
}
|
||
|
// For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
|
||
|
// If 'srcId' is going to be removed, remap all the out edges to 'dstId'.
|
||
|
if (outEdges.count(srcId) > 0) {
|
||
|
SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
|
||
|
for (auto &outEdge : oldOutEdges) {
|
||
|
// Remove any out edges from 'srcId' to 'dstId' across memrefs.
|
||
|
if (outEdge.id == dstId)
|
||
|
removeEdge(srcId, outEdge.id, outEdge.value);
|
||
|
else if (removeSrcId) {
|
||
|
addEdge(dstId, outEdge.id, outEdge.value);
|
||
|
removeEdge(srcId, outEdge.id, outEdge.value);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
// Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
|
||
|
// replaced by a private memref). These edges could come from nodes
|
||
|
// other than 'srcId' which were removed in the previous step.
|
||
|
if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) {
|
||
|
SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
|
||
|
for (auto &inEdge : oldInEdges)
|
||
|
if (privateMemRefs.count(inEdge.value) > 0)
|
||
|
removeEdge(inEdge.id, dstId, inEdge.value);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
|
||
|
// of sibling node 'sibId' into node 'dstId'.
|
||
|
void MemRefDependenceGraph::updateEdges(unsigned sibId, unsigned dstId) {
|
||
|
// For each edge in 'inEdges[sibId]':
|
||
|
// *) Add new edge from source node 'inEdge.id' to 'dstNode'.
|
||
|
// *) Remove edge from source node 'inEdge.id' to 'sibNode'.
|
||
|
if (inEdges.count(sibId) > 0) {
|
||
|
SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
|
||
|
for (auto &inEdge : oldInEdges) {
|
||
|
addEdge(inEdge.id, dstId, inEdge.value);
|
||
|
removeEdge(inEdge.id, sibId, inEdge.value);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// For each edge in 'outEdges[sibId]' to node 'id'
|
||
|
// *) Add new edge from 'dstId' to 'outEdge.id'.
|
||
|
// *) Remove edge from 'sibId' to 'outEdge.id'.
|
||
|
if (outEdges.count(sibId) > 0) {
|
||
|
SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
|
||
|
for (auto &outEdge : oldOutEdges) {
|
||
|
addEdge(dstId, outEdge.id, outEdge.value);
|
||
|
removeEdge(sibId, outEdge.id, outEdge.value);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Adds ops in 'loads' and 'stores' to node at 'id'.
|
||
|
void MemRefDependenceGraph::addToNode(
|
||
|
unsigned id, const SmallVectorImpl<Operation *> &loads,
|
||
|
const SmallVectorImpl<Operation *> &stores) {
|
||
|
Node *node = getNode(id);
|
||
|
llvm::append_range(node->loads, loads);
|
||
|
llvm::append_range(node->stores, stores);
|
||
|
}
|
||
|
|
||
|
void MemRefDependenceGraph::clearNodeLoadAndStores(unsigned id) {
|
||
|
Node *node = getNode(id);
|
||
|
node->loads.clear();
|
||
|
node->stores.clear();
|
||
|
}
|
||
|
|
||
|
// Calls 'callback' for each input edge incident to node 'id' which carries a
|
||
|
// memref dependence.
|
||
|
void MemRefDependenceGraph::forEachMemRefInputEdge(
|
||
|
unsigned id, const std::function<void(Edge)> &callback) {
|
||
|
if (inEdges.count(id) > 0)
|
||
|
forEachMemRefEdge(inEdges[id], callback);
|
||
|
}
|
||
|
|
||
|
// Calls 'callback' for each output edge from node 'id' which carries a
|
||
|
// memref dependence.
|
||
|
void MemRefDependenceGraph::forEachMemRefOutputEdge(
|
||
|
unsigned id, const std::function<void(Edge)> &callback) {
|
||
|
if (outEdges.count(id) > 0)
|
||
|
forEachMemRefEdge(outEdges[id], callback);
|
||
|
}
|
||
|
|
||
|
// Calls 'callback' for each edge in 'edges' which carries a memref
|
||
|
// dependence.
|
||
|
void MemRefDependenceGraph::forEachMemRefEdge(
|
||
|
ArrayRef<Edge> edges, const std::function<void(Edge)> &callback) {
|
||
|
for (const auto &edge : edges) {
|
||
|
// Skip if 'edge' is not a memref dependence edge.
|
||
|
if (!isa<MemRefType>(edge.value.getType()))
|
||
|
continue;
|
||
|
assert(nodes.count(edge.id) > 0);
|
||
|
// Skip if 'edge.id' is not a loop nest.
|
||
|
if (!isa<AffineForOp>(getNode(edge.id)->op))
|
||
|
continue;
|
||
|
// Visit current input edge 'edge'.
|
||
|
callback(edge);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void MemRefDependenceGraph::print(raw_ostream &os) const {
|
||
|
os << "\nMemRefDependenceGraph\n";
|
||
|
os << "\nNodes:\n";
|
||
|
for (const auto &idAndNode : nodes) {
|
||
|
os << "Node: " << idAndNode.first << "\n";
|
||
|
auto it = inEdges.find(idAndNode.first);
|
||
|
if (it != inEdges.end()) {
|
||
|
for (const auto &e : it->second)
|
||
|
os << " InEdge: " << e.id << " " << e.value << "\n";
|
||
|
}
|
||
|
it = outEdges.find(idAndNode.first);
|
||
|
if (it != outEdges.end()) {
|
||
|
for (const auto &e : it->second)
|
||
|
os << " OutEdge: " << e.id << " " << e.value << "\n";
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void mlir::affine::getAffineForIVs(Operation &op,
|
||
|
SmallVectorImpl<AffineForOp> *loops) {
|
||
|
auto *currOp = op.getParentOp();
|
||
|
AffineForOp currAffineForOp;
|
||
|
// Traverse up the hierarchy collecting all 'affine.for' operation while
|
||
|
// skipping over 'affine.if' operations.
|
||
|
while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
|
||
|
if (auto currAffineForOp = dyn_cast<AffineForOp>(currOp))
|
||
|
loops->push_back(currAffineForOp);
|
||
|
currOp = currOp->getParentOp();
|
||
|
}
|
||
|
std::reverse(loops->begin(), loops->end());
|
||
|
}
|
||
|
|
||
|
void mlir::affine::getEnclosingAffineOps(Operation &op,
|
||
|
SmallVectorImpl<Operation *> *ops) {
|
||
|
ops->clear();
|
||
|
Operation *currOp = op.getParentOp();
|
||
|
|
||
|
// Traverse up the hierarchy collecting all `affine.for`, `affine.if`, and
|
||
|
// affine.parallel operations.
|
||
|
while (currOp && !currOp->hasTrait<OpTrait::AffineScope>()) {
|
||
|
if (isa<AffineIfOp, AffineForOp, AffineParallelOp>(currOp))
|
||
|
ops->push_back(currOp);
|
||
|
currOp = currOp->getParentOp();
|
||
|
}
|
||
|
std::reverse(ops->begin(), ops->end());
|
||
|
}
|
||
|
|
||
|
// Populates 'cst' with FlatAffineValueConstraints which represent original
|
||
|
// domain of the loop bounds that define 'ivs'.
|
||
|
LogicalResult ComputationSliceState::getSourceAsConstraints(
|
||
|
FlatAffineValueConstraints &cst) const {
|
||
|
assert(!ivs.empty() && "Cannot have a slice without its IVs");
|
||
|
cst = FlatAffineValueConstraints(/*numDims=*/ivs.size(), /*numSymbols=*/0,
|
||
|
/*numLocals=*/0, ivs);
|
||
|
for (Value iv : ivs) {
|
||
|
AffineForOp loop = getForInductionVarOwner(iv);
|
||
|
assert(loop && "Expected affine for");
|
||
|
if (failed(cst.addAffineForOpDomain(loop)))
|
||
|
return failure();
|
||
|
}
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
// Populates 'cst' with FlatAffineValueConstraints which represent slice bounds.
|
||
|
LogicalResult
|
||
|
ComputationSliceState::getAsConstraints(FlatAffineValueConstraints *cst) const {
|
||
|
assert(!lbOperands.empty());
|
||
|
// Adds src 'ivs' as dimension variables in 'cst'.
|
||
|
unsigned numDims = ivs.size();
|
||
|
// Adds operands (dst ivs and symbols) as symbols in 'cst'.
|
||
|
unsigned numSymbols = lbOperands[0].size();
|
||
|
|
||
|
SmallVector<Value, 4> values(ivs);
|
||
|
// Append 'ivs' then 'operands' to 'values'.
|
||
|
values.append(lbOperands[0].begin(), lbOperands[0].end());
|
||
|
*cst = FlatAffineValueConstraints(numDims, numSymbols, 0, values);
|
||
|
|
||
|
// Add loop bound constraints for values which are loop IVs of the destination
|
||
|
// of fusion and equality constraints for symbols which are constants.
|
||
|
for (unsigned i = numDims, end = values.size(); i < end; ++i) {
|
||
|
Value value = values[i];
|
||
|
assert(cst->containsVar(value) && "value expected to be present");
|
||
|
if (isValidSymbol(value)) {
|
||
|
// Check if the symbol is a constant.
|
||
|
if (std::optional<int64_t> cOp = getConstantIntValue(value))
|
||
|
cst->addBound(BoundType::EQ, value, cOp.value());
|
||
|
} else if (auto loop = getForInductionVarOwner(value)) {
|
||
|
if (failed(cst->addAffineForOpDomain(loop)))
|
||
|
return failure();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Add slices bounds on 'ivs' using maps 'lbs'/'ubs' with 'lbOperands[0]'
|
||
|
LogicalResult ret = cst->addSliceBounds(ivs, lbs, ubs, lbOperands[0]);
|
||
|
assert(succeeded(ret) &&
|
||
|
"should not fail as we never have semi-affine slice maps");
|
||
|
(void)ret;
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
// Clears state bounds and operand state.
|
||
|
void ComputationSliceState::clearBounds() {
|
||
|
lbs.clear();
|
||
|
ubs.clear();
|
||
|
lbOperands.clear();
|
||
|
ubOperands.clear();
|
||
|
}
|
||
|
|
||
|
void ComputationSliceState::dump() const {
|
||
|
llvm::errs() << "\tIVs:\n";
|
||
|
for (Value iv : ivs)
|
||
|
llvm::errs() << "\t\t" << iv << "\n";
|
||
|
|
||
|
llvm::errs() << "\tLBs:\n";
|
||
|
for (auto en : llvm::enumerate(lbs)) {
|
||
|
llvm::errs() << "\t\t" << en.value() << "\n";
|
||
|
llvm::errs() << "\t\tOperands:\n";
|
||
|
for (Value lbOp : lbOperands[en.index()])
|
||
|
llvm::errs() << "\t\t\t" << lbOp << "\n";
|
||
|
}
|
||
|
|
||
|
llvm::errs() << "\tUBs:\n";
|
||
|
for (auto en : llvm::enumerate(ubs)) {
|
||
|
llvm::errs() << "\t\t" << en.value() << "\n";
|
||
|
llvm::errs() << "\t\tOperands:\n";
|
||
|
for (Value ubOp : ubOperands[en.index()])
|
||
|
llvm::errs() << "\t\t\t" << ubOp << "\n";
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/// Fast check to determine if the computation slice is maximal. Returns true if
|
||
|
/// each slice dimension maps to an existing dst dimension and both the src
|
||
|
/// and the dst loops for those dimensions have the same bounds. Returns false
|
||
|
/// if both the src and the dst loops don't have the same bounds. Returns
|
||
|
/// std::nullopt if none of the above can be proven.
|
||
|
std::optional<bool> ComputationSliceState::isSliceMaximalFastCheck() const {
|
||
|
assert(lbs.size() == ubs.size() && !lbs.empty() && !ivs.empty() &&
|
||
|
"Unexpected number of lbs, ubs and ivs in slice");
|
||
|
|
||
|
for (unsigned i = 0, end = lbs.size(); i < end; ++i) {
|
||
|
AffineMap lbMap = lbs[i];
|
||
|
AffineMap ubMap = ubs[i];
|
||
|
|
||
|
// Check if this slice is just an equality along this dimension.
|
||
|
if (!lbMap || !ubMap || lbMap.getNumResults() != 1 ||
|
||
|
ubMap.getNumResults() != 1 ||
|
||
|
lbMap.getResult(0) + 1 != ubMap.getResult(0) ||
|
||
|
// The condition above will be true for maps describing a single
|
||
|
// iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
|
||
|
// Make sure we skip those cases by checking that the lb result is not
|
||
|
// just a constant.
|
||
|
isa<AffineConstantExpr>(lbMap.getResult(0)))
|
||
|
return std::nullopt;
|
||
|
|
||
|
// Limited support: we expect the lb result to be just a loop dimension for
|
||
|
// now.
|
||
|
AffineDimExpr result = dyn_cast<AffineDimExpr>(lbMap.getResult(0));
|
||
|
if (!result)
|
||
|
return std::nullopt;
|
||
|
|
||
|
// Retrieve dst loop bounds.
|
||
|
AffineForOp dstLoop =
|
||
|
getForInductionVarOwner(lbOperands[i][result.getPosition()]);
|
||
|
if (!dstLoop)
|
||
|
return std::nullopt;
|
||
|
AffineMap dstLbMap = dstLoop.getLowerBoundMap();
|
||
|
AffineMap dstUbMap = dstLoop.getUpperBoundMap();
|
||
|
|
||
|
// Retrieve src loop bounds.
|
||
|
AffineForOp srcLoop = getForInductionVarOwner(ivs[i]);
|
||
|
assert(srcLoop && "Expected affine for");
|
||
|
AffineMap srcLbMap = srcLoop.getLowerBoundMap();
|
||
|
AffineMap srcUbMap = srcLoop.getUpperBoundMap();
|
||
|
|
||
|
// Limited support: we expect simple src and dst loops with a single
|
||
|
// constant component per bound for now.
|
||
|
if (srcLbMap.getNumResults() != 1 || srcUbMap.getNumResults() != 1 ||
|
||
|
dstLbMap.getNumResults() != 1 || dstUbMap.getNumResults() != 1)
|
||
|
return std::nullopt;
|
||
|
|
||
|
AffineExpr srcLbResult = srcLbMap.getResult(0);
|
||
|
AffineExpr dstLbResult = dstLbMap.getResult(0);
|
||
|
AffineExpr srcUbResult = srcUbMap.getResult(0);
|
||
|
AffineExpr dstUbResult = dstUbMap.getResult(0);
|
||
|
if (!isa<AffineConstantExpr>(srcLbResult) ||
|
||
|
!isa<AffineConstantExpr>(srcUbResult) ||
|
||
|
!isa<AffineConstantExpr>(dstLbResult) ||
|
||
|
!isa<AffineConstantExpr>(dstUbResult))
|
||
|
return std::nullopt;
|
||
|
|
||
|
// Check if src and dst loop bounds are the same. If not, we can guarantee
|
||
|
// that the slice is not maximal.
|
||
|
if (srcLbResult != dstLbResult || srcUbResult != dstUbResult ||
|
||
|
srcLoop.getStep() != dstLoop.getStep())
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
/// Returns true if it is deterministically verified that the original iteration
|
||
|
/// space of the slice is contained within the new iteration space that is
|
||
|
/// created after fusing 'this' slice into its destination.
|
||
|
std::optional<bool> ComputationSliceState::isSliceValid() const {
|
||
|
// Fast check to determine if the slice is valid. If the following conditions
|
||
|
// are verified to be true, slice is declared valid by the fast check:
|
||
|
// 1. Each slice loop is a single iteration loop bound in terms of a single
|
||
|
// destination loop IV.
|
||
|
// 2. Loop bounds of the destination loop IV (from above) and those of the
|
||
|
// source loop IV are exactly the same.
|
||
|
// If the fast check is inconclusive or false, we proceed with a more
|
||
|
// expensive analysis.
|
||
|
// TODO: Store the result of the fast check, as it might be used again in
|
||
|
// `canRemoveSrcNodeAfterFusion`.
|
||
|
std::optional<bool> isValidFastCheck = isSliceMaximalFastCheck();
|
||
|
if (isValidFastCheck && *isValidFastCheck)
|
||
|
return true;
|
||
|
|
||
|
// Create constraints for the source loop nest using which slice is computed.
|
||
|
FlatAffineValueConstraints srcConstraints;
|
||
|
// TODO: Store the source's domain to avoid computation at each depth.
|
||
|
if (failed(getSourceAsConstraints(srcConstraints))) {
|
||
|
LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n");
|
||
|
return std::nullopt;
|
||
|
}
|
||
|
// As the set difference utility currently cannot handle symbols in its
|
||
|
// operands, validity of the slice cannot be determined.
|
||
|
if (srcConstraints.getNumSymbolVars() > 0) {
|
||
|
LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n");
|
||
|
return std::nullopt;
|
||
|
}
|
||
|
// TODO: Handle local vars in the source domains while using the 'projectOut'
|
||
|
// utility below. Currently, aligning is not done assuming that there will be
|
||
|
// no local vars in the source domain.
|
||
|
if (srcConstraints.getNumLocalVars() != 0) {
|
||
|
LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n");
|
||
|
return std::nullopt;
|
||
|
}
|
||
|
|
||
|
// Create constraints for the slice loop nest that would be created if the
|
||
|
// fusion succeeds.
|
||
|
FlatAffineValueConstraints sliceConstraints;
|
||
|
if (failed(getAsConstraints(&sliceConstraints))) {
|
||
|
LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n");
|
||
|
return std::nullopt;
|
||
|
}
|
||
|
|
||
|
// Projecting out every dimension other than the 'ivs' to express slice's
|
||
|
// domain completely in terms of source's IVs.
|
||
|
sliceConstraints.projectOut(ivs.size(),
|
||
|
sliceConstraints.getNumVars() - ivs.size());
|
||
|
|
||
|
LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n");
|
||
|
LLVM_DEBUG(srcConstraints.dump());
|
||
|
LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds "
|
||
|
"(expressed in terms of its source's IVs):\n");
|
||
|
LLVM_DEBUG(sliceConstraints.dump());
|
||
|
|
||
|
// TODO: Store 'srcSet' to avoid recalculating for each depth.
|
||
|
PresburgerSet srcSet(srcConstraints);
|
||
|
PresburgerSet sliceSet(sliceConstraints);
|
||
|
PresburgerSet diffSet = sliceSet.subtract(srcSet);
|
||
|
|
||
|
if (!diffSet.isIntegerEmpty()) {
|
||
|
LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n");
|
||
|
return false;
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
/// Returns true if the computation slice encloses all the iterations of the
|
||
|
/// sliced loop nest. Returns false if it does not. Returns std::nullopt if it
|
||
|
/// cannot determine if the slice is maximal or not.
|
||
|
std::optional<bool> ComputationSliceState::isMaximal() const {
|
||
|
// Fast check to determine if the computation slice is maximal. If the result
|
||
|
// is inconclusive, we proceed with a more expensive analysis.
|
||
|
std::optional<bool> isMaximalFastCheck = isSliceMaximalFastCheck();
|
||
|
if (isMaximalFastCheck)
|
||
|
return isMaximalFastCheck;
|
||
|
|
||
|
// Create constraints for the src loop nest being sliced.
|
||
|
FlatAffineValueConstraints srcConstraints(/*numDims=*/ivs.size(),
|
||
|
/*numSymbols=*/0,
|
||
|
/*numLocals=*/0, ivs);
|
||
|
for (Value iv : ivs) {
|
||
|
AffineForOp loop = getForInductionVarOwner(iv);
|
||
|
assert(loop && "Expected affine for");
|
||
|
if (failed(srcConstraints.addAffineForOpDomain(loop)))
|
||
|
return std::nullopt;
|
||
|
}
|
||
|
|
||
|
// Create constraints for the slice using the dst loop nest information. We
|
||
|
// retrieve existing dst loops from the lbOperands.
|
||
|
SmallVector<Value> consumerIVs;
|
||
|
for (Value lbOp : lbOperands[0])
|
||
|
if (getForInductionVarOwner(lbOp))
|
||
|
consumerIVs.push_back(lbOp);
|
||
|
|
||
|
// Add empty IV Values for those new loops that are not equalities and,
|
||
|
// therefore, are not yet materialized in the IR.
|
||
|
for (int i = consumerIVs.size(), end = ivs.size(); i < end; ++i)
|
||
|
consumerIVs.push_back(Value());
|
||
|
|
||
|
FlatAffineValueConstraints sliceConstraints(/*numDims=*/consumerIVs.size(),
|
||
|
/*numSymbols=*/0,
|
||
|
/*numLocals=*/0, consumerIVs);
|
||
|
|
||
|
if (failed(sliceConstraints.addDomainFromSliceMaps(lbs, ubs, lbOperands[0])))
|
||
|
return std::nullopt;
|
||
|
|
||
|
if (srcConstraints.getNumDimVars() != sliceConstraints.getNumDimVars())
|
||
|
// Constraint dims are different. The integer set difference can't be
|
||
|
// computed so we don't know if the slice is maximal.
|
||
|
return std::nullopt;
|
||
|
|
||
|
// Compute the difference between the src loop nest and the slice integer
|
||
|
// sets.
|
||
|
PresburgerSet srcSet(srcConstraints);
|
||
|
PresburgerSet sliceSet(sliceConstraints);
|
||
|
PresburgerSet diffSet = srcSet.subtract(sliceSet);
|
||
|
return diffSet.isIntegerEmpty();
|
||
|
}
|
||
|
|
||
|
unsigned MemRefRegion::getRank() const {
|
||
|
return cast<MemRefType>(memref.getType()).getRank();
|
||
|
}
|
||
|
|
||
|
std::optional<int64_t> MemRefRegion::getConstantBoundingSizeAndShape(
|
||
|
SmallVectorImpl<int64_t> *shape, std::vector<SmallVector<int64_t, 4>> *lbs,
|
||
|
SmallVectorImpl<int64_t> *lbDivisors) const {
|
||
|
auto memRefType = cast<MemRefType>(memref.getType());
|
||
|
unsigned rank = memRefType.getRank();
|
||
|
if (shape)
|
||
|
shape->reserve(rank);
|
||
|
|
||
|
assert(rank == cst.getNumDimVars() && "inconsistent memref region");
|
||
|
|
||
|
// Use a copy of the region constraints that has upper/lower bounds for each
|
||
|
// memref dimension with static size added to guard against potential
|
||
|
// over-approximation from projection or union bounding box. We may not add
|
||
|
// this on the region itself since they might just be redundant constraints
|
||
|
// that will need non-trivials means to eliminate.
|
||
|
FlatAffineValueConstraints cstWithShapeBounds(cst);
|
||
|
for (unsigned r = 0; r < rank; r++) {
|
||
|
cstWithShapeBounds.addBound(BoundType::LB, r, 0);
|
||
|
int64_t dimSize = memRefType.getDimSize(r);
|
||
|
if (ShapedType::isDynamic(dimSize))
|
||
|
continue;
|
||
|
cstWithShapeBounds.addBound(BoundType::UB, r, dimSize - 1);
|
||
|
}
|
||
|
|
||
|
// Find a constant upper bound on the extent of this memref region along each
|
||
|
// dimension.
|
||
|
int64_t numElements = 1;
|
||
|
int64_t diffConstant;
|
||
|
int64_t lbDivisor;
|
||
|
for (unsigned d = 0; d < rank; d++) {
|
||
|
SmallVector<int64_t, 4> lb;
|
||
|
std::optional<int64_t> diff =
|
||
|
cstWithShapeBounds.getConstantBoundOnDimSize64(d, &lb, &lbDivisor);
|
||
|
if (diff.has_value()) {
|
||
|
diffConstant = *diff;
|
||
|
assert(diffConstant >= 0 && "Dim size bound can't be negative");
|
||
|
assert(lbDivisor > 0);
|
||
|
} else {
|
||
|
// If no constant bound is found, then it can always be bound by the
|
||
|
// memref's dim size if the latter has a constant size along this dim.
|
||
|
auto dimSize = memRefType.getDimSize(d);
|
||
|
if (dimSize == ShapedType::kDynamic)
|
||
|
return std::nullopt;
|
||
|
diffConstant = dimSize;
|
||
|
// Lower bound becomes 0.
|
||
|
lb.resize(cstWithShapeBounds.getNumSymbolVars() + 1, 0);
|
||
|
lbDivisor = 1;
|
||
|
}
|
||
|
numElements *= diffConstant;
|
||
|
if (lbs) {
|
||
|
lbs->push_back(lb);
|
||
|
assert(lbDivisors && "both lbs and lbDivisor or none");
|
||
|
lbDivisors->push_back(lbDivisor);
|
||
|
}
|
||
|
if (shape) {
|
||
|
shape->push_back(diffConstant);
|
||
|
}
|
||
|
}
|
||
|
return numElements;
|
||
|
}
|
||
|
|
||
|
void MemRefRegion::getLowerAndUpperBound(unsigned pos, AffineMap &lbMap,
|
||
|
AffineMap &ubMap) const {
|
||
|
assert(pos < cst.getNumDimVars() && "invalid position");
|
||
|
auto memRefType = cast<MemRefType>(memref.getType());
|
||
|
unsigned rank = memRefType.getRank();
|
||
|
|
||
|
assert(rank == cst.getNumDimVars() && "inconsistent memref region");
|
||
|
|
||
|
auto boundPairs = cst.getLowerAndUpperBound(
|
||
|
pos, /*offset=*/0, /*num=*/rank, cst.getNumDimAndSymbolVars(),
|
||
|
/*localExprs=*/{}, memRefType.getContext());
|
||
|
lbMap = boundPairs.first;
|
||
|
ubMap = boundPairs.second;
|
||
|
assert(lbMap && "lower bound for a region must exist");
|
||
|
assert(ubMap && "upper bound for a region must exist");
|
||
|
assert(lbMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank);
|
||
|
assert(ubMap.getNumInputs() == cst.getNumDimAndSymbolVars() - rank);
|
||
|
}
|
||
|
|
||
|
LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) {
|
||
|
assert(memref == other.memref);
|
||
|
return cst.unionBoundingBox(*other.getConstraints());
|
||
|
}
|
||
|
|
||
|
/// Computes the memory region accessed by this memref with the region
|
||
|
/// represented as constraints symbolic/parametric in 'loopDepth' loops
|
||
|
/// surrounding opInst and any additional Function symbols.
|
||
|
// For example, the memref region for this load operation at loopDepth = 1 will
|
||
|
// be as below:
|
||
|
//
|
||
|
// affine.for %i = 0 to 32 {
|
||
|
// affine.for %ii = %i to (d0) -> (d0 + 8) (%i) {
|
||
|
// load %A[%ii]
|
||
|
// }
|
||
|
// }
|
||
|
//
|
||
|
// region: {memref = %A, write = false, {%i <= m0 <= %i + 7} }
|
||
|
// The last field is a 2-d FlatAffineValueConstraints symbolic in %i.
|
||
|
//
|
||
|
// TODO: extend this to any other memref dereferencing ops
|
||
|
// (dma_start, dma_wait).
|
||
|
LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
|
||
|
const ComputationSliceState *sliceState,
|
||
|
bool addMemRefDimBounds) {
|
||
|
assert((isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) &&
|
||
|
"affine read/write op expected");
|
||
|
|
||
|
MemRefAccess access(op);
|
||
|
memref = access.memref;
|
||
|
write = access.isStore();
|
||
|
|
||
|
unsigned rank = access.getRank();
|
||
|
|
||
|
LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op
|
||
|
<< "\ndepth: " << loopDepth << "\n";);
|
||
|
|
||
|
// 0-d memrefs.
|
||
|
if (rank == 0) {
|
||
|
SmallVector<Value, 4> ivs;
|
||
|
getAffineIVs(*op, ivs);
|
||
|
assert(loopDepth <= ivs.size() && "invalid 'loopDepth'");
|
||
|
// The first 'loopDepth' IVs are symbols for this region.
|
||
|
ivs.resize(loopDepth);
|
||
|
// A 0-d memref has a 0-d region.
|
||
|
cst = FlatAffineValueConstraints(rank, loopDepth, /*numLocals=*/0, ivs);
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
// Build the constraints for this region.
|
||
|
AffineValueMap accessValueMap;
|
||
|
access.getAccessMap(&accessValueMap);
|
||
|
AffineMap accessMap = accessValueMap.getAffineMap();
|
||
|
|
||
|
unsigned numDims = accessMap.getNumDims();
|
||
|
unsigned numSymbols = accessMap.getNumSymbols();
|
||
|
unsigned numOperands = accessValueMap.getNumOperands();
|
||
|
// Merge operands with slice operands.
|
||
|
SmallVector<Value, 4> operands;
|
||
|
operands.resize(numOperands);
|
||
|
for (unsigned i = 0; i < numOperands; ++i)
|
||
|
operands[i] = accessValueMap.getOperand(i);
|
||
|
|
||
|
if (sliceState != nullptr) {
|
||
|
operands.reserve(operands.size() + sliceState->lbOperands[0].size());
|
||
|
// Append slice operands to 'operands' as symbols.
|
||
|
for (auto extraOperand : sliceState->lbOperands[0]) {
|
||
|
if (!llvm::is_contained(operands, extraOperand)) {
|
||
|
operands.push_back(extraOperand);
|
||
|
numSymbols++;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
// We'll first associate the dims and symbols of the access map to the dims
|
||
|
// and symbols resp. of cst. This will change below once cst is
|
||
|
// fully constructed out.
|
||
|
cst = FlatAffineValueConstraints(numDims, numSymbols, 0, operands);
|
||
|
|
||
|
// Add equality constraints.
|
||
|
// Add inequalities for loop lower/upper bounds.
|
||
|
for (unsigned i = 0; i < numDims + numSymbols; ++i) {
|
||
|
auto operand = operands[i];
|
||
|
if (auto affineFor = getForInductionVarOwner(operand)) {
|
||
|
// Note that cst can now have more dimensions than accessMap if the
|
||
|
// bounds expressions involve outer loops or other symbols.
|
||
|
// TODO: rewrite this to use getInstIndexSet; this way
|
||
|
// conditionals will be handled when the latter supports it.
|
||
|
if (failed(cst.addAffineForOpDomain(affineFor)))
|
||
|
return failure();
|
||
|
} else if (auto parallelOp = getAffineParallelInductionVarOwner(operand)) {
|
||
|
if (failed(cst.addAffineParallelOpDomain(parallelOp)))
|
||
|
return failure();
|
||
|
} else if (isValidSymbol(operand)) {
|
||
|
// Check if the symbol is a constant.
|
||
|
Value symbol = operand;
|
||
|
if (auto constVal = getConstantIntValue(symbol))
|
||
|
cst.addBound(BoundType::EQ, symbol, constVal.value());
|
||
|
} else {
|
||
|
LLVM_DEBUG(llvm::dbgs() << "unknown affine dimensional value");
|
||
|
return failure();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Add lower/upper bounds on loop IVs using bounds from 'sliceState'.
|
||
|
if (sliceState != nullptr) {
|
||
|
// Add dim and symbol slice operands.
|
||
|
for (auto operand : sliceState->lbOperands[0]) {
|
||
|
cst.addInductionVarOrTerminalSymbol(operand);
|
||
|
}
|
||
|
// Add upper/lower bounds from 'sliceState' to 'cst'.
|
||
|
LogicalResult ret =
|
||
|
cst.addSliceBounds(sliceState->ivs, sliceState->lbs, sliceState->ubs,
|
||
|
sliceState->lbOperands[0]);
|
||
|
assert(succeeded(ret) &&
|
||
|
"should not fail as we never have semi-affine slice maps");
|
||
|
(void)ret;
|
||
|
}
|
||
|
|
||
|
// Add access function equalities to connect loop IVs to data dimensions.
|
||
|
if (failed(cst.composeMap(&accessValueMap))) {
|
||
|
op->emitError("getMemRefRegion: compose affine map failed");
|
||
|
LLVM_DEBUG(accessValueMap.getAffineMap().dump());
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
// Set all variables appearing after the first 'rank' variables as
|
||
|
// symbolic variables - so that the ones corresponding to the memref
|
||
|
// dimensions are the dimensional variables for the memref region.
|
||
|
cst.setDimSymbolSeparation(cst.getNumDimAndSymbolVars() - rank);
|
||
|
|
||
|
// Eliminate any loop IVs other than the outermost 'loopDepth' IVs, on which
|
||
|
// this memref region is symbolic.
|
||
|
SmallVector<Value, 4> enclosingIVs;
|
||
|
getAffineIVs(*op, enclosingIVs);
|
||
|
assert(loopDepth <= enclosingIVs.size() && "invalid loop depth");
|
||
|
enclosingIVs.resize(loopDepth);
|
||
|
SmallVector<Value, 4> vars;
|
||
|
cst.getValues(cst.getNumDimVars(), cst.getNumDimAndSymbolVars(), &vars);
|
||
|
for (Value var : vars) {
|
||
|
if ((isAffineInductionVar(var)) && !llvm::is_contained(enclosingIVs, var)) {
|
||
|
cst.projectOut(var);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Project out any local variables (these would have been added for any
|
||
|
// mod/divs).
|
||
|
cst.projectOut(cst.getNumDimAndSymbolVars(), cst.getNumLocalVars());
|
||
|
|
||
|
// Constant fold any symbolic variables.
|
||
|
cst.constantFoldVarRange(/*pos=*/cst.getNumDimVars(),
|
||
|
/*num=*/cst.getNumSymbolVars());
|
||
|
|
||
|
assert(cst.getNumDimVars() == rank && "unexpected MemRefRegion format");
|
||
|
|
||
|
// Add upper/lower bounds for each memref dimension with static size
|
||
|
// to guard against potential over-approximation from projection.
|
||
|
// TODO: Support dynamic memref dimensions.
|
||
|
if (addMemRefDimBounds) {
|
||
|
auto memRefType = cast<MemRefType>(memref.getType());
|
||
|
for (unsigned r = 0; r < rank; r++) {
|
||
|
cst.addBound(BoundType::LB, /*pos=*/r, /*value=*/0);
|
||
|
if (memRefType.isDynamicDim(r))
|
||
|
continue;
|
||
|
cst.addBound(BoundType::UB, /*pos=*/r, memRefType.getDimSize(r) - 1);
|
||
|
}
|
||
|
}
|
||
|
cst.removeTrivialRedundancy();
|
||
|
|
||
|
LLVM_DEBUG(llvm::dbgs() << "Memory region:\n");
|
||
|
LLVM_DEBUG(cst.dump());
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
std::optional<int64_t>
|
||
|
mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) {
|
||
|
auto elementType = memRefType.getElementType();
|
||
|
|
||
|
unsigned sizeInBits;
|
||
|
if (elementType.isIntOrFloat()) {
|
||
|
sizeInBits = elementType.getIntOrFloatBitWidth();
|
||
|
} else if (auto vectorType = dyn_cast<VectorType>(elementType)) {
|
||
|
if (vectorType.getElementType().isIntOrFloat())
|
||
|
sizeInBits =
|
||
|
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
|
||
|
else
|
||
|
return std::nullopt;
|
||
|
} else {
|
||
|
return std::nullopt;
|
||
|
}
|
||
|
return llvm::divideCeil(sizeInBits, 8);
|
||
|
}
|
||
|
|
||
|
// Returns the size of the region.
|
||
|
std::optional<int64_t> MemRefRegion::getRegionSize() {
|
||
|
auto memRefType = cast<MemRefType>(memref.getType());
|
||
|
|
||
|
if (!memRefType.getLayout().isIdentity()) {
|
||
|
LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
// Indices to use for the DmaStart op.
|
||
|
// Indices for the original memref being DMAed from/to.
|
||
|
SmallVector<Value, 4> memIndices;
|
||
|
// Indices for the faster buffer being DMAed into/from.
|
||
|
SmallVector<Value, 4> bufIndices;
|
||
|
|
||
|
// Compute the extents of the buffer.
|
||
|
std::optional<int64_t> numElements = getConstantBoundingSizeAndShape();
|
||
|
if (!numElements) {
|
||
|
LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n");
|
||
|
return std::nullopt;
|
||
|
}
|
||
|
auto eltSize = getMemRefIntOrFloatEltSizeInBytes(memRefType);
|
||
|
if (!eltSize)
|
||
|
return std::nullopt;
|
||
|
return *eltSize * *numElements;
|
||
|
}
|
||
|
|
||
|
/// Returns the size of memref data in bytes if it's statically shaped,
|
||
|
/// std::nullopt otherwise. If the element of the memref has vector type, takes
|
||
|
/// into account size of the vector as well.
|
||
|
// TODO: improve/complete this when we have target data.
|
||
|
std::optional<uint64_t>
|
||
|
mlir::affine::getIntOrFloatMemRefSizeInBytes(MemRefType memRefType) {
|
||
|
if (!memRefType.hasStaticShape())
|
||
|
return std::nullopt;
|
||
|
auto elementType = memRefType.getElementType();
|
||
|
if (!elementType.isIntOrFloat() && !isa<VectorType>(elementType))
|
||
|
return std::nullopt;
|
||
|
|
||
|
auto sizeInBytes = getMemRefIntOrFloatEltSizeInBytes(memRefType);
|
||
|
if (!sizeInBytes)
|
||
|
return std::nullopt;
|
||
|
for (unsigned i = 0, e = memRefType.getRank(); i < e; i++) {
|
||
|
sizeInBytes = *sizeInBytes * memRefType.getDimSize(i);
|
||
|
}
|
||
|
return sizeInBytes;
|
||
|
}
|
||
|
|
||
|
template <typename LoadOrStoreOp>
|
||
|
LogicalResult mlir::affine::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp,
|
||
|
bool emitError) {
|
||
|
static_assert(llvm::is_one_of<LoadOrStoreOp, AffineReadOpInterface,
|
||
|
AffineWriteOpInterface>::value,
|
||
|
"argument should be either a AffineReadOpInterface or a "
|
||
|
"AffineWriteOpInterface");
|
||
|
|
||
|
Operation *op = loadOrStoreOp.getOperation();
|
||
|
MemRefRegion region(op->getLoc());
|
||
|
if (failed(region.compute(op, /*loopDepth=*/0, /*sliceState=*/nullptr,
|
||
|
/*addMemRefDimBounds=*/false)))
|
||
|
return success();
|
||
|
|
||
|
LLVM_DEBUG(llvm::dbgs() << "Memory region");
|
||
|
LLVM_DEBUG(region.getConstraints()->dump());
|
||
|
|
||
|
bool outOfBounds = false;
|
||
|
unsigned rank = loadOrStoreOp.getMemRefType().getRank();
|
||
|
|
||
|
// For each dimension, check for out of bounds.
|
||
|
for (unsigned r = 0; r < rank; r++) {
|
||
|
FlatAffineValueConstraints ucst(*region.getConstraints());
|
||
|
|
||
|
// Intersect memory region with constraint capturing out of bounds (both out
|
||
|
// of upper and out of lower), and check if the constraint system is
|
||
|
// feasible. If it is, there is at least one point out of bounds.
|
||
|
SmallVector<int64_t, 4> ineq(rank + 1, 0);
|
||
|
int64_t dimSize = loadOrStoreOp.getMemRefType().getDimSize(r);
|
||
|
// TODO: handle dynamic dim sizes.
|
||
|
if (dimSize == -1)
|
||
|
continue;
|
||
|
|
||
|
// Check for overflow: d_i >= memref dim size.
|
||
|
ucst.addBound(BoundType::LB, r, dimSize);
|
||
|
outOfBounds = !ucst.isEmpty();
|
||
|
if (outOfBounds && emitError) {
|
||
|
loadOrStoreOp.emitOpError()
|
||
|
<< "memref out of upper bound access along dimension #" << (r + 1);
|
||
|
}
|
||
|
|
||
|
// Check for a negative index.
|
||
|
FlatAffineValueConstraints lcst(*region.getConstraints());
|
||
|
std::fill(ineq.begin(), ineq.end(), 0);
|
||
|
// d_i <= -1;
|
||
|
lcst.addBound(BoundType::UB, r, -1);
|
||
|
outOfBounds = !lcst.isEmpty();
|
||
|
if (outOfBounds && emitError) {
|
||
|
loadOrStoreOp.emitOpError()
|
||
|
<< "memref out of lower bound access along dimension #" << (r + 1);
|
||
|
}
|
||
|
}
|
||
|
return failure(outOfBounds);
|
||
|
}
|
||
|
|
||
|
// Explicitly instantiate the template so that the compiler knows we need them!
|
||
|
template LogicalResult
|
||
|
mlir::affine::boundCheckLoadOrStoreOp(AffineReadOpInterface loadOp,
|
||
|
bool emitError);
|
||
|
template LogicalResult
|
||
|
mlir::affine::boundCheckLoadOrStoreOp(AffineWriteOpInterface storeOp,
|
||
|
bool emitError);
|
||
|
|
||
|
// Returns in 'positions' the Block positions of 'op' in each ancestor
|
||
|
// Block from the Block containing operation, stopping at 'limitBlock'.
|
||
|
static void findInstPosition(Operation *op, Block *limitBlock,
|
||
|
SmallVectorImpl<unsigned> *positions) {
|
||
|
Block *block = op->getBlock();
|
||
|
while (block != limitBlock) {
|
||
|
// FIXME: This algorithm is unnecessarily O(n) and should be improved to not
|
||
|
// rely on linear scans.
|
||
|
int instPosInBlock = std::distance(block->begin(), op->getIterator());
|
||
|
positions->push_back(instPosInBlock);
|
||
|
op = block->getParentOp();
|
||
|
block = op->getBlock();
|
||
|
}
|
||
|
std::reverse(positions->begin(), positions->end());
|
||
|
}
|
||
|
|
||
|
// Returns the Operation in a possibly nested set of Blocks, where the
|
||
|
// position of the operation is represented by 'positions', which has a
|
||
|
// Block position for each level of nesting.
|
||
|
static Operation *getInstAtPosition(ArrayRef<unsigned> positions,
|
||
|
unsigned level, Block *block) {
|
||
|
unsigned i = 0;
|
||
|
for (auto &op : *block) {
|
||
|
if (i != positions[level]) {
|
||
|
++i;
|
||
|
continue;
|
||
|
}
|
||
|
if (level == positions.size() - 1)
|
||
|
return &op;
|
||
|
if (auto childAffineForOp = dyn_cast<AffineForOp>(op))
|
||
|
return getInstAtPosition(positions, level + 1,
|
||
|
childAffineForOp.getBody());
|
||
|
|
||
|
for (auto ®ion : op.getRegions()) {
|
||
|
for (auto &b : region)
|
||
|
if (auto *ret = getInstAtPosition(positions, level + 1, &b))
|
||
|
return ret;
|
||
|
}
|
||
|
return nullptr;
|
||
|
}
|
||
|
return nullptr;
|
||
|
}
|
||
|
|
||
|
// Adds loop IV bounds to 'cst' for loop IVs not found in 'ivs'.
|
||
|
static LogicalResult addMissingLoopIVBounds(SmallPtrSet<Value, 8> &ivs,
|
||
|
FlatAffineValueConstraints *cst) {
|
||
|
for (unsigned i = 0, e = cst->getNumDimVars(); i < e; ++i) {
|
||
|
auto value = cst->getValue(i);
|
||
|
if (ivs.count(value) == 0) {
|
||
|
assert(isAffineForInductionVar(value));
|
||
|
auto loop = getForInductionVarOwner(value);
|
||
|
if (failed(cst->addAffineForOpDomain(loop)))
|
||
|
return failure();
|
||
|
}
|
||
|
}
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
/// Returns the innermost common loop depth for the set of operations in 'ops'.
|
||
|
// TODO: Move this to LoopUtils.
|
||
|
unsigned mlir::affine::getInnermostCommonLoopDepth(
|
||
|
ArrayRef<Operation *> ops, SmallVectorImpl<AffineForOp> *surroundingLoops) {
|
||
|
unsigned numOps = ops.size();
|
||
|
assert(numOps > 0 && "Expected at least one operation");
|
||
|
|
||
|
std::vector<SmallVector<AffineForOp, 4>> loops(numOps);
|
||
|
unsigned loopDepthLimit = std::numeric_limits<unsigned>::max();
|
||
|
for (unsigned i = 0; i < numOps; ++i) {
|
||
|
getAffineForIVs(*ops[i], &loops[i]);
|
||
|
loopDepthLimit =
|
||
|
std::min(loopDepthLimit, static_cast<unsigned>(loops[i].size()));
|
||
|
}
|
||
|
|
||
|
unsigned loopDepth = 0;
|
||
|
for (unsigned d = 0; d < loopDepthLimit; ++d) {
|
||
|
unsigned i;
|
||
|
for (i = 1; i < numOps; ++i) {
|
||
|
if (loops[i - 1][d] != loops[i][d])
|
||
|
return loopDepth;
|
||
|
}
|
||
|
if (surroundingLoops)
|
||
|
surroundingLoops->push_back(loops[i - 1][d]);
|
||
|
++loopDepth;
|
||
|
}
|
||
|
return loopDepth;
|
||
|
}
|
||
|
|
||
|
/// Computes in 'sliceUnion' the union of all slice bounds computed at
|
||
|
/// 'loopDepth' between all dependent pairs of ops in 'opsA' and 'opsB', and
|
||
|
/// then verifies if it is valid. Returns 'SliceComputationResult::Success' if
|
||
|
/// union was computed correctly, an appropriate failure otherwise.
|
||
|
SliceComputationResult
|
||
|
mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA,
|
||
|
ArrayRef<Operation *> opsB, unsigned loopDepth,
|
||
|
unsigned numCommonLoops, bool isBackwardSlice,
|
||
|
ComputationSliceState *sliceUnion) {
|
||
|
// Compute the union of slice bounds between all pairs in 'opsA' and
|
||
|
// 'opsB' in 'sliceUnionCst'.
|
||
|
FlatAffineValueConstraints sliceUnionCst;
|
||
|
assert(sliceUnionCst.getNumDimAndSymbolVars() == 0);
|
||
|
std::vector<std::pair<Operation *, Operation *>> dependentOpPairs;
|
||
|
for (auto *i : opsA) {
|
||
|
MemRefAccess srcAccess(i);
|
||
|
for (auto *j : opsB) {
|
||
|
MemRefAccess dstAccess(j);
|
||
|
if (srcAccess.memref != dstAccess.memref)
|
||
|
continue;
|
||
|
// Check if 'loopDepth' exceeds nesting depth of src/dst ops.
|
||
|
if ((!isBackwardSlice && loopDepth > getNestingDepth(i)) ||
|
||
|
(isBackwardSlice && loopDepth > getNestingDepth(j))) {
|
||
|
LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n");
|
||
|
return SliceComputationResult::GenericFailure;
|
||
|
}
|
||
|
|
||
|
bool readReadAccesses = isa<AffineReadOpInterface>(srcAccess.opInst) &&
|
||
|
isa<AffineReadOpInterface>(dstAccess.opInst);
|
||
|
FlatAffineValueConstraints dependenceConstraints;
|
||
|
// Check dependence between 'srcAccess' and 'dstAccess'.
|
||
|
DependenceResult result = checkMemrefAccessDependence(
|
||
|
srcAccess, dstAccess, /*loopDepth=*/numCommonLoops + 1,
|
||
|
&dependenceConstraints, /*dependenceComponents=*/nullptr,
|
||
|
/*allowRAR=*/readReadAccesses);
|
||
|
if (result.value == DependenceResult::Failure) {
|
||
|
LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n");
|
||
|
return SliceComputationResult::GenericFailure;
|
||
|
}
|
||
|
if (result.value == DependenceResult::NoDependence)
|
||
|
continue;
|
||
|
dependentOpPairs.emplace_back(i, j);
|
||
|
|
||
|
// Compute slice bounds for 'srcAccess' and 'dstAccess'.
|
||
|
ComputationSliceState tmpSliceState;
|
||
|
mlir::affine::getComputationSliceState(i, j, &dependenceConstraints,
|
||
|
loopDepth, isBackwardSlice,
|
||
|
&tmpSliceState);
|
||
|
|
||
|
if (sliceUnionCst.getNumDimAndSymbolVars() == 0) {
|
||
|
// Initialize 'sliceUnionCst' with the bounds computed in previous step.
|
||
|
if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) {
|
||
|
LLVM_DEBUG(llvm::dbgs()
|
||
|
<< "Unable to compute slice bound constraints\n");
|
||
|
return SliceComputationResult::GenericFailure;
|
||
|
}
|
||
|
assert(sliceUnionCst.getNumDimAndSymbolVars() > 0);
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
// Compute constraints for 'tmpSliceState' in 'tmpSliceCst'.
|
||
|
FlatAffineValueConstraints tmpSliceCst;
|
||
|
if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) {
|
||
|
LLVM_DEBUG(llvm::dbgs()
|
||
|
<< "Unable to compute slice bound constraints\n");
|
||
|
return SliceComputationResult::GenericFailure;
|
||
|
}
|
||
|
|
||
|
// Align coordinate spaces of 'sliceUnionCst' and 'tmpSliceCst' if needed.
|
||
|
if (!sliceUnionCst.areVarsAlignedWithOther(tmpSliceCst)) {
|
||
|
|
||
|
// Pre-constraint var alignment: record loop IVs used in each constraint
|
||
|
// system.
|
||
|
SmallPtrSet<Value, 8> sliceUnionIVs;
|
||
|
for (unsigned k = 0, l = sliceUnionCst.getNumDimVars(); k < l; ++k)
|
||
|
sliceUnionIVs.insert(sliceUnionCst.getValue(k));
|
||
|
SmallPtrSet<Value, 8> tmpSliceIVs;
|
||
|
for (unsigned k = 0, l = tmpSliceCst.getNumDimVars(); k < l; ++k)
|
||
|
tmpSliceIVs.insert(tmpSliceCst.getValue(k));
|
||
|
|
||
|
sliceUnionCst.mergeAndAlignVarsWithOther(/*offset=*/0, &tmpSliceCst);
|
||
|
|
||
|
// Post-constraint var alignment: add loop IV bounds missing after
|
||
|
// var alignment to constraint systems. This can occur if one constraint
|
||
|
// system uses an loop IV that is not used by the other. The call
|
||
|
// to unionBoundingBox below expects constraints for each Loop IV, even
|
||
|
// if they are the unsliced full loop bounds added here.
|
||
|
if (failed(addMissingLoopIVBounds(sliceUnionIVs, &sliceUnionCst)))
|
||
|
return SliceComputationResult::GenericFailure;
|
||
|
if (failed(addMissingLoopIVBounds(tmpSliceIVs, &tmpSliceCst)))
|
||
|
return SliceComputationResult::GenericFailure;
|
||
|
}
|
||
|
// Compute union bounding box of 'sliceUnionCst' and 'tmpSliceCst'.
|
||
|
if (sliceUnionCst.getNumLocalVars() > 0 ||
|
||
|
tmpSliceCst.getNumLocalVars() > 0 ||
|
||
|
failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) {
|
||
|
LLVM_DEBUG(llvm::dbgs()
|
||
|
<< "Unable to compute union bounding box of slice bounds\n");
|
||
|
return SliceComputationResult::GenericFailure;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Empty union.
|
||
|
if (sliceUnionCst.getNumDimAndSymbolVars() == 0)
|
||
|
return SliceComputationResult::GenericFailure;
|
||
|
|
||
|
// Gather loops surrounding ops from loop nest where slice will be inserted.
|
||
|
SmallVector<Operation *, 4> ops;
|
||
|
for (auto &dep : dependentOpPairs) {
|
||
|
ops.push_back(isBackwardSlice ? dep.second : dep.first);
|
||
|
}
|
||
|
SmallVector<AffineForOp, 4> surroundingLoops;
|
||
|
unsigned innermostCommonLoopDepth =
|
||
|
getInnermostCommonLoopDepth(ops, &surroundingLoops);
|
||
|
if (loopDepth > innermostCommonLoopDepth) {
|
||
|
LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n");
|
||
|
return SliceComputationResult::GenericFailure;
|
||
|
}
|
||
|
|
||
|
// Store 'numSliceLoopIVs' before converting dst loop IVs to dims.
|
||
|
unsigned numSliceLoopIVs = sliceUnionCst.getNumDimVars();
|
||
|
|
||
|
// Convert any dst loop IVs which are symbol variables to dim variables.
|
||
|
sliceUnionCst.convertLoopIVSymbolsToDims();
|
||
|
sliceUnion->clearBounds();
|
||
|
sliceUnion->lbs.resize(numSliceLoopIVs, AffineMap());
|
||
|
sliceUnion->ubs.resize(numSliceLoopIVs, AffineMap());
|
||
|
|
||
|
// Get slice bounds from slice union constraints 'sliceUnionCst'.
|
||
|
sliceUnionCst.getSliceBounds(/*offset=*/0, numSliceLoopIVs,
|
||
|
opsA[0]->getContext(), &sliceUnion->lbs,
|
||
|
&sliceUnion->ubs);
|
||
|
|
||
|
// Add slice bound operands of union.
|
||
|
SmallVector<Value, 4> sliceBoundOperands;
|
||
|
sliceUnionCst.getValues(numSliceLoopIVs,
|
||
|
sliceUnionCst.getNumDimAndSymbolVars(),
|
||
|
&sliceBoundOperands);
|
||
|
|
||
|
// Copy src loop IVs from 'sliceUnionCst' to 'sliceUnion'.
|
||
|
sliceUnion->ivs.clear();
|
||
|
sliceUnionCst.getValues(0, numSliceLoopIVs, &sliceUnion->ivs);
|
||
|
|
||
|
// Set loop nest insertion point to block start at 'loopDepth'.
|
||
|
sliceUnion->insertPoint =
|
||
|
isBackwardSlice
|
||
|
? surroundingLoops[loopDepth - 1].getBody()->begin()
|
||
|
: std::prev(surroundingLoops[loopDepth - 1].getBody()->end());
|
||
|
|
||
|
// Give each bound its own copy of 'sliceBoundOperands' for subsequent
|
||
|
// canonicalization.
|
||
|
sliceUnion->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
|
||
|
sliceUnion->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
|
||
|
|
||
|
// Check if the slice computed is valid. Return success only if it is verified
|
||
|
// that the slice is valid, otherwise return appropriate failure status.
|
||
|
std::optional<bool> isSliceValid = sliceUnion->isSliceValid();
|
||
|
if (!isSliceValid) {
|
||
|
LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n");
|
||
|
return SliceComputationResult::GenericFailure;
|
||
|
}
|
||
|
if (!*isSliceValid)
|
||
|
return SliceComputationResult::IncorrectSliceFailure;
|
||
|
|
||
|
return SliceComputationResult::Success;
|
||
|
}
|
||
|
|
||
|
// TODO: extend this to handle multiple result maps.
|
||
|
static std::optional<uint64_t> getConstDifference(AffineMap lbMap,
|
||
|
AffineMap ubMap) {
|
||
|
assert(lbMap.getNumResults() == 1 && "expected single result bound map");
|
||
|
assert(ubMap.getNumResults() == 1 && "expected single result bound map");
|
||
|
assert(lbMap.getNumDims() == ubMap.getNumDims());
|
||
|
assert(lbMap.getNumSymbols() == ubMap.getNumSymbols());
|
||
|
AffineExpr lbExpr(lbMap.getResult(0));
|
||
|
AffineExpr ubExpr(ubMap.getResult(0));
|
||
|
auto loopSpanExpr = simplifyAffineExpr(ubExpr - lbExpr, lbMap.getNumDims(),
|
||
|
lbMap.getNumSymbols());
|
||
|
auto cExpr = dyn_cast<AffineConstantExpr>(loopSpanExpr);
|
||
|
if (!cExpr)
|
||
|
return std::nullopt;
|
||
|
return cExpr.getValue();
|
||
|
}
|
||
|
|
||
|
// Builds a map 'tripCountMap' from AffineForOp to constant trip count for loop
|
||
|
// nest surrounding represented by slice loop bounds in 'slice'. Returns true
|
||
|
// on success, false otherwise (if a non-constant trip count was encountered).
|
||
|
// TODO: Make this work with non-unit step loops.
|
||
|
bool mlir::affine::buildSliceTripCountMap(
|
||
|
const ComputationSliceState &slice,
|
||
|
llvm::SmallDenseMap<Operation *, uint64_t, 8> *tripCountMap) {
|
||
|
unsigned numSrcLoopIVs = slice.ivs.size();
|
||
|
// Populate map from AffineForOp -> trip count
|
||
|
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
|
||
|
AffineForOp forOp = getForInductionVarOwner(slice.ivs[i]);
|
||
|
auto *op = forOp.getOperation();
|
||
|
AffineMap lbMap = slice.lbs[i];
|
||
|
AffineMap ubMap = slice.ubs[i];
|
||
|
// If lower or upper bound maps are null or provide no results, it implies
|
||
|
// that source loop was not at all sliced, and the entire loop will be a
|
||
|
// part of the slice.
|
||
|
if (!lbMap || lbMap.getNumResults() == 0 || !ubMap ||
|
||
|
ubMap.getNumResults() == 0) {
|
||
|
// The iteration of src loop IV 'i' was not sliced. Use full loop bounds.
|
||
|
if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) {
|
||
|
(*tripCountMap)[op] =
|
||
|
forOp.getConstantUpperBound() - forOp.getConstantLowerBound();
|
||
|
continue;
|
||
|
}
|
||
|
std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
|
||
|
if (maybeConstTripCount.has_value()) {
|
||
|
(*tripCountMap)[op] = *maybeConstTripCount;
|
||
|
continue;
|
||
|
}
|
||
|
return false;
|
||
|
}
|
||
|
std::optional<uint64_t> tripCount = getConstDifference(lbMap, ubMap);
|
||
|
// Slice bounds are created with a constant ub - lb difference.
|
||
|
if (!tripCount.has_value())
|
||
|
return false;
|
||
|
(*tripCountMap)[op] = *tripCount;
|
||
|
}
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
// Return the number of iterations in the given slice.
|
||
|
uint64_t mlir::affine::getSliceIterationCount(
|
||
|
const llvm::SmallDenseMap<Operation *, uint64_t, 8> &sliceTripCountMap) {
|
||
|
uint64_t iterCount = 1;
|
||
|
for (const auto &count : sliceTripCountMap) {
|
||
|
iterCount *= count.second;
|
||
|
}
|
||
|
return iterCount;
|
||
|
}
|
||
|
|
||
|
const char *const kSliceFusionBarrierAttrName = "slice_fusion_barrier";
|
||
|
// Computes slice bounds by projecting out any loop IVs from
|
||
|
// 'dependenceConstraints' at depth greater than 'loopDepth', and computes slice
|
||
|
// bounds in 'sliceState' which represent the one loop nest's IVs in terms of
|
||
|
// the other loop nest's IVs, symbols and constants (using 'isBackwardsSlice').
|
||
|
void mlir::affine::getComputationSliceState(
|
||
|
Operation *depSourceOp, Operation *depSinkOp,
|
||
|
FlatAffineValueConstraints *dependenceConstraints, unsigned loopDepth,
|
||
|
bool isBackwardSlice, ComputationSliceState *sliceState) {
|
||
|
// Get loop nest surrounding src operation.
|
||
|
SmallVector<AffineForOp, 4> srcLoopIVs;
|
||
|
getAffineForIVs(*depSourceOp, &srcLoopIVs);
|
||
|
unsigned numSrcLoopIVs = srcLoopIVs.size();
|
||
|
|
||
|
// Get loop nest surrounding dst operation.
|
||
|
SmallVector<AffineForOp, 4> dstLoopIVs;
|
||
|
getAffineForIVs(*depSinkOp, &dstLoopIVs);
|
||
|
unsigned numDstLoopIVs = dstLoopIVs.size();
|
||
|
|
||
|
assert((!isBackwardSlice && loopDepth <= numSrcLoopIVs) ||
|
||
|
(isBackwardSlice && loopDepth <= numDstLoopIVs));
|
||
|
|
||
|
// Project out dimensions other than those up to 'loopDepth'.
|
||
|
unsigned pos = isBackwardSlice ? numSrcLoopIVs + loopDepth : loopDepth;
|
||
|
unsigned num =
|
||
|
isBackwardSlice ? numDstLoopIVs - loopDepth : numSrcLoopIVs - loopDepth;
|
||
|
dependenceConstraints->projectOut(pos, num);
|
||
|
|
||
|
// Add slice loop IV values to 'sliceState'.
|
||
|
unsigned offset = isBackwardSlice ? 0 : loopDepth;
|
||
|
unsigned numSliceLoopIVs = isBackwardSlice ? numSrcLoopIVs : numDstLoopIVs;
|
||
|
dependenceConstraints->getValues(offset, offset + numSliceLoopIVs,
|
||
|
&sliceState->ivs);
|
||
|
|
||
|
// Set up lower/upper bound affine maps for the slice.
|
||
|
sliceState->lbs.resize(numSliceLoopIVs, AffineMap());
|
||
|
sliceState->ubs.resize(numSliceLoopIVs, AffineMap());
|
||
|
|
||
|
// Get bounds for slice IVs in terms of other IVs, symbols, and constants.
|
||
|
dependenceConstraints->getSliceBounds(offset, numSliceLoopIVs,
|
||
|
depSourceOp->getContext(),
|
||
|
&sliceState->lbs, &sliceState->ubs);
|
||
|
|
||
|
// Set up bound operands for the slice's lower and upper bounds.
|
||
|
SmallVector<Value, 4> sliceBoundOperands;
|
||
|
unsigned numDimsAndSymbols = dependenceConstraints->getNumDimAndSymbolVars();
|
||
|
for (unsigned i = 0; i < numDimsAndSymbols; ++i) {
|
||
|
if (i < offset || i >= offset + numSliceLoopIVs) {
|
||
|
sliceBoundOperands.push_back(dependenceConstraints->getValue(i));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Give each bound its own copy of 'sliceBoundOperands' for subsequent
|
||
|
// canonicalization.
|
||
|
sliceState->lbOperands.resize(numSliceLoopIVs, sliceBoundOperands);
|
||
|
sliceState->ubOperands.resize(numSliceLoopIVs, sliceBoundOperands);
|
||
|
|
||
|
// Set destination loop nest insertion point to block start at 'dstLoopDepth'.
|
||
|
sliceState->insertPoint =
|
||
|
isBackwardSlice ? dstLoopIVs[loopDepth - 1].getBody()->begin()
|
||
|
: std::prev(srcLoopIVs[loopDepth - 1].getBody()->end());
|
||
|
|
||
|
llvm::SmallDenseSet<Value, 8> sequentialLoops;
|
||
|
if (isa<AffineReadOpInterface>(depSourceOp) &&
|
||
|
isa<AffineReadOpInterface>(depSinkOp)) {
|
||
|
// For read-read access pairs, clear any slice bounds on sequential loops.
|
||
|
// Get sequential loops in loop nest rooted at 'srcLoopIVs[0]'.
|
||
|
getSequentialLoops(isBackwardSlice ? srcLoopIVs[0] : dstLoopIVs[0],
|
||
|
&sequentialLoops);
|
||
|
}
|
||
|
auto getSliceLoop = [&](unsigned i) {
|
||
|
return isBackwardSlice ? srcLoopIVs[i] : dstLoopIVs[i];
|
||
|
};
|
||
|
auto isInnermostInsertion = [&]() {
|
||
|
return (isBackwardSlice ? loopDepth >= srcLoopIVs.size()
|
||
|
: loopDepth >= dstLoopIVs.size());
|
||
|
};
|
||
|
llvm::SmallDenseMap<Operation *, uint64_t, 8> sliceTripCountMap;
|
||
|
auto srcIsUnitSlice = [&]() {
|
||
|
return (buildSliceTripCountMap(*sliceState, &sliceTripCountMap) &&
|
||
|
(getSliceIterationCount(sliceTripCountMap) == 1));
|
||
|
};
|
||
|
// Clear all sliced loop bounds beginning at the first sequential loop, or
|
||
|
// first loop with a slice fusion barrier attribute..
|
||
|
|
||
|
for (unsigned i = 0; i < numSliceLoopIVs; ++i) {
|
||
|
Value iv = getSliceLoop(i).getInductionVar();
|
||
|
if (sequentialLoops.count(iv) == 0 &&
|
||
|
getSliceLoop(i)->getAttr(kSliceFusionBarrierAttrName) == nullptr)
|
||
|
continue;
|
||
|
// Skip reset of bounds of reduction loop inserted in the destination loop
|
||
|
// that meets the following conditions:
|
||
|
// 1. Slice is single trip count.
|
||
|
// 2. Loop bounds of the source and destination match.
|
||
|
// 3. Is being inserted at the innermost insertion point.
|
||
|
std::optional<bool> isMaximal = sliceState->isMaximal();
|
||
|
if (isLoopParallelAndContainsReduction(getSliceLoop(i)) &&
|
||
|
isInnermostInsertion() && srcIsUnitSlice() && isMaximal && *isMaximal)
|
||
|
continue;
|
||
|
for (unsigned j = i; j < numSliceLoopIVs; ++j) {
|
||
|
sliceState->lbs[j] = AffineMap();
|
||
|
sliceState->ubs[j] = AffineMap();
|
||
|
}
|
||
|
break;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
/// Creates a computation slice of the loop nest surrounding 'srcOpInst',
|
||
|
/// updates the slice loop bounds with any non-null bound maps specified in
|
||
|
/// 'sliceState', and inserts this slice into the loop nest surrounding
|
||
|
/// 'dstOpInst' at loop depth 'dstLoopDepth'.
|
||
|
// TODO: extend the slicing utility to compute slices that
|
||
|
// aren't necessarily a one-to-one relation b/w the source and destination. The
|
||
|
// relation between the source and destination could be many-to-many in general.
|
||
|
// TODO: the slice computation is incorrect in the cases
|
||
|
// where the dependence from the source to the destination does not cover the
|
||
|
// entire destination index set. Subtract out the dependent destination
|
||
|
// iterations from destination index set and check for emptiness --- this is one
|
||
|
// solution.
|
||
|
AffineForOp mlir::affine::insertBackwardComputationSlice(
|
||
|
Operation *srcOpInst, Operation *dstOpInst, unsigned dstLoopDepth,
|
||
|
ComputationSliceState *sliceState) {
|
||
|
// Get loop nest surrounding src operation.
|
||
|
SmallVector<AffineForOp, 4> srcLoopIVs;
|
||
|
getAffineForIVs(*srcOpInst, &srcLoopIVs);
|
||
|
unsigned numSrcLoopIVs = srcLoopIVs.size();
|
||
|
|
||
|
// Get loop nest surrounding dst operation.
|
||
|
SmallVector<AffineForOp, 4> dstLoopIVs;
|
||
|
getAffineForIVs(*dstOpInst, &dstLoopIVs);
|
||
|
unsigned dstLoopIVsSize = dstLoopIVs.size();
|
||
|
if (dstLoopDepth > dstLoopIVsSize) {
|
||
|
dstOpInst->emitError("invalid destination loop depth");
|
||
|
return AffineForOp();
|
||
|
}
|
||
|
|
||
|
// Find the op block positions of 'srcOpInst' within 'srcLoopIVs'.
|
||
|
SmallVector<unsigned, 4> positions;
|
||
|
// TODO: This code is incorrect since srcLoopIVs can be 0-d.
|
||
|
findInstPosition(srcOpInst, srcLoopIVs[0]->getBlock(), &positions);
|
||
|
|
||
|
// Clone src loop nest and insert it a the beginning of the operation block
|
||
|
// of the loop at 'dstLoopDepth' in 'dstLoopIVs'.
|
||
|
auto dstAffineForOp = dstLoopIVs[dstLoopDepth - 1];
|
||
|
OpBuilder b(dstAffineForOp.getBody(), dstAffineForOp.getBody()->begin());
|
||
|
auto sliceLoopNest =
|
||
|
cast<AffineForOp>(b.clone(*srcLoopIVs[0].getOperation()));
|
||
|
|
||
|
Operation *sliceInst =
|
||
|
getInstAtPosition(positions, /*level=*/0, sliceLoopNest.getBody());
|
||
|
// Get loop nest surrounding 'sliceInst'.
|
||
|
SmallVector<AffineForOp, 4> sliceSurroundingLoops;
|
||
|
getAffineForIVs(*sliceInst, &sliceSurroundingLoops);
|
||
|
|
||
|
// Sanity check.
|
||
|
unsigned sliceSurroundingLoopsSize = sliceSurroundingLoops.size();
|
||
|
(void)sliceSurroundingLoopsSize;
|
||
|
assert(dstLoopDepth + numSrcLoopIVs >= sliceSurroundingLoopsSize);
|
||
|
unsigned sliceLoopLimit = dstLoopDepth + numSrcLoopIVs;
|
||
|
(void)sliceLoopLimit;
|
||
|
assert(sliceLoopLimit >= sliceSurroundingLoopsSize);
|
||
|
|
||
|
// Update loop bounds for loops in 'sliceLoopNest'.
|
||
|
for (unsigned i = 0; i < numSrcLoopIVs; ++i) {
|
||
|
auto forOp = sliceSurroundingLoops[dstLoopDepth + i];
|
||
|
if (AffineMap lbMap = sliceState->lbs[i])
|
||
|
forOp.setLowerBound(sliceState->lbOperands[i], lbMap);
|
||
|
if (AffineMap ubMap = sliceState->ubs[i])
|
||
|
forOp.setUpperBound(sliceState->ubOperands[i], ubMap);
|
||
|
}
|
||
|
return sliceLoopNest;
|
||
|
}
|
||
|
|
||
|
// Constructs MemRefAccess populating it with the memref, its indices and
|
||
|
// opinst from 'loadOrStoreOpInst'.
|
||
|
MemRefAccess::MemRefAccess(Operation *loadOrStoreOpInst) {
|
||
|
if (auto loadOp = dyn_cast<AffineReadOpInterface>(loadOrStoreOpInst)) {
|
||
|
memref = loadOp.getMemRef();
|
||
|
opInst = loadOrStoreOpInst;
|
||
|
llvm::append_range(indices, loadOp.getMapOperands());
|
||
|
} else {
|
||
|
assert(isa<AffineWriteOpInterface>(loadOrStoreOpInst) &&
|
||
|
"Affine read/write op expected");
|
||
|
auto storeOp = cast<AffineWriteOpInterface>(loadOrStoreOpInst);
|
||
|
opInst = loadOrStoreOpInst;
|
||
|
memref = storeOp.getMemRef();
|
||
|
llvm::append_range(indices, storeOp.getMapOperands());
|
||
|
}
|
||
|
}
|
||
|
|
||
|
unsigned MemRefAccess::getRank() const {
|
||
|
return cast<MemRefType>(memref.getType()).getRank();
|
||
|
}
|
||
|
|
||
|
bool MemRefAccess::isStore() const {
|
||
|
return isa<AffineWriteOpInterface>(opInst);
|
||
|
}
|
||
|
|
||
|
/// Returns the nesting depth of this statement, i.e., the number of loops
|
||
|
/// surrounding this statement.
|
||
|
unsigned mlir::affine::getNestingDepth(Operation *op) {
|
||
|
Operation *currOp = op;
|
||
|
unsigned depth = 0;
|
||
|
while ((currOp = currOp->getParentOp())) {
|
||
|
if (isa<AffineForOp>(currOp))
|
||
|
depth++;
|
||
|
}
|
||
|
return depth;
|
||
|
}
|
||
|
|
||
|
/// Equal if both affine accesses are provably equivalent (at compile
|
||
|
/// time) when considering the memref, the affine maps and their respective
|
||
|
/// operands. The equality of access functions + operands is checked by
|
||
|
/// subtracting fully composed value maps, and then simplifying the difference
|
||
|
/// using the expression flattener.
|
||
|
/// TODO: this does not account for aliasing of memrefs.
|
||
|
bool MemRefAccess::operator==(const MemRefAccess &rhs) const {
|
||
|
if (memref != rhs.memref)
|
||
|
return false;
|
||
|
|
||
|
AffineValueMap diff, thisMap, rhsMap;
|
||
|
getAccessMap(&thisMap);
|
||
|
rhs.getAccessMap(&rhsMap);
|
||
|
AffineValueMap::difference(thisMap, rhsMap, &diff);
|
||
|
return llvm::all_of(diff.getAffineMap().getResults(),
|
||
|
[](AffineExpr e) { return e == 0; });
|
||
|
}
|
||
|
|
||
|
void mlir::affine::getAffineIVs(Operation &op, SmallVectorImpl<Value> &ivs) {
|
||
|
auto *currOp = op.getParentOp();
|
||
|
AffineForOp currAffineForOp;
|
||
|
// Traverse up the hierarchy collecting all 'affine.for' and affine.parallel
|
||
|
// operation while skipping over 'affine.if' operations.
|
||
|
while (currOp) {
|
||
|
if (AffineForOp currAffineForOp = dyn_cast<AffineForOp>(currOp))
|
||
|
ivs.push_back(currAffineForOp.getInductionVar());
|
||
|
else if (auto parOp = dyn_cast<AffineParallelOp>(currOp))
|
||
|
llvm::append_range(ivs, parOp.getIVs());
|
||
|
currOp = currOp->getParentOp();
|
||
|
}
|
||
|
std::reverse(ivs.begin(), ivs.end());
|
||
|
}
|
||
|
|
||
|
/// Returns the number of surrounding loops common to 'loopsA' and 'loopsB',
|
||
|
/// where each lists loops from outer-most to inner-most in loop nest.
|
||
|
unsigned mlir::affine::getNumCommonSurroundingLoops(Operation &a,
|
||
|
Operation &b) {
|
||
|
SmallVector<Value, 4> loopsA, loopsB;
|
||
|
getAffineIVs(a, loopsA);
|
||
|
getAffineIVs(b, loopsB);
|
||
|
|
||
|
unsigned minNumLoops = std::min(loopsA.size(), loopsB.size());
|
||
|
unsigned numCommonLoops = 0;
|
||
|
for (unsigned i = 0; i < minNumLoops; ++i) {
|
||
|
if (loopsA[i] != loopsB[i])
|
||
|
break;
|
||
|
++numCommonLoops;
|
||
|
}
|
||
|
return numCommonLoops;
|
||
|
}
|
||
|
|
||
|
static std::optional<int64_t> getMemoryFootprintBytes(Block &block,
|
||
|
Block::iterator start,
|
||
|
Block::iterator end,
|
||
|
int memorySpace) {
|
||
|
SmallDenseMap<Value, std::unique_ptr<MemRefRegion>, 4> regions;
|
||
|
|
||
|
// Walk this 'affine.for' operation to gather all memory regions.
|
||
|
auto result = block.walk(start, end, [&](Operation *opInst) -> WalkResult {
|
||
|
if (!isa<AffineReadOpInterface, AffineWriteOpInterface>(opInst)) {
|
||
|
// Neither load nor a store op.
|
||
|
return WalkResult::advance();
|
||
|
}
|
||
|
|
||
|
// Compute the memref region symbolic in any IVs enclosing this block.
|
||
|
auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
|
||
|
if (failed(
|
||
|
region->compute(opInst,
|
||
|
/*loopDepth=*/getNestingDepth(&*block.begin())))) {
|
||
|
return opInst->emitError("error obtaining memory region\n");
|
||
|
}
|
||
|
|
||
|
auto it = regions.find(region->memref);
|
||
|
if (it == regions.end()) {
|
||
|
regions[region->memref] = std::move(region);
|
||
|
} else if (failed(it->second->unionBoundingBox(*region))) {
|
||
|
return opInst->emitWarning(
|
||
|
"getMemoryFootprintBytes: unable to perform a union on a memory "
|
||
|
"region");
|
||
|
}
|
||
|
return WalkResult::advance();
|
||
|
});
|
||
|
if (result.wasInterrupted())
|
||
|
return std::nullopt;
|
||
|
|
||
|
int64_t totalSizeInBytes = 0;
|
||
|
for (const auto ®ion : regions) {
|
||
|
std::optional<int64_t> size = region.second->getRegionSize();
|
||
|
if (!size.has_value())
|
||
|
return std::nullopt;
|
||
|
totalSizeInBytes += *size;
|
||
|
}
|
||
|
return totalSizeInBytes;
|
||
|
}
|
||
|
|
||
|
std::optional<int64_t> mlir::affine::getMemoryFootprintBytes(AffineForOp forOp,
|
||
|
int memorySpace) {
|
||
|
auto *forInst = forOp.getOperation();
|
||
|
return ::getMemoryFootprintBytes(
|
||
|
*forInst->getBlock(), Block::iterator(forInst),
|
||
|
std::next(Block::iterator(forInst)), memorySpace);
|
||
|
}
|
||
|
|
||
|
/// Returns whether a loop is parallel and contains a reduction loop.
|
||
|
bool mlir::affine::isLoopParallelAndContainsReduction(AffineForOp forOp) {
|
||
|
SmallVector<LoopReduction> reductions;
|
||
|
if (!isLoopParallel(forOp, &reductions))
|
||
|
return false;
|
||
|
return !reductions.empty();
|
||
|
}
|
||
|
|
||
|
/// Returns in 'sequentialLoops' all sequential loops in loop nest rooted
|
||
|
/// at 'forOp'.
|
||
|
void mlir::affine::getSequentialLoops(
|
||
|
AffineForOp forOp, llvm::SmallDenseSet<Value, 8> *sequentialLoops) {
|
||
|
forOp->walk([&](Operation *op) {
|
||
|
if (auto innerFor = dyn_cast<AffineForOp>(op))
|
||
|
if (!isLoopParallel(innerFor))
|
||
|
sequentialLoops->insert(innerFor.getInductionVar());
|
||
|
});
|
||
|
}
|
||
|
|
||
|
IntegerSet mlir::affine::simplifyIntegerSet(IntegerSet set) {
|
||
|
FlatAffineValueConstraints fac(set);
|
||
|
if (fac.isEmpty())
|
||
|
return IntegerSet::getEmptySet(set.getNumDims(), set.getNumSymbols(),
|
||
|
set.getContext());
|
||
|
fac.removeTrivialRedundancy();
|
||
|
|
||
|
auto simplifiedSet = fac.getAsIntegerSet(set.getContext());
|
||
|
assert(simplifiedSet && "guaranteed to succeed while roundtripping");
|
||
|
return simplifiedSet;
|
||
|
}
|
||
|
|
||
|
static void unpackOptionalValues(ArrayRef<std::optional<Value>> source,
|
||
|
SmallVector<Value> &target) {
|
||
|
target =
|
||
|
llvm::to_vector<4>(llvm::map_range(source, [](std::optional<Value> val) {
|
||
|
return val.has_value() ? *val : Value();
|
||
|
}));
|
||
|
}
|
||
|
|
||
|
/// Bound an identifier `pos` in a given FlatAffineValueConstraints with
|
||
|
/// constraints drawn from an affine map. Before adding the constraint, the
|
||
|
/// dimensions/symbols of the affine map are aligned with `constraints`.
|
||
|
/// `operands` are the SSA Value operands used with the affine map.
|
||
|
/// Note: This function adds a new symbol column to the `constraints` for each
|
||
|
/// dimension/symbol that exists in the affine map but not in `constraints`.
|
||
|
static LogicalResult alignAndAddBound(FlatAffineValueConstraints &constraints,
|
||
|
BoundType type, unsigned pos,
|
||
|
AffineMap map, ValueRange operands) {
|
||
|
SmallVector<Value> dims, syms, newSyms;
|
||
|
unpackOptionalValues(constraints.getMaybeValues(VarKind::SetDim), dims);
|
||
|
unpackOptionalValues(constraints.getMaybeValues(VarKind::Symbol), syms);
|
||
|
|
||
|
AffineMap alignedMap =
|
||
|
alignAffineMapWithValues(map, operands, dims, syms, &newSyms);
|
||
|
for (unsigned i = syms.size(); i < newSyms.size(); ++i)
|
||
|
constraints.appendSymbolVar(newSyms[i]);
|
||
|
return constraints.addBound(type, pos, alignedMap);
|
||
|
}
|
||
|
|
||
|
/// Add `val` to each result of `map`.
|
||
|
static AffineMap addConstToResults(AffineMap map, int64_t val) {
|
||
|
SmallVector<AffineExpr> newResults;
|
||
|
for (AffineExpr r : map.getResults())
|
||
|
newResults.push_back(r + val);
|
||
|
return AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
|
||
|
map.getContext());
|
||
|
}
|
||
|
|
||
|
// Attempt to simplify the given min/max operation by proving that its value is
|
||
|
// bounded by the same lower and upper bound.
|
||
|
//
|
||
|
// Bounds are computed by FlatAffineValueConstraints. Invariants required for
|
||
|
// finding/proving bounds should be supplied via `constraints`.
|
||
|
//
|
||
|
// 1. Add dimensions for `op` and `opBound` (lower or upper bound of `op`).
|
||
|
// 2. Compute an upper bound of `op` (in case of `isMin`) or a lower bound (in
|
||
|
// case of `!isMin`) and bind it to `opBound`. SSA values that are used in
|
||
|
// `op` but are not part of `constraints`, are added as extra symbols.
|
||
|
// 3. For each result of `op`: Add result as a dimension `r_i`. Prove that:
|
||
|
// * If `isMin`: r_i >= opBound
|
||
|
// * If `isMax`: r_i <= opBound
|
||
|
// If this is the case, ub(op) == lb(op).
|
||
|
// 4. Replace `op` with `opBound`.
|
||
|
//
|
||
|
// In summary, the following constraints are added throughout this function.
|
||
|
// Note: `invar` are dimensions added by the caller to express the invariants.
|
||
|
// (Showing only the case where `isMin`.)
|
||
|
//
|
||
|
// invar | op | opBound | r_i | extra syms... | const | eq/ineq
|
||
|
// ------+-------+---------+-----+---------------+-------+-------------------
|
||
|
// (various eq./ineq. constraining `invar`, added by the caller)
|
||
|
// ... | 0 | 0 | 0 | 0 | ... | ...
|
||
|
// ------+-------+---------+-----+---------------+-------+-------------------
|
||
|
// (various ineq. constraining `op` in terms of `op` operands (`invar` and
|
||
|
// extra `op` operands "extra syms" that are not in `invar`)).
|
||
|
// ... | -1 | 0 | 0 | ... | ... | >= 0
|
||
|
// ------+-------+---------+-----+---------------+-------+-------------------
|
||
|
// (set `opBound` to `op` upper bound in terms of `invar` and "extra syms")
|
||
|
// ... | 0 | -1 | 0 | ... | ... | = 0
|
||
|
// ------+-------+---------+-----+---------------+-------+-------------------
|
||
|
// (for each `op` map result r_i: set r_i to corresponding map result,
|
||
|
// prove that r_i >= minOpUb via contradiction)
|
||
|
// ... | 0 | 0 | -1 | ... | ... | = 0
|
||
|
// 0 | 0 | 1 | -1 | 0 | -1 | >= 0
|
||
|
//
|
||
|
FailureOr<AffineValueMap> mlir::affine::simplifyConstrainedMinMaxOp(
|
||
|
Operation *op, FlatAffineValueConstraints constraints) {
|
||
|
bool isMin = isa<AffineMinOp>(op);
|
||
|
assert((isMin || isa<AffineMaxOp>(op)) && "expect AffineMin/MaxOp");
|
||
|
MLIRContext *ctx = op->getContext();
|
||
|
Builder builder(ctx);
|
||
|
AffineMap map =
|
||
|
isMin ? cast<AffineMinOp>(op).getMap() : cast<AffineMaxOp>(op).getMap();
|
||
|
ValueRange operands = op->getOperands();
|
||
|
unsigned numResults = map.getNumResults();
|
||
|
|
||
|
// Add a few extra dimensions.
|
||
|
unsigned dimOp = constraints.appendDimVar(); // `op`
|
||
|
unsigned dimOpBound = constraints.appendDimVar(); // `op` lower/upper bound
|
||
|
unsigned resultDimStart = constraints.appendDimVar(/*num=*/numResults);
|
||
|
|
||
|
// Add an inequality for each result expr_i of map:
|
||
|
// isMin: op <= expr_i, !isMin: op >= expr_i
|
||
|
auto boundType = isMin ? BoundType::UB : BoundType::LB;
|
||
|
// Upper bounds are exclusive, so add 1. (`affine.min` ops are inclusive.)
|
||
|
AffineMap mapLbUb = isMin ? addConstToResults(map, 1) : map;
|
||
|
if (failed(
|
||
|
alignAndAddBound(constraints, boundType, dimOp, mapLbUb, operands)))
|
||
|
return failure();
|
||
|
|
||
|
// Try to compute a lower/upper bound for op, expressed in terms of the other
|
||
|
// `dims` and extra symbols.
|
||
|
SmallVector<AffineMap> opLb(1), opUb(1);
|
||
|
constraints.getSliceBounds(dimOp, 1, ctx, &opLb, &opUb);
|
||
|
AffineMap sliceBound = isMin ? opUb[0] : opLb[0];
|
||
|
// TODO: `getSliceBounds` may return multiple bounds at the moment. This is
|
||
|
// a TODO of `getSliceBounds` and not handled here.
|
||
|
if (!sliceBound || sliceBound.getNumResults() != 1)
|
||
|
return failure(); // No or multiple bounds found.
|
||
|
// Recover the inclusive UB in the case of an `affine.min`.
|
||
|
AffineMap boundMap = isMin ? addConstToResults(sliceBound, -1) : sliceBound;
|
||
|
|
||
|
// Add an equality: Set dimOpBound to computed bound.
|
||
|
// Add back dimension for op. (Was removed by `getSliceBounds`.)
|
||
|
AffineMap alignedBoundMap = boundMap.shiftDims(/*shift=*/1, /*offset=*/dimOp);
|
||
|
if (failed(constraints.addBound(BoundType::EQ, dimOpBound, alignedBoundMap)))
|
||
|
return failure();
|
||
|
|
||
|
// If the constraint system is empty, there is an inconsistency. (E.g., this
|
||
|
// can happen if loop lb > ub.)
|
||
|
if (constraints.isEmpty())
|
||
|
return failure();
|
||
|
|
||
|
// In the case of `isMin` (`!isMin` is inversed):
|
||
|
// Prove that each result of `map` has a lower bound that is equal to (or
|
||
|
// greater than) the upper bound of `op` (`dimOpBound`). In that case, `op`
|
||
|
// can be replaced with the bound. I.e., prove that for each result
|
||
|
// expr_i (represented by dimension r_i):
|
||
|
//
|
||
|
// r_i >= opBound
|
||
|
//
|
||
|
// To prove this inequality, add its negation to the constraint set and prove
|
||
|
// that the constraint set is empty.
|
||
|
for (unsigned i = resultDimStart; i < resultDimStart + numResults; ++i) {
|
||
|
FlatAffineValueConstraints newConstr(constraints);
|
||
|
|
||
|
// Add an equality: r_i = expr_i
|
||
|
// Note: These equalities could have been added earlier and used to express
|
||
|
// minOp <= expr_i. However, then we run the risk that `getSliceBounds`
|
||
|
// computes minOpUb in terms of r_i dims, which is not desired.
|
||
|
if (failed(alignAndAddBound(newConstr, BoundType::EQ, i,
|
||
|
map.getSubMap({i - resultDimStart}), operands)))
|
||
|
return failure();
|
||
|
|
||
|
// If `isMin`: Add inequality: r_i < opBound
|
||
|
// equiv.: opBound - r_i - 1 >= 0
|
||
|
// If `!isMin`: Add inequality: r_i > opBound
|
||
|
// equiv.: -opBound + r_i - 1 >= 0
|
||
|
SmallVector<int64_t> ineq(newConstr.getNumCols(), 0);
|
||
|
ineq[dimOpBound] = isMin ? 1 : -1;
|
||
|
ineq[i] = isMin ? -1 : 1;
|
||
|
ineq[newConstr.getNumCols() - 1] = -1;
|
||
|
newConstr.addInequality(ineq);
|
||
|
if (!newConstr.isEmpty())
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
// Lower and upper bound of `op` are equal. Replace `minOp` with its bound.
|
||
|
AffineMap newMap = alignedBoundMap;
|
||
|
SmallVector<Value> newOperands;
|
||
|
unpackOptionalValues(constraints.getMaybeValues(), newOperands);
|
||
|
// If dims/symbols have known constant values, use those in order to simplify
|
||
|
// the affine map further.
|
||
|
for (int64_t i = 0, e = constraints.getNumDimAndSymbolVars(); i < e; ++i) {
|
||
|
// Skip unused operands and operands that are already constants.
|
||
|
if (!newOperands[i] || getConstantIntValue(newOperands[i]))
|
||
|
continue;
|
||
|
if (auto bound = constraints.getConstantBound64(BoundType::EQ, i)) {
|
||
|
AffineExpr expr =
|
||
|
i < newMap.getNumDims()
|
||
|
? builder.getAffineDimExpr(i)
|
||
|
: builder.getAffineSymbolExpr(i - newMap.getNumDims());
|
||
|
newMap = newMap.replace(expr, builder.getAffineConstantExpr(*bound),
|
||
|
newMap.getNumDims(), newMap.getNumSymbols());
|
||
|
}
|
||
|
}
|
||
|
affine::canonicalizeMapAndOperands(&newMap, &newOperands);
|
||
|
return AffineValueMap(newMap, newOperands);
|
||
|
}
|