//===- Context.cpp --------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Tools/PDLL/ODS/Context.h" #include "mlir/Tools/PDLL/ODS/Constraint.h" #include "mlir/Tools/PDLL/ODS/Dialect.h" #include "mlir/Tools/PDLL/ODS/Operation.h" #include "llvm/Support/ScopedPrinter.h" #include "llvm/Support/raw_ostream.h" #include using namespace mlir; using namespace mlir::pdll::ods; //===----------------------------------------------------------------------===// // Context //===----------------------------------------------------------------------===// Context::Context() = default; Context::~Context() = default; const AttributeConstraint & Context::insertAttributeConstraint(StringRef name, StringRef summary, StringRef cppClass) { std::unique_ptr &constraint = attributeConstraints[name]; if (!constraint) { constraint.reset(new AttributeConstraint(name, summary, cppClass)); } else { assert(constraint->getCppClass() == cppClass && constraint->getSummary() == summary && "constraint with the same name was already registered with a " "different class"); } return *constraint; } const TypeConstraint &Context::insertTypeConstraint(StringRef name, StringRef summary, StringRef cppClass) { std::unique_ptr &constraint = typeConstraints[name]; if (!constraint) constraint.reset(new TypeConstraint(name, summary, cppClass)); return *constraint; } Dialect &Context::insertDialect(StringRef name) { std::unique_ptr &dialect = dialects[name]; if (!dialect) dialect.reset(new Dialect(name)); return *dialect; } const Dialect *Context::lookupDialect(StringRef name) const { auto it = dialects.find(name); return it == dialects.end() ? nullptr : &*it->second; } std::pair Context::insertOperation(StringRef name, StringRef summary, StringRef desc, StringRef nativeClassName, bool supportsResultTypeInferrence, SMLoc loc) { std::pair dialectAndName = name.split('.'); return insertDialect(dialectAndName.first) .insertOperation(name, summary, desc, nativeClassName, supportsResultTypeInferrence, loc); } const Operation *Context::lookupOperation(StringRef name) const { std::pair dialectAndName = name.split('.'); if (const Dialect *dialect = lookupDialect(dialectAndName.first)) return dialect->lookupOperation(name); return nullptr; } template SmallVector sortMapByName(const llvm::StringMap> &map) { SmallVector storage; for (auto &entry : map) storage.push_back(entry.second.get()); llvm::sort(storage, [](const auto &lhs, const auto &rhs) { return lhs->getName() < rhs->getName(); }); return storage; } void Context::print(raw_ostream &os) const { auto printVariableLengthCst = [&](StringRef cst, VariableLengthKind kind) { switch (kind) { case VariableLengthKind::Optional: os << "Optional<" << cst << ">"; break; case VariableLengthKind::Single: os << cst; break; case VariableLengthKind::Variadic: os << "Variadic<" << cst << ">"; break; } }; llvm::ScopedPrinter printer(os); llvm::DictScope odsScope(printer, "ODSContext"); for (const Dialect *dialect : sortMapByName(dialects)) { printer.startLine() << "Dialect `" << dialect->getName() << "` {\n"; printer.indent(); for (const Operation *op : sortMapByName(dialect->getOperations())) { printer.startLine() << "Operation `" << op->getName() << "` {\n"; printer.indent(); // Attributes. ArrayRef attributes = op->getAttributes(); if (!attributes.empty()) { printer.startLine() << "Attributes { "; llvm::interleaveComma(attributes, os, [&](const Attribute &attr) { os << attr.getName() << " : "; auto kind = attr.isOptional() ? VariableLengthKind::Optional : VariableLengthKind::Single; printVariableLengthCst(attr.getConstraint().getDemangledName(), kind); }); os << " }\n"; } // Operands. ArrayRef operands = op->getOperands(); if (!operands.empty()) { printer.startLine() << "Operands { "; llvm::interleaveComma( operands, os, [&](const OperandOrResult &operand) { os << operand.getName() << " : "; printVariableLengthCst(operand.getConstraint().getDemangledName(), operand.getVariableLengthKind()); }); os << " }\n"; } // Results. ArrayRef results = op->getResults(); if (!results.empty()) { printer.startLine() << "Results { "; llvm::interleaveComma(results, os, [&](const OperandOrResult &result) { os << result.getName() << " : "; printVariableLengthCst(result.getConstraint().getDemangledName(), result.getVariableLengthKind()); }); os << " }\n"; } printer.objectEnd(); } printer.objectEnd(); } for (const AttributeConstraint *cst : sortMapByName(attributeConstraints)) { printer.startLine() << "AttributeConstraint `" << cst->getDemangledName() << "` {\n"; printer.indent(); printer.startLine() << "Summary: " << cst->getSummary() << "\n"; printer.startLine() << "CppClass: " << cst->getCppClass() << "\n"; printer.objectEnd(); } for (const TypeConstraint *cst : sortMapByName(typeConstraints)) { printer.startLine() << "TypeConstraint `" << cst->getDemangledName() << "` {\n"; printer.indent(); printer.startLine() << "Summary: " << cst->getSummary() << "\n"; printer.startLine() << "CppClass: " << cst->getCppClass() << "\n"; printer.objectEnd(); } printer.objectEnd(); }