486 lines
17 KiB
C++
486 lines
17 KiB
C++
//===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Support/IndentedOstream.h"
|
|
#include "mlir/TableGen/GenInfo.h"
|
|
#include "llvm/ADT/MapVector.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/Support/CommandLine.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include "llvm/TableGen/Error.h"
|
|
#include "llvm/TableGen/Record.h"
|
|
#include <regex>
|
|
|
|
using namespace llvm;
|
|
|
|
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
|
|
static llvm::cl::opt<std::string>
|
|
selectedBcDialect("bytecode-dialect",
|
|
llvm::cl::desc("The dialect to gen for"),
|
|
llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
|
|
|
|
namespace {
|
|
|
|
/// Helper class to generate C++ bytecode parser helpers.
|
|
class Generator {
|
|
public:
|
|
Generator(raw_ostream &output) : output(output) {}
|
|
|
|
/// Returns whether successfully emitted attribute/type parsers.
|
|
void emitParse(StringRef kind, Record &x);
|
|
|
|
/// Returns whether successfully emitted attribute/type printers.
|
|
void emitPrint(StringRef kind, StringRef type,
|
|
ArrayRef<std::pair<int64_t, Record *>> vec);
|
|
|
|
/// Emits parse dispatch table.
|
|
void emitParseDispatch(StringRef kind, ArrayRef<Record *> vec);
|
|
|
|
/// Emits print dispatch table.
|
|
void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec);
|
|
|
|
private:
|
|
/// Emits parse calls to construct given kind.
|
|
void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder,
|
|
ArrayRef<Init *> args, ArrayRef<std::string> argNames,
|
|
StringRef failure, mlir::raw_indented_ostream &ios);
|
|
|
|
/// Emits print instructions.
|
|
void emitPrintHelper(Record *memberRec, StringRef kind, StringRef parent,
|
|
StringRef name, mlir::raw_indented_ostream &ios);
|
|
|
|
raw_ostream &output;
|
|
};
|
|
} // namespace
|
|
|
|
/// Helper to replace set of from strings to target in `s`.
|
|
/// Assumed: non-overlapping replacements.
|
|
static std::string format(StringRef templ,
|
|
std::map<std::string, std::string> &&map) {
|
|
std::string s = templ.str();
|
|
for (const auto &[from, to] : map)
|
|
// All replacements start with $, don't treat as anchor.
|
|
s = std::regex_replace(s, std::regex("\\" + from), to);
|
|
return s;
|
|
}
|
|
|
|
/// Return string with first character capitalized.
|
|
static std::string capitalize(StringRef str) {
|
|
return ((Twine)toUpper(str[0]) + str.drop_front()).str();
|
|
}
|
|
|
|
/// Return the C++ type for the given record.
|
|
static std::string getCType(Record *def) {
|
|
std::string format = "{0}";
|
|
if (def->isSubClassOf("Array")) {
|
|
def = def->getValueAsDef("elemT");
|
|
format = "SmallVector<{0}>";
|
|
}
|
|
|
|
StringRef cType = def->getValueAsString("cType");
|
|
if (cType.empty()) {
|
|
if (def->isAnonymous())
|
|
PrintFatalError(def->getLoc(), "Unable to determine cType");
|
|
|
|
return formatv(format.c_str(), def->getName().str());
|
|
}
|
|
return formatv(format.c_str(), cType.str());
|
|
}
|
|
|
|
void Generator::emitParseDispatch(StringRef kind, ArrayRef<Record *> vec) {
|
|
mlir::raw_indented_ostream os(output);
|
|
char const *head =
|
|
R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))";
|
|
os << formatv(head, capitalize(kind));
|
|
auto funScope = os.scope(" {\n", "}\n\n");
|
|
|
|
if (vec.empty()) {
|
|
os << "return reader.emitError() << \"unknown attribute\", "
|
|
<< capitalize(kind) << "();\n";
|
|
return;
|
|
}
|
|
|
|
os << "uint64_t kind;\n";
|
|
os << "if (failed(reader.readVarInt(kind)))\n"
|
|
<< " return " << capitalize(kind) << "();\n";
|
|
os << "switch (kind) ";
|
|
{
|
|
auto switchScope = os.scope("{\n", "}\n");
|
|
for (const auto &it : llvm::enumerate(vec)) {
|
|
if (it.value()->getName() == "ReservedOrDead")
|
|
continue;
|
|
|
|
os << formatv("case {1}:\n return read{0}(context, reader);\n",
|
|
it.value()->getName(), it.index());
|
|
}
|
|
os << "default:\n"
|
|
<< " reader.emitError() << \"unknown attribute code: \" "
|
|
<< "<< kind;\n"
|
|
<< " return " << capitalize(kind) << "();\n";
|
|
}
|
|
os << "return " << capitalize(kind) << "();\n";
|
|
}
|
|
|
|
void Generator::emitParse(StringRef kind, Record &x) {
|
|
if (x.getNameInitAsString() == "ReservedOrDead")
|
|
return;
|
|
|
|
char const *head =
|
|
R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
|
|
mlir::raw_indented_ostream os(output);
|
|
std::string returnType = getCType(&x);
|
|
os << formatv(head, kind == "attribute" ? "::mlir::Attribute" : "::mlir::Type", x.getName());
|
|
DagInit *members = x.getValueAsDag("members");
|
|
SmallVector<std::string> argNames =
|
|
llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) {
|
|
return init->getAsUnquotedString();
|
|
}));
|
|
StringRef builder = x.getValueAsString("cBuilder").trim();
|
|
emitParseHelper(kind, returnType, builder, members->getArgs(), argNames,
|
|
returnType + "()", os);
|
|
os << "\n\n";
|
|
}
|
|
|
|
void printParseConditional(mlir::raw_indented_ostream &ios,
|
|
ArrayRef<Init *> args,
|
|
ArrayRef<std::string> argNames) {
|
|
ios << "if ";
|
|
auto parenScope = ios.scope("(", ") {");
|
|
ios.indent();
|
|
|
|
auto listHelperName = [](StringRef name) {
|
|
return formatv("read{0}", capitalize(name));
|
|
};
|
|
|
|
auto parsedArgs =
|
|
llvm::to_vector(make_filter_range(args, [](Init *const attr) {
|
|
Record *def = cast<DefInit>(attr)->getDef();
|
|
if (def->isSubClassOf("Array"))
|
|
return true;
|
|
return !def->getValueAsString("cParser").empty();
|
|
}));
|
|
|
|
interleave(
|
|
zip(parsedArgs, argNames),
|
|
[&](std::tuple<llvm::Init *&, const std::string &> it) {
|
|
Record *attr = cast<DefInit>(std::get<0>(it))->getDef();
|
|
std::string parser;
|
|
if (auto optParser = attr->getValueAsOptionalString("cParser")) {
|
|
parser = *optParser;
|
|
} else if (attr->isSubClassOf("Array")) {
|
|
Record *def = attr->getValueAsDef("elemT");
|
|
bool composite = def->isSubClassOf("CompositeBytecode");
|
|
if (!composite && def->isSubClassOf("AttributeKind"))
|
|
parser = "succeeded($_reader.readAttributes($_var))";
|
|
else if (!composite && def->isSubClassOf("TypeKind"))
|
|
parser = "succeeded($_reader.readTypes($_var))";
|
|
else
|
|
parser = ("succeeded($_reader.readList($_var, " +
|
|
listHelperName(std::get<1>(it)) + "))")
|
|
.str();
|
|
} else {
|
|
PrintFatalError(attr->getLoc(), "No parser specified");
|
|
}
|
|
std::string type = getCType(attr);
|
|
ios << format(parser, {{"$_reader", "reader"},
|
|
{"$_resultType", type},
|
|
{"$_var", std::get<1>(it)}});
|
|
},
|
|
[&]() { ios << " &&\n"; });
|
|
}
|
|
|
|
void Generator::emitParseHelper(StringRef kind, StringRef returnType,
|
|
StringRef builder, ArrayRef<Init *> args,
|
|
ArrayRef<std::string> argNames,
|
|
StringRef failure,
|
|
mlir::raw_indented_ostream &ios) {
|
|
auto funScope = ios.scope("{\n", "}");
|
|
|
|
if (args.empty()) {
|
|
ios << formatv("return get<{0}>(context);\n", returnType);
|
|
return;
|
|
}
|
|
|
|
// Print decls.
|
|
std::string lastCType = "";
|
|
for (auto [arg, name] : zip(args, argNames)) {
|
|
DefInit *first = dyn_cast<DefInit>(arg);
|
|
if (!first)
|
|
PrintFatalError("Unexpected type for " + name);
|
|
Record *def = first->getDef();
|
|
|
|
// Create variable decls, if there are a block of same type then create
|
|
// comma separated list of them.
|
|
std::string cType = getCType(def);
|
|
if (lastCType == cType) {
|
|
ios << ", ";
|
|
} else {
|
|
if (!lastCType.empty())
|
|
ios << ";\n";
|
|
ios << cType << " ";
|
|
}
|
|
ios << name;
|
|
lastCType = cType;
|
|
}
|
|
ios << ";\n";
|
|
|
|
// Returns the name of the helper used in list parsing. E.g., the name of the
|
|
// lambda passed to array parsing.
|
|
auto listHelperName = [](StringRef name) {
|
|
return formatv("read{0}", capitalize(name));
|
|
};
|
|
|
|
// Emit list helper functions.
|
|
for (auto [arg, name] : zip(args, argNames)) {
|
|
Record *attr = cast<DefInit>(arg)->getDef();
|
|
if (!attr->isSubClassOf("Array"))
|
|
continue;
|
|
|
|
// TODO: Dedupe readers.
|
|
Record *def = attr->getValueAsDef("elemT");
|
|
if (!def->isSubClassOf("CompositeBytecode") &&
|
|
(def->isSubClassOf("AttributeKind") || def->isSubClassOf("TypeKind")))
|
|
continue;
|
|
|
|
std::string returnType = getCType(def);
|
|
ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<"
|
|
<< returnType << "> ";
|
|
SmallVector<Init *> args;
|
|
SmallVector<std::string> argNames;
|
|
if (def->isSubClassOf("CompositeBytecode")) {
|
|
DagInit *members = def->getValueAsDag("members");
|
|
args = llvm::to_vector(members->getArgs());
|
|
argNames = llvm::to_vector(
|
|
map_range(members->getArgNames(), [](StringInit *init) {
|
|
return init->getAsUnquotedString();
|
|
}));
|
|
} else {
|
|
args = {def->getDefInit()};
|
|
argNames = {"temp"};
|
|
}
|
|
StringRef builder = def->getValueAsString("cBuilder");
|
|
emitParseHelper(kind, returnType, builder, args, argNames, "failure()",
|
|
ios);
|
|
ios << ";\n";
|
|
}
|
|
|
|
// Print parse conditional.
|
|
printParseConditional(ios, args, argNames);
|
|
|
|
// Compute args to pass to create method.
|
|
auto passedArgs = llvm::to_vector(make_filter_range(
|
|
argNames, [](StringRef str) { return !str.starts_with("_"); }));
|
|
std::string argStr;
|
|
raw_string_ostream argStream(argStr);
|
|
interleaveComma(passedArgs, argStream,
|
|
[&](const std::string &str) { argStream << str; });
|
|
// Return the invoked constructor.
|
|
ios << "\nreturn "
|
|
<< format(builder, {{"$_resultType", returnType.str()},
|
|
{"$_args", argStream.str()}})
|
|
<< ";\n";
|
|
ios.unindent();
|
|
|
|
// TODO: Emit error in debug.
|
|
// This assumes the result types in error case can always be empty
|
|
// constructed.
|
|
ios << "}\nreturn " << failure << ";\n";
|
|
}
|
|
|
|
void Generator::emitPrint(StringRef kind, StringRef type,
|
|
ArrayRef<std::pair<int64_t, Record *>> vec) {
|
|
if (type == "ReservedOrDead")
|
|
return;
|
|
|
|
char const *head =
|
|
R"(static void write({0} {1}, DialectBytecodeWriter &writer) )";
|
|
mlir::raw_indented_ostream os(output);
|
|
os << formatv(head, type, kind);
|
|
auto funScope = os.scope("{\n", "}\n\n");
|
|
|
|
// Check that predicates specified if multiple bytecode instances.
|
|
for (llvm::Record *rec : make_second_range(vec)) {
|
|
StringRef pred = rec->getValueAsString("printerPredicate");
|
|
if (vec.size() > 1 && pred.empty()) {
|
|
for (auto [index, rec] : vec) {
|
|
(void)index;
|
|
StringRef pred = rec->getValueAsString("printerPredicate");
|
|
if (vec.size() > 1 && pred.empty())
|
|
PrintError(rec->getLoc(),
|
|
"Requires parsing predicate given common cType");
|
|
}
|
|
PrintFatalError("Unspecified for shared cType " + type);
|
|
}
|
|
}
|
|
|
|
for (auto [index, rec] : vec) {
|
|
StringRef pred = rec->getValueAsString("printerPredicate");
|
|
if (!pred.empty()) {
|
|
os << "if (" << format(pred, {{"$_val", kind.str()}}) << ") {\n";
|
|
os.indent();
|
|
}
|
|
|
|
os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index
|
|
<< ");\n";
|
|
|
|
auto *members = rec->getValueAsDag("members");
|
|
for (auto [arg, name] :
|
|
llvm::zip(members->getArgs(), members->getArgNames())) {
|
|
DefInit *def = dyn_cast<DefInit>(arg);
|
|
assert(def);
|
|
Record *memberRec = def->getDef();
|
|
emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os);
|
|
}
|
|
|
|
if (!pred.empty()) {
|
|
os.unindent();
|
|
os << "}\n";
|
|
}
|
|
}
|
|
}
|
|
|
|
void Generator::emitPrintHelper(Record *memberRec, StringRef kind,
|
|
StringRef parent, StringRef name,
|
|
mlir::raw_indented_ostream &ios) {
|
|
std::string getter;
|
|
if (auto cGetter = memberRec->getValueAsOptionalString("cGetter");
|
|
cGetter && !cGetter->empty()) {
|
|
getter = format(
|
|
*cGetter,
|
|
{{"$_attrType", parent.str()},
|
|
{"$_member", name.str()},
|
|
{"$_getMember", "get" + convertToCamelFromSnakeCase(name, true)}});
|
|
} else {
|
|
getter =
|
|
formatv("{0}.get{1}()", parent, convertToCamelFromSnakeCase(name, true))
|
|
.str();
|
|
}
|
|
|
|
if (memberRec->isSubClassOf("Array")) {
|
|
Record *def = memberRec->getValueAsDef("elemT");
|
|
if (!def->isSubClassOf("CompositeBytecode")) {
|
|
if (def->isSubClassOf("AttributeKind")) {
|
|
ios << "writer.writeAttributes(" << getter << ");\n";
|
|
return;
|
|
}
|
|
if (def->isSubClassOf("TypeKind")) {
|
|
ios << "writer.writeTypes(" << getter << ");\n";
|
|
return;
|
|
}
|
|
}
|
|
std::string returnType = getCType(def);
|
|
std::string nestedName = kind.str();
|
|
ios << "writer.writeList(" << getter << ", [&](" << returnType << " "
|
|
<< nestedName << ") ";
|
|
auto lambdaScope = ios.scope("{\n", "});\n");
|
|
return emitPrintHelper(def, kind, nestedName, nestedName, ios);
|
|
}
|
|
if (memberRec->isSubClassOf("CompositeBytecode")) {
|
|
auto *members = memberRec->getValueAsDag("members");
|
|
for (auto [arg, argName] :
|
|
zip(members->getArgs(), members->getArgNames())) {
|
|
DefInit *def = dyn_cast<DefInit>(arg);
|
|
assert(def);
|
|
emitPrintHelper(def->getDef(), kind, parent,
|
|
argName->getAsUnquotedString(), ios);
|
|
}
|
|
}
|
|
|
|
if (std::string printer = memberRec->getValueAsString("cPrinter").str();
|
|
!printer.empty())
|
|
ios << format(printer, {{"$_writer", "writer"},
|
|
{"$_name", kind.str()},
|
|
{"$_getter", getter}})
|
|
<< ";\n";
|
|
}
|
|
|
|
void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) {
|
|
mlir::raw_indented_ostream os(output);
|
|
char const *head = R"(static LogicalResult write{0}({0} {1},
|
|
DialectBytecodeWriter &writer))";
|
|
os << formatv(head, capitalize(kind), kind);
|
|
auto funScope = os.scope(" {\n", "}\n\n");
|
|
|
|
os << "return TypeSwitch<" << capitalize(kind) << ", LogicalResult>(" << kind
|
|
<< ")";
|
|
auto switchScope = os.scope("", "");
|
|
for (StringRef type : vec) {
|
|
if (type == "ReservedOrDead")
|
|
continue;
|
|
|
|
os << "\n.Case([&](" << type << " t)";
|
|
auto caseScope = os.scope(" {\n", "})");
|
|
os << "return write(t, writer), success();\n";
|
|
}
|
|
os << "\n.Default([&](" << capitalize(kind) << ") { return failure(); });\n";
|
|
}
|
|
|
|
namespace {
|
|
/// Container of Attribute or Type for Dialect.
|
|
struct AttrOrType {
|
|
std::vector<Record *> attr, type;
|
|
};
|
|
} // namespace
|
|
|
|
static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
|
|
MapVector<StringRef, AttrOrType> dialectAttrOrType;
|
|
for (auto &it : records.getAllDerivedDefinitions("DialectAttributes")) {
|
|
if (!selectedBcDialect.empty() &&
|
|
it->getValueAsString("dialect") != selectedBcDialect)
|
|
continue;
|
|
dialectAttrOrType[it->getValueAsString("dialect")].attr =
|
|
it->getValueAsListOfDefs("elems");
|
|
}
|
|
for (auto &it : records.getAllDerivedDefinitions("DialectTypes")) {
|
|
if (!selectedBcDialect.empty() &&
|
|
it->getValueAsString("dialect") != selectedBcDialect)
|
|
continue;
|
|
dialectAttrOrType[it->getValueAsString("dialect")].type =
|
|
it->getValueAsListOfDefs("elems");
|
|
}
|
|
|
|
if (dialectAttrOrType.size() != 1)
|
|
PrintFatalError("Single dialect per invocation required (either only "
|
|
"one in input file or specified via dialect option)");
|
|
|
|
auto it = dialectAttrOrType.front();
|
|
Generator gen(os);
|
|
|
|
SmallVector<std::vector<Record *> *, 2> vecs;
|
|
SmallVector<std::string, 2> kinds;
|
|
vecs.push_back(&it.second.attr);
|
|
kinds.push_back("attribute");
|
|
vecs.push_back(&it.second.type);
|
|
kinds.push_back("type");
|
|
for (auto [vec, kind] : zip(vecs, kinds)) {
|
|
// Handle Attribute/Type emission.
|
|
std::map<std::string, std::vector<std::pair<int64_t, Record *>>> perType;
|
|
for (auto kt : llvm::enumerate(*vec))
|
|
perType[getCType(kt.value())].emplace_back(kt.index(), kt.value());
|
|
for (const auto &jt : perType) {
|
|
for (auto kt : jt.second)
|
|
gen.emitParse(kind, *std::get<1>(kt));
|
|
gen.emitPrint(kind, jt.first, jt.second);
|
|
}
|
|
gen.emitParseDispatch(kind, *vec);
|
|
|
|
SmallVector<std::string> types;
|
|
for (const auto &it : perType) {
|
|
types.push_back(it.first);
|
|
}
|
|
gen.emitPrintDispatch(kind, types);
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
static mlir::GenRegistration
|
|
genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers",
|
|
[](const RecordKeeper &records, raw_ostream &os) {
|
|
return emitBCRW(records, os);
|
|
});
|