//===- Translation.cpp - Translation registry -----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Definitions of the translation registry. // //===----------------------------------------------------------------------===// #include "mlir/Tools/mlir-translate/Translation.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Verifier.h" #include "mlir/Parser/Parser.h" #include "mlir/Tools/ParseUtilities.h" #include "llvm/Support/SourceMgr.h" #include using namespace mlir; //===----------------------------------------------------------------------===// // Translation CommandLine Options //===----------------------------------------------------------------------===// struct TranslationOptions { llvm::cl::opt noImplicitModule{ "no-implicit-module", llvm::cl::desc("Disable the parsing of an implicit top-level module op"), llvm::cl::init(false)}; }; static llvm::ManagedStatic clOptions; void mlir::registerTranslationCLOptions() { *clOptions; } //===----------------------------------------------------------------------===// // Translation Registry //===----------------------------------------------------------------------===// /// Get the mutable static map between registered file-to-file MLIR /// translations. static llvm::StringMap &getTranslationRegistry() { static llvm::StringMap translationBundle; return translationBundle; } /// Register the given translation. static void registerTranslation(StringRef name, StringRef description, std::optional inputAlignment, const TranslateFunction &function) { auto ®istry = getTranslationRegistry(); if (registry.count(name)) llvm::report_fatal_error( "Attempting to overwrite an existing function"); assert(function && "Attempting to register an empty translate function"); registry[name] = Translation(function, description, inputAlignment); } TranslateRegistration::TranslateRegistration( StringRef name, StringRef description, const TranslateFunction &function) { registerTranslation(name, description, /*inputAlignment=*/std::nullopt, function); } //===----------------------------------------------------------------------===// // Translation to MLIR //===----------------------------------------------------------------------===// // Puts `function` into the to-MLIR translation registry unless there is already // a function registered for the same name. static void registerTranslateToMLIRFunction( StringRef name, StringRef description, const DialectRegistrationFunction &dialectRegistration, std::optional inputAlignment, const TranslateSourceMgrToMLIRFunction &function) { auto wrappedFn = [function, dialectRegistration]( const std::shared_ptr &sourceMgr, raw_ostream &output, MLIRContext *context) { DialectRegistry registry; dialectRegistration(registry); context->appendDialectRegistry(registry); OwningOpRef op = function(sourceMgr, context); if (!op || failed(verify(*op))) return failure(); op.get()->print(output); return success(); }; registerTranslation(name, description, inputAlignment, wrappedFn); } TranslateToMLIRRegistration::TranslateToMLIRRegistration( StringRef name, StringRef description, const TranslateSourceMgrToMLIRFunction &function, const DialectRegistrationFunction &dialectRegistration, std::optional inputAlignment) { registerTranslateToMLIRFunction(name, description, dialectRegistration, inputAlignment, function); } TranslateToMLIRRegistration::TranslateToMLIRRegistration( StringRef name, StringRef description, const TranslateRawSourceMgrToMLIRFunction &function, const DialectRegistrationFunction &dialectRegistration, std::optional inputAlignment) { registerTranslateToMLIRFunction( name, description, dialectRegistration, inputAlignment, [function](const std::shared_ptr &sourceMgr, MLIRContext *ctx) { return function(*sourceMgr, ctx); }); } /// Wraps `function` with a lambda that extracts a StringRef from a source /// manager and registers the wrapper lambda as a to-MLIR conversion. TranslateToMLIRRegistration::TranslateToMLIRRegistration( StringRef name, StringRef description, const TranslateStringRefToMLIRFunction &function, const DialectRegistrationFunction &dialectRegistration, std::optional inputAlignment) { registerTranslateToMLIRFunction( name, description, dialectRegistration, inputAlignment, [function](const std::shared_ptr &sourceMgr, MLIRContext *ctx) { const llvm::MemoryBuffer *buffer = sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()); return function(buffer->getBuffer(), ctx); }); } //===----------------------------------------------------------------------===// // Translation from MLIR //===----------------------------------------------------------------------===// TranslateFromMLIRRegistration::TranslateFromMLIRRegistration( StringRef name, StringRef description, const TranslateFromMLIRFunction &function, const DialectRegistrationFunction &dialectRegistration) { registerTranslation( name, description, /*inputAlignment=*/std::nullopt, [function, dialectRegistration](const std::shared_ptr &sourceMgr, raw_ostream &output, MLIRContext *context) { DialectRegistry registry; dialectRegistration(registry); context->appendDialectRegistry(registry); bool implicitModule = (!clOptions.isConstructed() || !clOptions->noImplicitModule); OwningOpRef op = parseSourceFileForTool(sourceMgr, context, implicitModule); if (!op || failed(verify(*op))) return failure(); return function(op.get(), output); }); } //===----------------------------------------------------------------------===// // Translation Parser //===----------------------------------------------------------------------===// TranslationParser::TranslationParser(llvm::cl::Option &opt) : llvm::cl::parser(opt) { for (const auto &kv : getTranslationRegistry()) addLiteralOption(kv.first(), &kv.second, kv.second.getDescription()); } void TranslationParser::printOptionInfo(const llvm::cl::Option &o, size_t globalWidth) const { TranslationParser *tp = const_cast(this); llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(), [](const TranslationParser::OptionInfo *lhs, const TranslationParser::OptionInfo *rhs) { return lhs->Name.compare(rhs->Name); }); llvm::cl::parser::printOptionInfo(o, globalWidth); }