95 lines
3.6 KiB
C++
95 lines
3.6 KiB
C++
//===- Preload.cpp - Test MlirOptMain parameterization ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/IR/Utils.h"
|
|
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h"
|
|
#include "mlir/IR/AsmState.h"
|
|
#include "mlir/IR/DialectRegistry.h"
|
|
#include "mlir/IR/Verifier.h"
|
|
#include "mlir/Parser/Parser.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Support/FileUtilities.h"
|
|
#include "mlir/Support/TypeID.h"
|
|
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
|
|
#include "llvm/Support/MemoryBuffer.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include "gtest/gtest.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
std::unique_ptr<Pass> createTestTransformDialectInterpreterPass();
|
|
} // namespace test
|
|
} // namespace mlir
|
|
|
|
const static llvm::StringLiteral library = R"MLIR(
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
|
transform.debug.emit_remark_at %arg0, "from external symbol" : !transform.any_op
|
|
transform.yield
|
|
}
|
|
})MLIR";
|
|
|
|
const static llvm::StringLiteral input = R"MLIR(
|
|
module attributes {transform.with_named_sequence} {
|
|
transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly})
|
|
|
|
transform.sequence failures(propagate) {
|
|
^bb0(%arg0: !transform.any_op):
|
|
include @__transform_main failures(propagate) (%arg0) : (!transform.any_op) -> ()
|
|
}
|
|
})MLIR";
|
|
|
|
TEST(Preload, ContextPreloadConstructedLibrary) {
|
|
registerPassManagerCLOptions();
|
|
|
|
MLIRContext context;
|
|
auto *dialect = context.getOrLoadDialect<transform::TransformDialect>();
|
|
DialectRegistry registry;
|
|
mlir::transform::registerDebugExtension(registry);
|
|
registry.applyExtensions(&context);
|
|
ParserConfig parserConfig(&context);
|
|
|
|
OwningOpRef<ModuleOp> inputModule =
|
|
parseSourceString<ModuleOp>(input, parserConfig, "<input>");
|
|
EXPECT_TRUE(inputModule) << "failed to parse input module";
|
|
|
|
OwningOpRef<ModuleOp> transformLibrary =
|
|
parseSourceString<ModuleOp>(library, parserConfig, "<transform-library>");
|
|
EXPECT_TRUE(transformLibrary) << "failed to parse transform module";
|
|
LogicalResult diag =
|
|
dialect->loadIntoLibraryModule(std::move(transformLibrary));
|
|
EXPECT_TRUE(succeeded(diag));
|
|
|
|
ModuleOp retrievedTransformLibrary =
|
|
transform::detail::getPreloadedTransformModule(&context);
|
|
EXPECT_TRUE(retrievedTransformLibrary)
|
|
<< "failed to retrieve transform module";
|
|
|
|
OwningOpRef<Operation *> clonedTransformModule(
|
|
retrievedTransformLibrary->clone());
|
|
|
|
LogicalResult res = transform::detail::mergeSymbolsInto(
|
|
inputModule->getOperation(), std::move(clonedTransformModule));
|
|
EXPECT_TRUE(succeeded(res)) << "failed to define declared symbols";
|
|
|
|
transform::TransformOpInterface entryPoint =
|
|
transform::detail::findTransformEntryPoint(inputModule->getOperation(),
|
|
retrievedTransformLibrary);
|
|
EXPECT_TRUE(entryPoint) << "failed to find entry point";
|
|
|
|
transform::TransformOptions options;
|
|
res = transform::applyTransformNamedSequence(
|
|
inputModule->getOperation(), entryPoint, retrievedTransformLibrary,
|
|
options);
|
|
EXPECT_TRUE(succeeded(res)) << "failed to apply named sequence";
|
|
}
|