//===- Preload.cpp - Test MlirOptMain parameterization ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/Utils.h" #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/Verifier.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/TypeID.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/raw_ostream.h" #include "gtest/gtest.h" using namespace mlir; namespace mlir { namespace test { std::unique_ptr createTestTransformDialectInterpreterPass(); } // namespace test } // namespace mlir const static llvm::StringLiteral library = R"MLIR( module attributes {transform.with_named_sequence} { transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly}) { transform.debug.emit_remark_at %arg0, "from external symbol" : !transform.any_op transform.yield } })MLIR"; const static llvm::StringLiteral input = R"MLIR( module attributes {transform.with_named_sequence} { transform.named_sequence private @__transform_main(%arg0: !transform.any_op {transform.readonly}) transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): include @__transform_main failures(propagate) (%arg0) : (!transform.any_op) -> () } })MLIR"; TEST(Preload, ContextPreloadConstructedLibrary) { registerPassManagerCLOptions(); MLIRContext context; auto *dialect = context.getOrLoadDialect(); DialectRegistry registry; mlir::transform::registerDebugExtension(registry); registry.applyExtensions(&context); ParserConfig parserConfig(&context); OwningOpRef inputModule = parseSourceString(input, parserConfig, ""); EXPECT_TRUE(inputModule) << "failed to parse input module"; OwningOpRef transformLibrary = parseSourceString(library, parserConfig, ""); EXPECT_TRUE(transformLibrary) << "failed to parse transform module"; LogicalResult diag = dialect->loadIntoLibraryModule(std::move(transformLibrary)); EXPECT_TRUE(succeeded(diag)); ModuleOp retrievedTransformLibrary = transform::detail::getPreloadedTransformModule(&context); EXPECT_TRUE(retrievedTransformLibrary) << "failed to retrieve transform module"; OwningOpRef clonedTransformModule( retrievedTransformLibrary->clone()); LogicalResult res = transform::detail::mergeSymbolsInto( inputModule->getOperation(), std::move(clonedTransformModule)); EXPECT_TRUE(succeeded(res)) << "failed to define declared symbols"; transform::TransformOpInterface entryPoint = transform::detail::findTransformEntryPoint(inputModule->getOperation(), retrievedTransformLibrary); EXPECT_TRUE(entryPoint) << "failed to find entry point"; transform::TransformOptions options; res = transform::applyTransformNamedSequence( inputModule->getOperation(), entryPoint, retrievedTransformLibrary, options); EXPECT_TRUE(succeeded(res)) << "failed to apply named sequence"; }