140 lines
4.2 KiB
C++
140 lines
4.2 KiB
C++
|
//===- RegistryManager.cpp - Matcher registry -----------------------------===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
//
|
||
|
// Registry map populated at static initialization time.
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "RegistryManager.h"
|
||
|
#include "mlir/Query/Matcher/Registry.h"
|
||
|
|
||
|
#include <set>
|
||
|
#include <utility>
|
||
|
|
||
|
namespace mlir::query::matcher {
|
||
|
namespace {
|
||
|
|
||
|
// This is needed because these matchers are defined as overloaded functions.
|
||
|
using IsConstantOp = detail::constant_op_matcher();
|
||
|
using HasOpAttrName = detail::AttrOpMatcher(llvm::StringRef);
|
||
|
using HasOpName = detail::NameOpMatcher(llvm::StringRef);
|
||
|
|
||
|
// Enum to string for autocomplete.
|
||
|
static std::string asArgString(ArgKind kind) {
|
||
|
switch (kind) {
|
||
|
case ArgKind::Matcher:
|
||
|
return "Matcher";
|
||
|
case ArgKind::String:
|
||
|
return "String";
|
||
|
}
|
||
|
llvm_unreachable("Unhandled ArgKind");
|
||
|
}
|
||
|
|
||
|
} // namespace
|
||
|
|
||
|
void Registry::registerMatcherDescriptor(
|
||
|
llvm::StringRef matcherName,
|
||
|
std::unique_ptr<internal::MatcherDescriptor> callback) {
|
||
|
assert(!constructorMap.contains(matcherName));
|
||
|
constructorMap[matcherName] = std::move(callback);
|
||
|
}
|
||
|
|
||
|
std::optional<MatcherCtor>
|
||
|
RegistryManager::lookupMatcherCtor(llvm::StringRef matcherName,
|
||
|
const Registry &matcherRegistry) {
|
||
|
auto it = matcherRegistry.constructors().find(matcherName);
|
||
|
return it == matcherRegistry.constructors().end()
|
||
|
? std::optional<MatcherCtor>()
|
||
|
: it->second.get();
|
||
|
}
|
||
|
|
||
|
std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes(
|
||
|
llvm::ArrayRef<std::pair<MatcherCtor, unsigned>> context) {
|
||
|
// Starting with the above seed of acceptable top-level matcher types, compute
|
||
|
// the acceptable type set for the argument indicated by each context element.
|
||
|
std::set<ArgKind> typeSet;
|
||
|
typeSet.insert(ArgKind::Matcher);
|
||
|
|
||
|
for (const auto &ctxEntry : context) {
|
||
|
MatcherCtor ctor = ctxEntry.first;
|
||
|
unsigned argNumber = ctxEntry.second;
|
||
|
std::vector<ArgKind> nextTypeSet;
|
||
|
|
||
|
if (argNumber < ctor->getNumArgs())
|
||
|
ctor->getArgKinds(argNumber, nextTypeSet);
|
||
|
|
||
|
typeSet.insert(nextTypeSet.begin(), nextTypeSet.end());
|
||
|
}
|
||
|
|
||
|
return std::vector<ArgKind>(typeSet.begin(), typeSet.end());
|
||
|
}
|
||
|
|
||
|
std::vector<MatcherCompletion>
|
||
|
RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
|
||
|
const Registry &matcherRegistry) {
|
||
|
std::vector<MatcherCompletion> completions;
|
||
|
|
||
|
// Search the registry for acceptable matchers.
|
||
|
for (const auto &m : matcherRegistry.constructors()) {
|
||
|
const internal::MatcherDescriptor &matcher = *m.getValue();
|
||
|
llvm::StringRef name = m.getKey();
|
||
|
|
||
|
unsigned numArgs = matcher.getNumArgs();
|
||
|
std::vector<std::vector<ArgKind>> argKinds(numArgs);
|
||
|
|
||
|
for (const ArgKind &kind : acceptedTypes) {
|
||
|
if (kind != ArgKind::Matcher)
|
||
|
continue;
|
||
|
|
||
|
for (unsigned arg = 0; arg != numArgs; ++arg)
|
||
|
matcher.getArgKinds(arg, argKinds[arg]);
|
||
|
}
|
||
|
|
||
|
std::string decl;
|
||
|
llvm::raw_string_ostream os(decl);
|
||
|
|
||
|
std::string typedText = std::string(name);
|
||
|
os << "Matcher: " << name << "(";
|
||
|
|
||
|
for (const std::vector<ArgKind> &arg : argKinds) {
|
||
|
if (&arg != &argKinds[0])
|
||
|
os << ", ";
|
||
|
|
||
|
bool firstArgKind = true;
|
||
|
// Two steps. First all non-matchers, then matchers only.
|
||
|
for (const ArgKind &argKind : arg) {
|
||
|
if (!firstArgKind)
|
||
|
os << "|";
|
||
|
|
||
|
firstArgKind = false;
|
||
|
os << asArgString(argKind);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
os << ")";
|
||
|
typedText += "(";
|
||
|
|
||
|
if (argKinds.empty())
|
||
|
typedText += ")";
|
||
|
else if (argKinds[0][0] == ArgKind::String)
|
||
|
typedText += "\"";
|
||
|
|
||
|
completions.emplace_back(typedText, os.str());
|
||
|
}
|
||
|
|
||
|
return completions;
|
||
|
}
|
||
|
|
||
|
VariantMatcher RegistryManager::constructMatcher(
|
||
|
MatcherCtor ctor, internal::SourceRange nameRange,
|
||
|
llvm::ArrayRef<ParserValue> args, internal::Diagnostics *error) {
|
||
|
return ctor->create(nameRange, args, error);
|
||
|
}
|
||
|
|
||
|
} // namespace mlir::query::matcher
|