//===- TestIRVisitorsGeneric.cpp - Pass to test the Generic IR visitors ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "TestDialect.h" #include "mlir/Pass/Pass.h" using namespace mlir; static std::string getStageDescription(const WalkStage &stage) { if (stage.isBeforeAllRegions()) return "before all regions"; if (stage.isAfterAllRegions()) return "after all regions"; return "before region #" + std::to_string(stage.getNextRegion()); } namespace { /// This pass exercises generic visitor with void callbacks and prints the order /// and stage in which operations are visited. struct TestGenericIRVisitorPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGenericIRVisitorPass) StringRef getArgument() const final { return "test-generic-ir-visitors"; } StringRef getDescription() const final { return "Test generic IR visitors."; } void runOnOperation() override { Operation *outerOp = getOperation(); int stepNo = 0; outerOp->walk([&](Operation *op, const WalkStage &stage) { llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " << getStageDescription(stage) << "\n"; }); // Exercise static inference of operation type. outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) { llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " << getStageDescription(stage) << "\n"; }); } }; /// This pass exercises the generic visitor with non-void callbacks and prints /// the order and stage in which operations are visited. It will interrupt the /// walk based on attributes peesent in the IR. struct TestGenericIRVisitorInterruptPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestGenericIRVisitorInterruptPass) StringRef getArgument() const final { return "test-generic-ir-visitors-interrupt"; } StringRef getDescription() const final { return "Test generic IR visitors with interrupts."; } void runOnOperation() override { Operation *outerOp = getOperation(); int stepNo = 0; auto walker = [&](Operation *op, const WalkStage &stage) { if (auto interruptBeforeAall = op->getAttrOfType("interrupt_before_all")) if (interruptBeforeAall.getValue() && stage.isBeforeAllRegions()) return WalkResult::interrupt(); if (auto interruptAfterAll = op->getAttrOfType("interrupt_after_all")) if (interruptAfterAll.getValue() && stage.isAfterAllRegions()) return WalkResult::interrupt(); if (auto interruptAfterRegion = op->getAttrOfType("interrupt_after_region")) if (stage.isAfterRegion( static_cast(interruptAfterRegion.getInt()))) return WalkResult::interrupt(); if (auto skipBeforeAall = op->getAttrOfType("skip_before_all")) if (skipBeforeAall.getValue() && stage.isBeforeAllRegions()) return WalkResult::skip(); if (auto skipAfterAll = op->getAttrOfType("skip_after_all")) if (skipAfterAll.getValue() && stage.isAfterAllRegions()) return WalkResult::skip(); if (auto skipAfterRegion = op->getAttrOfType("skip_after_region")) if (stage.isAfterRegion(static_cast(skipAfterRegion.getInt()))) return WalkResult::skip(); llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " << getStageDescription(stage) << "\n"; return WalkResult::advance(); }; // Interrupt the walk based on attributes on the operation. auto result = outerOp->walk(walker); if (result.wasInterrupted()) llvm::outs() << "step " << stepNo++ << " walk was interrupted\n"; // Exercise static inference of operation type. result = outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) { return walker(op, stage); }); if (result.wasInterrupted()) llvm::outs() << "step " << stepNo++ << " walk was interrupted\n"; } }; struct TestGenericIRBlockVisitorInterruptPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestGenericIRBlockVisitorInterruptPass) StringRef getArgument() const final { return "test-generic-ir-block-visitors-interrupt"; } StringRef getDescription() const final { return "Test generic IR visitors with interrupts, starting with Blocks."; } void runOnOperation() override { int stepNo = 0; auto walker = [&](Block *block) { for (Operation &op : *block) if (op.getAttrOfType("interrupt")) return WalkResult::interrupt(); llvm::outs() << "step " << stepNo++ << "\n"; return WalkResult::advance(); }; auto result = getOperation()->walk(walker); if (result.wasInterrupted()) llvm::outs() << "step " << stepNo++ << " walk was interrupted\n"; } }; struct TestGenericIRRegionVisitorInterruptPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestGenericIRRegionVisitorInterruptPass) StringRef getArgument() const final { return "test-generic-ir-region-visitors-interrupt"; } StringRef getDescription() const final { return "Test generic IR visitors with interrupts, starting with Regions."; } void runOnOperation() override { int stepNo = 0; auto walker = [&](Region *region) { for (Operation &op : region->getOps()) if (op.getAttrOfType("interrupt")) return WalkResult::interrupt(); llvm::outs() << "step " << stepNo++ << "\n"; return WalkResult::advance(); }; auto result = getOperation()->walk(walker); if (result.wasInterrupted()) llvm::outs() << "step " << stepNo++ << " walk was interrupted\n"; } }; } // namespace namespace mlir { namespace test { void registerTestGenericIRVisitorsPass() { PassRegistration(); PassRegistration(); PassRegistration(); PassRegistration(); } } // namespace test } // namespace mlir