235 lines
7.6 KiB
C++
235 lines
7.6 KiB
C++
|
//===- PipelineGlobalOpsPass.cpp - Pipeline Global Ops Pass ---------------===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
|
||
|
|
||
|
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
|
||
|
#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
|
||
|
#include "mlir/IR/BuiltinOps.h"
|
||
|
#include "mlir/Pass/Pass.h"
|
||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||
|
|
||
|
namespace mlir {
|
||
|
namespace ml_program {
|
||
|
#define GEN_PASS_DEF_MLPROGRAMPIPELINEGLOBALS
|
||
|
#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
|
||
|
|
||
|
namespace {
|
||
|
|
||
|
class MLProgramPipelineGlobals
|
||
|
: public impl::MLProgramPipelineGlobalsBase<MLProgramPipelineGlobals> {
|
||
|
public:
|
||
|
void runOnOperation() override;
|
||
|
|
||
|
private:
|
||
|
LogicalResult buildGlobalMap(ModuleOp op);
|
||
|
|
||
|
void ProcessBlock(Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
|
||
|
llvm::DenseSet<SymbolRefAttr> &symbolStore);
|
||
|
|
||
|
llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> loadSymbolsMap;
|
||
|
llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> storeSymbolsMap;
|
||
|
};
|
||
|
|
||
|
// Traverses upwards searchign for the operation mapped by the symbol.
|
||
|
static Operation *getFromSymbol(Operation *baseOp, SymbolRefAttr symbol) {
|
||
|
for (auto op = baseOp; op; op = op->getParentOp()) {
|
||
|
auto lookup = SymbolTable::lookupNearestSymbolFrom(op, symbol);
|
||
|
if (lookup)
|
||
|
return lookup;
|
||
|
}
|
||
|
return nullptr;
|
||
|
}
|
||
|
|
||
|
// Builds map from a symbol to MLProgram global symbols loaded or stored
|
||
|
// during processing.
|
||
|
LogicalResult MLProgramPipelineGlobals::buildGlobalMap(ModuleOp module) {
|
||
|
llvm::DenseMap<SymbolRefAttr, Operation *> callableMap;
|
||
|
auto res = module->walk([&](Operation *op) {
|
||
|
if (auto caller = mlir::dyn_cast<CallOpInterface>(op)) {
|
||
|
auto callable = caller.getCallableForCallee();
|
||
|
// For now we do not know how to handle Value based tracing, so fail.
|
||
|
if (mlir::isa<Value>(callable)) {
|
||
|
return WalkResult::interrupt();
|
||
|
}
|
||
|
|
||
|
auto symbol = mlir::dyn_cast<SymbolRefAttr>(callable);
|
||
|
auto func = getFromSymbol(op, symbol);
|
||
|
callableMap[symbol] = func;
|
||
|
}
|
||
|
return WalkResult::advance();
|
||
|
});
|
||
|
|
||
|
if (res.wasInterrupted()) {
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
// First grab all symbols loaded or stored by each function. This
|
||
|
// will not handle calls initially.
|
||
|
llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opLoadSymbols;
|
||
|
llvm::DenseMap<SymbolRefAttr, llvm::DenseSet<SymbolRefAttr>> opStoreSymbols;
|
||
|
for (auto callable : callableMap) {
|
||
|
llvm::DenseSet<SymbolRefAttr> loadSymbols;
|
||
|
llvm::DenseSet<SymbolRefAttr> storeSymbols;
|
||
|
|
||
|
callable.getSecond()->walk(
|
||
|
[&](GlobalLoadOp op) { loadSymbols.insert(op.getGlobal()); });
|
||
|
|
||
|
callable.getSecond()->walk(
|
||
|
[&](GlobalStoreOp op) { storeSymbols.insert(op.getGlobal()); });
|
||
|
|
||
|
opLoadSymbols[callable.getFirst()] = std::move(loadSymbols);
|
||
|
opStoreSymbols[callable.getFirst()] = std::move(storeSymbols);
|
||
|
}
|
||
|
|
||
|
// For each callable function we find each global loaded/stored within the
|
||
|
// function or a nested called function. This includes recursion checking to
|
||
|
// avoid infinitely recursing.
|
||
|
for (auto callable : callableMap) {
|
||
|
SymbolRefAttr thisSymbol = llvm::dyn_cast<SymbolRefAttr>(callable.first);
|
||
|
llvm::SmallVector<SymbolRefAttr> work = {thisSymbol};
|
||
|
llvm::DenseSet<SymbolRefAttr> visited = {thisSymbol};
|
||
|
llvm::DenseSet<SymbolRefAttr> loadSymbols;
|
||
|
llvm::DenseSet<SymbolRefAttr> storeSymbols;
|
||
|
|
||
|
for (size_t i = 0; i < work.size(); ++i) {
|
||
|
callableMap[work[i]]->walk([&](CallOpInterface call) {
|
||
|
auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
|
||
|
if (!visited.contains(symbol)) {
|
||
|
visited.insert(symbol);
|
||
|
work.push_back(symbol);
|
||
|
}
|
||
|
});
|
||
|
|
||
|
for (auto load : opLoadSymbols[work[i]])
|
||
|
loadSymbols.insert(load);
|
||
|
|
||
|
for (auto store : opStoreSymbols[work[i]])
|
||
|
storeSymbols.insert(store);
|
||
|
}
|
||
|
|
||
|
loadSymbolsMap[thisSymbol] = std::move(loadSymbols);
|
||
|
storeSymbolsMap[thisSymbol] = std::move(storeSymbols);
|
||
|
}
|
||
|
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
// Process each operation in the block deleting unneeded loads / stores,
|
||
|
// recursing on subblocks and checking function calls.
|
||
|
void MLProgramPipelineGlobals::ProcessBlock(
|
||
|
Block &block, llvm::DenseSet<SymbolRefAttr> &symbolLoad,
|
||
|
llvm::DenseSet<SymbolRefAttr> &symbolStore) {
|
||
|
|
||
|
llvm::DenseMap<SymbolRefAttr, Value> previousLoads;
|
||
|
llvm::DenseMap<SymbolRefAttr, Operation *> previousStores;
|
||
|
llvm::SmallVector<Operation *> toDelete;
|
||
|
for (auto &op : block) {
|
||
|
// If this is a global load, remap to a previous value if known
|
||
|
// and delete this load. Remember that this value is the currently
|
||
|
// known load.
|
||
|
if (auto load = mlir::dyn_cast<GlobalLoadOp>(op)) {
|
||
|
auto ref = load.getGlobal();
|
||
|
symbolLoad.insert(ref);
|
||
|
if (previousLoads.contains(ref)) {
|
||
|
toDelete.push_back(&op);
|
||
|
load.getResult().replaceAllUsesWith(previousLoads[ref]);
|
||
|
} else {
|
||
|
previousLoads[ref] = load.getResult();
|
||
|
}
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
// Delete a previous store if it exists and is not needed, update
|
||
|
// the most recent known value for this global ref.
|
||
|
if (auto store = mlir::dyn_cast<GlobalStoreOp>(op)) {
|
||
|
auto ref = store.getGlobal();
|
||
|
symbolStore.insert(ref);
|
||
|
if (previousStores.contains(ref)) {
|
||
|
toDelete.push_back(previousStores.find(ref)->getSecond());
|
||
|
}
|
||
|
|
||
|
previousLoads[ref] = store.getValue();
|
||
|
previousStores[ref] = &op;
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
// If a function is called, clear known values for loads/stores used by
|
||
|
// the function or its sub-functions.
|
||
|
if (auto call = mlir::dyn_cast<CallOpInterface>(op)) {
|
||
|
auto loadSymbols =
|
||
|
loadSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
|
||
|
auto storeSymbols =
|
||
|
storeSymbolsMap[dyn_cast<SymbolRefAttr>(call.getCallableForCallee())];
|
||
|
|
||
|
for (auto sym : loadSymbols) {
|
||
|
previousStores.erase(sym);
|
||
|
}
|
||
|
|
||
|
for (auto sym : storeSymbols) {
|
||
|
previousLoads.erase(sym);
|
||
|
previousStores.erase(sym);
|
||
|
}
|
||
|
continue;
|
||
|
}
|
||
|
|
||
|
// If the op has sub-regions, recurse inside. We make no guarantees whether
|
||
|
// the recursion occurs.
|
||
|
llvm::DenseSet<SymbolRefAttr> opSymbolLoad;
|
||
|
llvm::DenseSet<SymbolRefAttr> opSymbolStore;
|
||
|
for (auto ®ion : op.getRegions()) {
|
||
|
for (auto &block : region) {
|
||
|
ProcessBlock(block, opSymbolLoad, opSymbolStore);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Update current state from the subblock.
|
||
|
for (auto change : opSymbolLoad) {
|
||
|
symbolLoad.insert(change);
|
||
|
previousStores.erase(change);
|
||
|
}
|
||
|
|
||
|
for (auto change : opSymbolStore) {
|
||
|
symbolStore.insert(change);
|
||
|
previousLoads.erase(change);
|
||
|
previousStores.erase(change);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
for (auto op : toDelete) {
|
||
|
op->erase();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void MLProgramPipelineGlobals::runOnOperation() {
|
||
|
auto targetOp = getOperation();
|
||
|
if (failed(buildGlobalMap(targetOp))) {
|
||
|
return;
|
||
|
}
|
||
|
|
||
|
for (auto &funcOp : *targetOp.getBody()) {
|
||
|
for (auto ®ion : funcOp.getRegions()) {
|
||
|
for (auto &block : region.getBlocks()) {
|
||
|
llvm::DenseSet<SymbolRefAttr> symbolsLoaded;
|
||
|
llvm::DenseSet<SymbolRefAttr> symbolsStored;
|
||
|
ProcessBlock(block, symbolsLoaded, symbolsStored);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
} // namespace
|
||
|
|
||
|
std::unique_ptr<OperationPass<mlir::ModuleOp>>
|
||
|
createMLProgramPipelineGlobalsPass() {
|
||
|
return std::make_unique<MLProgramPipelineGlobals>();
|
||
|
}
|
||
|
|
||
|
} // namespace ml_program
|
||
|
} // namespace mlir
|