//===- SymbolTableTest.cpp - SymbolTable unit tests -----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/IR/SymbolTable.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Parser/Parser.h" #include "gtest/gtest.h" using namespace mlir; namespace test { void registerTestDialect(DialectRegistry &); } // namespace test class ReplaceAllSymbolUsesTest : public ::testing::Test { protected: using ReplaceFnType = llvm::function_ref; void SetUp() override { ::test::registerTestDialect(registry); context = std::make_unique(registry); } void testReplaceAllSymbolUses(ReplaceFnType replaceFn) { // Set up IR and find func ops. OwningOpRef module = parseSourceString(kInput, context.get()); SymbolTable symbolTable(module.get()); auto opIterator = module->getBody(0)->getOperations().begin(); auto fooOp = cast(opIterator++); auto barOp = cast(opIterator++); ASSERT_EQ(fooOp.getNameAttr(), "foo"); ASSERT_EQ(barOp.getNameAttr(), "bar"); // Call test function that does symbol replacement. LogicalResult res = replaceFn(symbolTable, module.get(), fooOp, barOp); ASSERT_TRUE(succeeded(res)); ASSERT_TRUE(succeeded(verify(module.get()))); // Check that it got renamed. bool calleeFound = false; fooOp->walk([&](CallOpInterface callOp) { StringAttr callee = callOp.getCallableForCallee() .dyn_cast() .getLeafReference(); EXPECT_EQ(callee, "baz"); calleeFound = true; }); EXPECT_TRUE(calleeFound); } std::unique_ptr context; private: constexpr static llvm::StringLiteral kInput = R"MLIR( module { test.conversion_func_op private @foo() { "test.conversion_call_op"() { callee=@bar } : () -> () "test.return"() : () -> () } test.conversion_func_op private @bar() } )MLIR"; DialectRegistry registry; }; namespace { TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) { // Symbol as `Operation *`, rename within module. testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( barOp, StringAttr::get(context.get(), "baz"), module); }); } TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) { // Symbol as `StringAttr`, rename within module. testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( StringAttr::get(context.get(), "bar"), StringAttr::get(context.get(), "baz"), module); }); } TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) { // Symbol as `Operation *`, rename within module body. testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0)); }); } TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) { // Symbol as `StringAttr`, rename within module body. testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( StringAttr::get(context.get(), "bar"), StringAttr::get(context.get(), "baz"), &module->getRegion(0)); }); } TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) { // Symbol as `Operation *`, rename within function. testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( barOp, StringAttr::get(context.get(), "baz"), fooOp); }); } TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) { // Symbol as `StringAttr`, rename within function. testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, auto barOp) -> LogicalResult { return symbolTable.replaceAllSymbolUses( StringAttr::get(context.get(), "bar"), StringAttr::get(context.get(), "baz"), fooOp); }); } } // namespace