//===- PassManagerTest.cpp - PassManager unit tests -----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Pass/PassManager.h" #include "mlir/Debug/BreakpointManagers/TagBreakpointManager.h" #include "mlir/Debug/ExecutionContext.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/Pass/Pass.h" #include "gtest/gtest.h" #include using namespace mlir; using namespace mlir::detail; namespace { /// Analysis that operates on any operation. struct GenericAnalysis { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GenericAnalysis) GenericAnalysis(Operation *op) : isFunc(isa(op)) {} const bool isFunc; }; /// Analysis that operates on a specific operation. struct OpSpecificAnalysis { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpSpecificAnalysis) OpSpecificAnalysis(func::FuncOp op) : isSecret(op.getName() == "secret") {} const bool isSecret; }; /// Simple pass to annotate a func::FuncOp with the results of analysis. struct AnnotateFunctionPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AnnotateFunctionPass) void runOnOperation() override { func::FuncOp op = getOperation(); Builder builder(op->getParentOfType()); auto &ga = getAnalysis(); auto &sa = getAnalysis(); op->setAttr("isFunc", builder.getBoolAttr(ga.isFunc)); op->setAttr("isSecret", builder.getBoolAttr(sa.isSecret)); } }; TEST(PassManagerTest, OpSpecificAnalysis) { MLIRContext context; context.loadDialect(); Builder builder(&context); // Create a module with 2 functions. OwningOpRef module(ModuleOp::create(UnknownLoc::get(&context))); for (StringRef name : {"secret", "not_secret"}) { auto func = func::FuncOp::create( builder.getUnknownLoc(), name, builder.getFunctionType(std::nullopt, std::nullopt)); func.setPrivate(); module->push_back(func); } // Instantiate and run our pass. auto pm = PassManager::on(&context); pm.addNestedPass(std::make_unique()); LogicalResult result = pm.run(module.get()); EXPECT_TRUE(succeeded(result)); // Verify that each function got annotated with expected attributes. for (func::FuncOp func : module->getOps()) { ASSERT_TRUE(isa(func->getDiscardableAttr("isFunc"))); EXPECT_TRUE(cast(func->getDiscardableAttr("isFunc")).getValue()); bool isSecret = func.getName() == "secret"; ASSERT_TRUE(isa(func->getDiscardableAttr("isSecret"))); EXPECT_EQ(cast(func->getDiscardableAttr("isSecret")).getValue(), isSecret); } } /// Simple pass to annotate a func::FuncOp with a single attribute `didProcess`. struct AddAttrFunctionPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddAttrFunctionPass) void runOnOperation() override { func::FuncOp op = getOperation(); Builder builder(op->getParentOfType()); if (op->hasAttr("didProcess")) op->setAttr("didProcessAgain", builder.getUnitAttr()); // We always want to set this one. op->setAttr("didProcess", builder.getUnitAttr()); } }; /// Simple pass to annotate a func::FuncOp with a single attribute /// `didProcess2`. struct AddSecondAttrFunctionPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddSecondAttrFunctionPass) void runOnOperation() override { func::FuncOp op = getOperation(); Builder builder(op->getParentOfType()); op->setAttr("didProcess2", builder.getUnitAttr()); } }; TEST(PassManagerTest, ExecutionAction) { MLIRContext context; context.loadDialect(); Builder builder(&context); // Create a module with 2 functions. OwningOpRef module(ModuleOp::create(UnknownLoc::get(&context))); auto f = func::FuncOp::create(builder.getUnknownLoc(), "process_me_once", builder.getFunctionType(std::nullopt, std::nullopt)); f.setPrivate(); module->push_back(f); // Instantiate our passes. auto pm = PassManager::on(&context); auto pass = std::make_unique(); auto *passPtr = pass.get(); pm.addNestedPass(std::move(pass)); pm.addNestedPass(std::make_unique()); // Duplicate the first pass to ensure that we *only* run the *first* pass, not // all instances of this pass kind. Notice that this pass (and the test as a // whole) are built to ensure that we can run just a single pass out of a // pipeline that may contain duplicates. pm.addNestedPass(std::make_unique()); // Use the action manager to only hit the first pass, not the second one. auto onBreakpoint = [&](const tracing::ActionActiveStack *backtrace) -> tracing::ExecutionContext::Control { // Not a PassExecutionAction, apply the action. auto *passExec = dyn_cast(&backtrace->getAction()); if (!passExec) return tracing::ExecutionContext::Next; // If this isn't a function, apply the action. if (!isa(passExec->getOp())) return tracing::ExecutionContext::Next; // Only apply the first function pass. Not all instances of the first pass, // only the first pass. if (passExec->getPass().getThreadingSiblingOrThis() == passPtr) return tracing::ExecutionContext::Next; // Do not apply any other passes in the pass manager. return tracing::ExecutionContext::Skip; }; // Set up our breakpoint manager. tracing::TagBreakpointManager simpleManager; tracing::ExecutionContext executionCtx(onBreakpoint); executionCtx.addBreakpointManager(&simpleManager); simpleManager.addBreakpoint(PassExecutionAction::tag); // Register the execution context in the MLIRContext. context.registerActionHandler(executionCtx); // Run the pass manager, expecting our handler to be called. LogicalResult result = pm.run(module.get()); EXPECT_TRUE(succeeded(result)); // Verify that each function got annotated with `didProcess` and *not* // `didProcess2`. for (func::FuncOp func : module->getOps()) { ASSERT_TRUE(func->getDiscardableAttr("didProcess")); ASSERT_FALSE(func->getDiscardableAttr("didProcess2")); ASSERT_FALSE(func->getDiscardableAttr("didProcessAgain")); } } namespace { struct InvalidPass : Pass { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InvalidPass) InvalidPass() : Pass(TypeID::get(), StringRef("invalid_op")) {} StringRef getName() const override { return "Invalid Pass"; } void runOnOperation() override {} bool canScheduleOn(RegisteredOperationName opName) const override { return true; } /// A clone method to create a copy of this pass. std::unique_ptr clonePass() const override { return std::make_unique( *static_cast(this)); } }; } // namespace TEST(PassManagerTest, InvalidPass) { MLIRContext context; context.allowUnregisteredDialects(); // Create a module OwningOpRef module(ModuleOp::create(UnknownLoc::get(&context))); // Add a single "invalid_op" operation OpBuilder builder(&module->getBodyRegion()); OperationState state(UnknownLoc::get(&context), "invalid_op"); builder.insert(Operation::create(state)); // Register a diagnostic handler to capture the diagnostic so that we can // check it later. std::unique_ptr diagnostic; context.getDiagEngine().registerHandler([&](Diagnostic &diag) { diagnostic = std::make_unique(std::move(diag)); }); // Instantiate and run our pass. auto pm = PassManager::on(&context); pm.nest("invalid_op").addPass(std::make_unique()); LogicalResult result = pm.run(module.get()); EXPECT_TRUE(failed(result)); ASSERT_TRUE(diagnostic.get() != nullptr); EXPECT_EQ( diagnostic->str(), "'invalid_op' op trying to schedule a pass on an unregistered operation"); // Check that clearing the pass manager effectively removed the pass. pm.clear(); result = pm.run(module.get()); EXPECT_TRUE(succeeded(result)); // Check that adding the pass at the top-level triggers a fatal error. ASSERT_DEATH(pm.addPass(std::make_unique()), "Can't add pass 'Invalid Pass' restricted to 'invalid_op' on a " "PassManager intended to run on 'builtin.module', did you " "intend to nest?"); } /// Simple pass to annotate a func::FuncOp with the results of analysis. struct InitializeCheckingPass : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InitializeCheckingPass) LogicalResult initialize(MLIRContext *ctx) final { initialized = true; return success(); } bool initialized = false; void runOnOperation() override { if (!initialized) { getOperation()->emitError() << "Pass isn't initialized!"; signalPassFailure(); } } }; TEST(PassManagerTest, PassInitialization) { MLIRContext context; context.allowUnregisteredDialects(); // Create a module OwningOpRef module(ModuleOp::create(UnknownLoc::get(&context))); // Instantiate and run our pass. auto pm = PassManager::on(&context); pm.addPass(std::make_unique()); EXPECT_TRUE(succeeded(pm.run(module.get()))); // Adding a second copy of the pass, we should also initialize it! pm.addPass(std::make_unique()); EXPECT_TRUE(succeeded(pm.run(module.get()))); } } // namespace