137 lines
4.8 KiB
C++
137 lines
4.8 KiB
C++
|
//===- SymbolTableTest.cpp - SymbolTable unit tests -----------------------===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
#include "mlir/IR/SymbolTable.h"
|
||
|
#include "mlir/IR/BuiltinOps.h"
|
||
|
#include "mlir/IR/Verifier.h"
|
||
|
#include "mlir/Interfaces/CallInterfaces.h"
|
||
|
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||
|
#include "mlir/Parser/Parser.h"
|
||
|
|
||
|
#include "gtest/gtest.h"
|
||
|
|
||
|
using namespace mlir;
|
||
|
|
||
|
namespace test {
|
||
|
void registerTestDialect(DialectRegistry &);
|
||
|
} // namespace test
|
||
|
|
||
|
class ReplaceAllSymbolUsesTest : public ::testing::Test {
|
||
|
protected:
|
||
|
using ReplaceFnType = llvm::function_ref<LogicalResult(
|
||
|
SymbolTable, ModuleOp, Operation *, Operation *)>;
|
||
|
|
||
|
void SetUp() override {
|
||
|
::test::registerTestDialect(registry);
|
||
|
context = std::make_unique<MLIRContext>(registry);
|
||
|
}
|
||
|
|
||
|
void testReplaceAllSymbolUses(ReplaceFnType replaceFn) {
|
||
|
// Set up IR and find func ops.
|
||
|
OwningOpRef<ModuleOp> module =
|
||
|
parseSourceString<ModuleOp>(kInput, context.get());
|
||
|
SymbolTable symbolTable(module.get());
|
||
|
auto opIterator = module->getBody(0)->getOperations().begin();
|
||
|
auto fooOp = cast<FunctionOpInterface>(opIterator++);
|
||
|
auto barOp = cast<FunctionOpInterface>(opIterator++);
|
||
|
ASSERT_EQ(fooOp.getNameAttr(), "foo");
|
||
|
ASSERT_EQ(barOp.getNameAttr(), "bar");
|
||
|
|
||
|
// Call test function that does symbol replacement.
|
||
|
LogicalResult res = replaceFn(symbolTable, module.get(), fooOp, barOp);
|
||
|
ASSERT_TRUE(succeeded(res));
|
||
|
ASSERT_TRUE(succeeded(verify(module.get())));
|
||
|
|
||
|
// Check that it got renamed.
|
||
|
bool calleeFound = false;
|
||
|
fooOp->walk([&](CallOpInterface callOp) {
|
||
|
StringAttr callee = callOp.getCallableForCallee()
|
||
|
.dyn_cast<SymbolRefAttr>()
|
||
|
.getLeafReference();
|
||
|
EXPECT_EQ(callee, "baz");
|
||
|
calleeFound = true;
|
||
|
});
|
||
|
EXPECT_TRUE(calleeFound);
|
||
|
}
|
||
|
|
||
|
std::unique_ptr<MLIRContext> context;
|
||
|
|
||
|
private:
|
||
|
constexpr static llvm::StringLiteral kInput = R"MLIR(
|
||
|
module {
|
||
|
test.conversion_func_op private @foo() {
|
||
|
"test.conversion_call_op"() { callee=@bar } : () -> ()
|
||
|
"test.return"() : () -> ()
|
||
|
}
|
||
|
test.conversion_func_op private @bar()
|
||
|
}
|
||
|
)MLIR";
|
||
|
|
||
|
DialectRegistry registry;
|
||
|
};
|
||
|
|
||
|
namespace {
|
||
|
|
||
|
TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) {
|
||
|
// Symbol as `Operation *`, rename within module.
|
||
|
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
|
||
|
auto barOp) -> LogicalResult {
|
||
|
return symbolTable.replaceAllSymbolUses(
|
||
|
barOp, StringAttr::get(context.get(), "baz"), module);
|
||
|
});
|
||
|
}
|
||
|
|
||
|
TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) {
|
||
|
// Symbol as `StringAttr`, rename within module.
|
||
|
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
|
||
|
auto barOp) -> LogicalResult {
|
||
|
return symbolTable.replaceAllSymbolUses(
|
||
|
StringAttr::get(context.get(), "bar"),
|
||
|
StringAttr::get(context.get(), "baz"), module);
|
||
|
});
|
||
|
}
|
||
|
|
||
|
TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) {
|
||
|
// Symbol as `Operation *`, rename within module body.
|
||
|
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
|
||
|
auto barOp) -> LogicalResult {
|
||
|
return symbolTable.replaceAllSymbolUses(
|
||
|
barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0));
|
||
|
});
|
||
|
}
|
||
|
|
||
|
TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) {
|
||
|
// Symbol as `StringAttr`, rename within module body.
|
||
|
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
|
||
|
auto barOp) -> LogicalResult {
|
||
|
return symbolTable.replaceAllSymbolUses(
|
||
|
StringAttr::get(context.get(), "bar"),
|
||
|
StringAttr::get(context.get(), "baz"), &module->getRegion(0));
|
||
|
});
|
||
|
}
|
||
|
|
||
|
TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) {
|
||
|
// Symbol as `Operation *`, rename within function.
|
||
|
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
|
||
|
auto barOp) -> LogicalResult {
|
||
|
return symbolTable.replaceAllSymbolUses(
|
||
|
barOp, StringAttr::get(context.get(), "baz"), fooOp);
|
||
|
});
|
||
|
}
|
||
|
|
||
|
TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) {
|
||
|
// Symbol as `StringAttr`, rename within function.
|
||
|
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
|
||
|
auto barOp) -> LogicalResult {
|
||
|
return symbolTable.replaceAllSymbolUses(
|
||
|
StringAttr::get(context.get(), "bar"),
|
||
|
StringAttr::get(context.get(), "baz"), fooOp);
|
||
|
});
|
||
|
}
|
||
|
|
||
|
} // namespace
|