bolt/deps/llvm-18.1.8/mlir/unittests/IR/SymbolTableTest.cpp
2025-02-14 19:21:04 +01:00

136 lines
4.8 KiB
C++

//===- SymbolTableTest.cpp - SymbolTable unit tests -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Parser/Parser.h"
#include "gtest/gtest.h"
using namespace mlir;
namespace test {
void registerTestDialect(DialectRegistry &);
} // namespace test
class ReplaceAllSymbolUsesTest : public ::testing::Test {
protected:
using ReplaceFnType = llvm::function_ref<LogicalResult(
SymbolTable, ModuleOp, Operation *, Operation *)>;
void SetUp() override {
::test::registerTestDialect(registry);
context = std::make_unique<MLIRContext>(registry);
}
void testReplaceAllSymbolUses(ReplaceFnType replaceFn) {
// Set up IR and find func ops.
OwningOpRef<ModuleOp> module =
parseSourceString<ModuleOp>(kInput, context.get());
SymbolTable symbolTable(module.get());
auto opIterator = module->getBody(0)->getOperations().begin();
auto fooOp = cast<FunctionOpInterface>(opIterator++);
auto barOp = cast<FunctionOpInterface>(opIterator++);
ASSERT_EQ(fooOp.getNameAttr(), "foo");
ASSERT_EQ(barOp.getNameAttr(), "bar");
// Call test function that does symbol replacement.
LogicalResult res = replaceFn(symbolTable, module.get(), fooOp, barOp);
ASSERT_TRUE(succeeded(res));
ASSERT_TRUE(succeeded(verify(module.get())));
// Check that it got renamed.
bool calleeFound = false;
fooOp->walk([&](CallOpInterface callOp) {
StringAttr callee = callOp.getCallableForCallee()
.dyn_cast<SymbolRefAttr>()
.getLeafReference();
EXPECT_EQ(callee, "baz");
calleeFound = true;
});
EXPECT_TRUE(calleeFound);
}
std::unique_ptr<MLIRContext> context;
private:
constexpr static llvm::StringLiteral kInput = R"MLIR(
module {
test.conversion_func_op private @foo() {
"test.conversion_call_op"() { callee=@bar } : () -> ()
"test.return"() : () -> ()
}
test.conversion_func_op private @bar()
}
)MLIR";
DialectRegistry registry;
};
namespace {
TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) {
// Symbol as `Operation *`, rename within module.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
barOp, StringAttr::get(context.get(), "baz"), module);
});
}
TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) {
// Symbol as `StringAttr`, rename within module.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
StringAttr::get(context.get(), "bar"),
StringAttr::get(context.get(), "baz"), module);
});
}
TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) {
// Symbol as `Operation *`, rename within module body.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0));
});
}
TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) {
// Symbol as `StringAttr`, rename within module body.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
StringAttr::get(context.get(), "bar"),
StringAttr::get(context.get(), "baz"), &module->getRegion(0));
});
}
TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) {
// Symbol as `Operation *`, rename within function.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
barOp, StringAttr::get(context.get(), "baz"), fooOp);
});
}
TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) {
// Symbol as `StringAttr`, rename within function.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
StringAttr::get(context.get(), "bar"),
StringAttr::get(context.get(), "baz"), fooOp);
});
}
} // namespace