//===- TestIRVisitors.cpp - Pass to test the IR visitors ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Iterators.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" using namespace mlir; static void printRegion(Region *region) { llvm::outs() << "region " << region->getRegionNumber() << " from operation '" << region->getParentOp()->getName() << "'"; } static void printBlock(Block *block) { llvm::outs() << "block "; block->printAsOperand(llvm::outs(), /*printType=*/false); llvm::outs() << " from "; printRegion(block->getParent()); } static void printOperation(Operation *op) { llvm::outs() << "op '" << op->getName() << "'"; } /// Tests pure callbacks. static void testPureCallbacks(Operation *op) { auto opPure = [](Operation *op) { llvm::outs() << "Visiting "; printOperation(op); llvm::outs() << "\n"; }; auto blockPure = [](Block *block) { llvm::outs() << "Visiting "; printBlock(block); llvm::outs() << "\n"; }; auto regionPure = [](Region *region) { llvm::outs() << "Visiting "; printRegion(region); llvm::outs() << "\n"; }; llvm::outs() << "Op pre-order visits" << "\n"; op->walk(opPure); llvm::outs() << "Block pre-order visits" << "\n"; op->walk(blockPure); llvm::outs() << "Region pre-order visits" << "\n"; op->walk(regionPure); llvm::outs() << "Op post-order visits" << "\n"; op->walk(opPure); llvm::outs() << "Block post-order visits" << "\n"; op->walk(blockPure); llvm::outs() << "Region post-order visits" << "\n"; op->walk(regionPure); llvm::outs() << "Op reverse post-order visits" << "\n"; op->walk(opPure); llvm::outs() << "Block reverse post-order visits" << "\n"; op->walk(blockPure); llvm::outs() << "Region reverse post-order visits" << "\n"; op->walk(regionPure); // This test case tests "NoGraphRegions = true", so start the walk with // functions. op->walk([&](FunctionOpInterface funcOp) { llvm::outs() << "Op forward dominance post-order visits" << "\n"; funcOp->walk>(opPure); llvm::outs() << "Block forward dominance post-order visits" << "\n"; funcOp->walk>(blockPure); llvm::outs() << "Region forward dominance post-order visits" << "\n"; funcOp->walk>(regionPure); llvm::outs() << "Op reverse dominance post-order visits" << "\n"; funcOp->walk>(opPure); llvm::outs() << "Block reverse dominance post-order visits" << "\n"; funcOp->walk>(blockPure); llvm::outs() << "Region reverse dominance post-order visits" << "\n"; funcOp->walk>(regionPure); }); } /// Tests erasure callbacks that skip the walk. static void testSkipErasureCallbacks(Operation *op) { auto skipOpErasure = [](Operation *op) { // Do not erase module and module children operations. Otherwise, there // wouldn't be too much to test in pre-order. if (isa(op) || isa(op->getParentOp())) return WalkResult::advance(); llvm::outs() << "Erasing "; printOperation(op); llvm::outs() << "\n"; op->dropAllUses(); op->erase(); return WalkResult::skip(); }; auto skipBlockErasure = [](Block *block) { // Do not erase module and module children blocks. Otherwise there wouldn't // be too much to test in pre-order. Operation *parentOp = block->getParentOp(); if (isa(parentOp) || isa(parentOp->getParentOp())) return WalkResult::advance(); if (block->use_empty()) { llvm::outs() << "Erasing "; printBlock(block); llvm::outs() << "\n"; block->erase(); return WalkResult::skip(); } else { llvm::outs() << "Cannot erase "; printBlock(block); llvm::outs() << ", still has uses\n"; return WalkResult::advance(); } }; llvm::outs() << "Op pre-order erasures (skip)" << "\n"; Operation *cloned = op->clone(); cloned->walk(skipOpErasure); cloned->erase(); llvm::outs() << "Block pre-order erasures (skip)" << "\n"; cloned = op->clone(); cloned->walk(skipBlockErasure); cloned->erase(); llvm::outs() << "Op post-order erasures (skip)" << "\n"; cloned = op->clone(); cloned->walk(skipOpErasure); cloned->erase(); llvm::outs() << "Block post-order erasures (skip)" << "\n"; cloned = op->clone(); cloned->walk(skipBlockErasure); cloned->erase(); } /// Tests callbacks that erase the op or block but don't return 'Skip'. This /// callbacks are only valid in post-order. static void testNoSkipErasureCallbacks(Operation *op) { auto noSkipOpErasure = [](Operation *op) { llvm::outs() << "Erasing "; printOperation(op); llvm::outs() << "\n"; op->dropAllUses(); op->erase(); }; auto noSkipBlockErasure = [](Block *block) { if (block->use_empty()) { llvm::outs() << "Erasing "; printBlock(block); llvm::outs() << "\n"; block->erase(); } else { llvm::outs() << "Cannot erase "; printBlock(block); llvm::outs() << ", still has uses\n"; } }; llvm::outs() << "Op post-order erasures (no skip)" << "\n"; Operation *cloned = op->clone(); cloned->walk(noSkipOpErasure); llvm::outs() << "Block post-order erasures (no skip)" << "\n"; cloned = op->clone(); cloned->walk(noSkipBlockErasure); cloned->erase(); } /// Invoke region/block walks on regions/blocks. static void testBlockAndRegionWalkers(Operation *op) { auto blockPure = [](Block *block) { llvm::outs() << "Visiting "; printBlock(block); llvm::outs() << "\n"; }; auto regionPure = [](Region *region) { llvm::outs() << "Visiting "; printRegion(region); llvm::outs() << "\n"; }; llvm::outs() << "Invoke block pre-order visits on blocks\n"; op->walk([&](Operation *op) { if (!op->hasAttr("walk_blocks")) return; for (Region ®ion : op->getRegions()) { for (Block &block : region.getBlocks()) { block.walk(blockPure); } } }); llvm::outs() << "Invoke block post-order visits on blocks\n"; op->walk([&](Operation *op) { if (!op->hasAttr("walk_blocks")) return; for (Region ®ion : op->getRegions()) { for (Block &block : region.getBlocks()) { block.walk(blockPure); } } }); llvm::outs() << "Invoke region pre-order visits on region\n"; op->walk([&](Operation *op) { if (!op->hasAttr("walk_regions")) return; for (Region ®ion : op->getRegions()) { region.walk(regionPure); } }); llvm::outs() << "Invoke region post-order visits on region\n"; op->walk([&](Operation *op) { if (!op->hasAttr("walk_regions")) return; for (Region ®ion : op->getRegions()) { region.walk(regionPure); } }); } namespace { /// This pass exercises the different configurations of the IR visitors. struct TestIRVisitorsPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIRVisitorsPass) StringRef getArgument() const final { return "test-ir-visitors"; } StringRef getDescription() const final { return "Test various visitors."; } void runOnOperation() override { Operation *op = getOperation(); testPureCallbacks(op); testBlockAndRegionWalkers(op); testSkipErasureCallbacks(op); testNoSkipErasureCallbacks(op); } }; } // namespace namespace mlir { namespace test { void registerTestIRVisitorsPass() { PassRegistration(); } } // namespace test } // namespace mlir