bolt/scripts/gennodes.py

904 lines
28 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
import re
from collections import deque
from pathlib import Path
import argparse
from typing import List, Optional, final
from sweetener.record import Record, warn
EOF = '\uFFFF'
END_OF_FILE = 0
IDENTIFIER = 1
SEMI = 2
EXTERNAL_KEYWORD = 3
NODE_KEYWORD = 4
LBRACE = 5
RBRACE = 6
LESSTHAN = 7
GREATERTHAN = 8
COLON = 9
LPAREN = 10
RPAREN = 11
VBAR = 12
COMMA = 13
HASH = 14
STRING = 15
CONST_KEYWORD = 16
NOEXCEPT_KEYWORD = 17
_re_whitespace = re.compile(r"[\n\r\t ]")
_re_ident_start = re.compile(r"[a-zA-Z_]")
_re_ident_part = re.compile(r"[a-zA-Z_0-9]")
_keywords = {
'external': EXTERNAL_KEYWORD,
'node': NODE_KEYWORD,
'const': CONST_KEYWORD,
'noexcept': NOEXCEPT_KEYWORD,
}
def to_upper_camel_case(snake_str):
return "".join(x.capitalize() for x in snake_str.lower().split("_"))
def escape_char(ch):
code = ord(ch)
if code >= 32 and code < 126:
return ch
if code <= 127:
return f"\\x{code:02X}"
return f"\\u{code:04X}"
class ScanError(RuntimeError):
def __init__(self, file, position, actual):
super().__init__(f"{file.name}:{position.line}:{position.column}: unexpected character '{escape_char(actual)}'")
self.file = file
self.position = position
self.actual = actual
_token_type_to_string = {
LPAREN: '(',
RPAREN: ')',
LBRACE: '{',
RBRACE: '}',
LESSTHAN: '<',
GREATERTHAN: '>',
SEMI: ';',
COLON: ':',
COMMA: ',',
VBAR: '|',
HASH: '#',
}
class Token:
def __init__(self, type, position=None, value=None):
self.type = type
self.start_pos = position
self.value = value
@property
def text(self):
if self.type in _token_type_to_string:
return _token_type_to_string[self.type]
if self.type == IDENTIFIER:
return self.value
if self.type == STRING:
return f'"{self.value}"'
if self.type == END_OF_FILE:
return ''
return '(unknown token)'
class TextFile:
def __init__(self, filename, text=None):
self.name = filename
self._cached_text = text
@property
def text(self):
if self._cached_text is None:
with open(self.name, 'r') as f:
self._cached_text = f.read()
return self._cached_text
class TextPos:
def __init__(self, line=1, column=1):
self.line = line
self.column = column
def clone(self):
return TextPos(self.line, self.column)
def advance(self, text):
for ch in text:
if ch == '\n':
self.line += 1
self.column = 1
else:
self.column += 1
class Scanner:
def __init__(self, text, text_offset=0, filename=None):
self._text = text
self._text_offset = text_offset
self.file = TextFile(filename, text)
self._curr_pos = TextPos()
def _peek_char(self, offset=1):
i = self._text_offset + offset - 1
return self._text[i] if i < len(self._text) else EOF
def _get_char(self):
if self._text_offset == len(self._text):
return EOF
i = self._text_offset
self._text_offset += 1
ch = self._text[i]
self._curr_pos.advance(ch)
return ch
def _take_while(self, pred):
out = ''
while True:
ch = self._peek_char()
if not pred(ch):
break
self._get_char()
out += ch
return out
def scan(self):
while True:
start_pos = self._curr_pos.clone()
c0 = self._get_char()
c1 = self._peek_char()
if _re_whitespace.match(c0):
continue
if c0 == '/' and c1 == '/':
while True:
c2 = self._get_char()
if c2 == EOF or c2 == '\n':
break
continue
break
if c0 == EOF:
return Token(END_OF_FILE, self._curr_pos.clone())
if c0 == ';': return Token(SEMI, start_pos)
if c0 == '{': return Token(LBRACE, start_pos)
if c0 == '}': return Token(RBRACE, start_pos)
if c0 == '(': return Token(LPAREN, start_pos)
if c0 == ')': return Token(RPAREN, start_pos)
if c0 == '<': return Token(LESSTHAN, start_pos)
if c0 == '>': return Token(GREATERTHAN, start_pos)
if c0 == ':': return Token(COLON, start_pos)
if c0 == '|': return Token(VBAR, start_pos)
if c0 == ',': return Token(COMMA, start_pos)
if c0 == '#': return Token(HASH, start_pos)
if c0 == '"':
text = ''
while True:
c1 = self._get_char()
if c1 == '"':
break
text += c1
return Token(STRING, start_pos, text)
if _re_ident_start.match(c0):
name = c0 + self._take_while(lambda ch: _re_ident_part.match(ch))
return Token(_keywords[name], start_pos) \
if name in _keywords \
else Token(IDENTIFIER, start_pos, name)
raise ScanError(self.file, start_pos, c0)
class Type(Record):
pass
class ListType(Type):
element_type: Type
class OptionalType(Type):
element_type: Type
class NodeType(Type):
name: str
class VariantType(Type):
types: List[Type]
class RawType(Type):
text: str
class AST(Record):
pass
class Directive(AST):
pass
INCLUDEMODE_LOCAL = 0
INCLUDEMODE_SYSTEM = 1
class IncludeDiretive(Directive):
path: str
mode: int
def __str__(self):
if self.mode == INCLUDEMODE_LOCAL:
return f"#include \"{self.path}\"\n"
if self.mode == INCLUDEMODE_SYSTEM:
return f"#include <{self.path}>\n"
class TypeExpr(AST):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.type = None
class RefTypeExpr(TypeExpr):
name: str
args: List[TypeExpr]
class UnionTypeExpr(TypeExpr):
types: List[TypeExpr]
class External(AST):
name: str
class NodeDeclField(AST):
name: str
type_expr: TypeExpr
class NodeDecl(AST):
name: str
parents: List[str]
members: List[NodeDeclField]
class Expr(AST):
pass
class BlockElement(AST):
pass
class RetStmt(BlockElement):
value: Expr
class Param(AST):
type_expr: TypeExpr
name: Optional[Token]
def pretty_token(token):
if token.type == END_OF_FILE:
return 'end-of-file'
return f"'{token.text}'"
def pretty_token_type(token_type):
if token_type in _token_type_to_string:
return f"'{_token_type_to_string[token_type]}'"
if token_type == IDENTIFIER:
return 'an identfier'
if token_type == STRING:
return 'a string literal'
if token_type == END_OF_FILE:
return 'end-of-file'
return f"(unknown token type {token_type})"
def pretty_alternatives(elements):
try:
out = next(elements)
except StopIteration:
return 'nothing'
try:
prev_element = next(elements)
except StopIteration:
return out
while True:
try:
element = next(elements)
except StopIteration:
break
out += ', ' + prev_element
prev_element = element
return out + ' or ' + prev_element
class ParseError(RuntimeError):
def __init__(self, file, actual, expected):
super().__init__(f"{file.name}:{actual.start_pos.line}:{actual.start_pos.column}: got {pretty_token(actual)} but expected {pretty_alternatives(pretty_token_type(tt) for tt in expected)}")
self.actual = actual
self.expected = expected
class Parser:
def __init__(self, scanner: Scanner):
self._scanner = scanner
self._token_buffer = deque[Token]()
def _peek_token(self, offset=1):
while len(self._token_buffer) < offset:
self._token_buffer.append(self._scanner.scan())
return self._token_buffer[offset-1]
def _get_token(self):
if self._token_buffer:
return self._token_buffer.popleft()
return self._scanner.scan()
def _expect_token(self, expected_token_type):
t0 = self._get_token()
if t0.type != expected_token_type:
raise ParseError(self._scanner.file, t0, [ expected_token_type ])
return t0
def _parse_prim_type_expr(self):
t0 = self._get_token()
if t0.type == LPAREN:
result = self.parse_type_expr()
self._expect_token(RPAREN)
return result
if t0.type == IDENTIFIER:
t1 = self._peek_token()
args = []
if t1.type == LESSTHAN:
self._get_token()
while True:
t2 = self._peek_token()
if t2.type == GREATERTHAN:
self._get_token()
break
args.append(self.parse_type_expr())
t3 = self._get_token()
if t3.type == GREATERTHAN:
break
if t3.type != COMMA:
raise ParseError(self._scanner.file, t3, [ COMMA, GREATERTHAN ])
return RefTypeExpr(t0.value, args)
raise ParseError(self._scanner.file, t0, [ LPAREN, IDENTIFIER ])
def parse_type_expr(self):
return self._parse_prim_type_expr()
def parse_param(self) -> Param:
type_expr = self.parse_type_expr()
t1 = self._peek_token()
name = None
if t1.type == IDENTIFIER:
self._get_token()
name = t1
return Param(type_expr, name)
def parse_node_member(self):
type_expr = self.parse_type_expr()
name = self._expect_token(IDENTIFIER)
self._expect_token(SEMI)
return NodeDeclField(name.value, type_expr)
def parse_toplevel(self):
t0 = self._get_token()
if t0.type == EXTERNAL_KEYWORD:
name = self._expect_token(IDENTIFIER)
self._expect_token(SEMI)
return External(name.value)
if t0.type == NODE_KEYWORD:
name = self._expect_token(IDENTIFIER).value
parents = []
t1 = self._peek_token()
if t1.type == COLON:
self._get_token()
while True:
parent = self._expect_token(IDENTIFIER).value
parents.append(parent)
t2 = self._peek_token()
if t2.type != COMMA:
break
self._get_token()
self._expect_token(LBRACE)
members = []
while True:
t2 = self._peek_token()
if t2.type == RBRACE:
self._get_token()
break
member = self.parse_node_member()
members.append(member)
return NodeDecl(name, parents, members)
if t0.type == HASH:
name = self._expect_token(IDENTIFIER)
if name.value == 'include':
t1 = self._get_token()
if t1.type == LESSTHAN:
assert(not self._token_buffer)
path = self._scanner._take_while(lambda ch: ch != '>')
self._scanner._get_char()
mode = INCLUDEMODE_SYSTEM
elif t1.type == STRING:
mode = INCLUDEMODE_LOCAL
path = t1.value
else:
raise ParseError(self._scanner.file, t1, [ STRING, LESSTHAN ])
return IncludeDiretive(path, mode)
raise RuntimeError(f"invalid preprocessor directive '{name.value}'")
raise ParseError(self._scanner.file, t0, [ EXTERNAL_KEYWORD, NODE_KEYWORD, HASH ])
def parse_grammar(self):
elements = []
while True:
t0 = self._peek_token()
if t0.type == END_OF_FILE:
break
element = self.parse_toplevel()
elements.append(element)
return elements
class Writer:
def __init__(self, text='', path=None):
self.path = path
self.text = text
self._at_blank_line = True
self._indentation = ' '
self._indent_level = 0
def indent(self, count=1):
self._indent_level += count
def dedent(self, count=1):
self._indent_level -= count
def write_dedent(self, chunk):
self.dedent()
self.write(chunk)
self.indent()
def write(self, chunk):
for ch in chunk:
if ch == '}':
self.dedent()
if ch == '\n':
self._at_blank_line = True
elif self._at_blank_line and not _re_whitespace.match(ch):
self.text += self._indentation * self._indent_level
self._at_blank_line = False
self.text += ch
if ch == '{':
self.indent()
def save(self, dest_dir):
dest_path = dest_dir / self.path
print(f'Writing file {dest_path} ...')
with open(dest_path, 'w') as f:
f.write(self.text)
class DiGraph:
def __init__(self):
self._out_edges = dict()
self._in_edges = dict()
def add_edge(self, a, b):
if a not in self._out_edges:
self._out_edges[a] = set()
self._out_edges[a].add(b)
if b not in self._in_edges:
self._in_edges[b] = set()
self._in_edges[b].add(a)
def get_children(self, node):
if node not in self._out_edges:
return
for child in self._out_edges[node]:
yield child
def is_child_of(self, a, b):
stack = [ b ]
visited = set()
while stack:
node = stack.pop()
if node in visited:
break
visited.add(node)
if node == a:
return True
for child in self.get_children(node):
stack.append(child)
return False
def get_ancestors(self, node):
if node not in self._in_edges:
return
for parent in self._in_edges[node]:
yield parent
def get_common_ancestor(self, nodes):
out = nodes[0]
parents = []
for node in nodes[1:]:
if not self.is_child_of(node, out):
for parent in self.get_ancestors(node):
parents.append(parent)
if not parents:
return out
parents.append(out)
return self.get_common_ancestor(parents)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('file', nargs=1, help='The specification file to generate C++ code for')
parser.add_argument('--namespace', default='', help='What C++ namespace to put generated code under')
parser.add_argument('--name', default='AST', help='How to name the generated tree')
parser.add_argument('-I', default='.', help='What path will be used to include generated header files')
parser.add_argument('--include-root', default='.', help='Where the headers live inside the include directroy')
parser.add_argument('--enable-serde', action='store_true', help='Also write (de)serialization logic')
parser.add_argument('--source-root', default='.', help='Where to store generated souce files')
parser.add_argument('--node-name', default='Node', help='How the root node of the hierachy should be called')
parser.add_argument('--out-dir', default='.', help='Place the endire folder structure inside this folder')
parser.add_argument('--dry-run', action='store_true', help='Do not write generated code to the file system')
args = parser.parse_args()
variable_casing = to_upper_camel_case
type_expr_casing = to_upper_camel_case
filename = args.file[0]
prefix = ''
cpp_root_node_name = prefix + args.node_name
include_dir = Path(args.I)
include_path = Path(args.include_root or '.')
full_include_path = include_dir / include_path
source_path = Path(args.source_root)
namespace = args.namespace.split('::')
out_dir = Path(args.out_dir)
out_name = args.name
write_serde = args.enable_serde
with open(filename, 'r') as f:
text = f.read()
scanner = Scanner(text, filename=filename)
parser = Parser(scanner)
elements = parser.parse_grammar()
types = dict()
nodes = list()
leaf_nodes = list()
graph = DiGraph()
parent_to_children = dict()
for element in elements:
if isinstance(element, External) \
or isinstance(element, NodeDecl):
types[element.name] = element
if isinstance(element, NodeDecl):
nodes.append(element)
for parent in element.parents:
graph.add_edge(parent, element.name)
if parent not in parent_to_children:
parent_to_children[parent] = set()
children = parent_to_children[parent]
children.add(element)
for node in nodes:
if node.name not in parent_to_children:
leaf_nodes.append(node)
def is_null_type_expr(type_expr):
return isinstance(type_expr, RefTypeExpr) and type_expr.name == 'null'
def is_node(name):
if name in types:
return isinstance(types[name], NodeDecl)
if name in parent_to_children:
return True
return False
def get_all_variant_elements(type_expr):
types = list()
def loop(ty):
if isinstance(ty, RefTypeExpr) and ty.name == 'Variant':
for arg in ty.args:
loop(arg)
else:
types.append(ty)
loop(type_expr)
return types
def infer_type(type_expr):
if isinstance(type_expr, RefTypeExpr):
if type_expr.name == 'List':
assert(len(type_expr.args) == 1)
return ListType(infer_type(type_expr.args[0]))
if type_expr.name == 'Variant':
types = get_all_variant_elements(type_expr)
has_null = False
for i, ty in enumerate(types):
if is_null_type_expr(ty):
has_null = True
del types[i]
if all(isinstance(ty, RefTypeExpr) and is_node(ty.name) for ty in types):
node_name = graph.get_common_ancestor(list(t.name for t in types))
out = NodeType(node_name)
else:
out = VariantType(types)
if has_null:
out = OptionalType(out)
return out
if type_expr.name == 'Option':
assert(len(type_expr.args) == 1)
return OptionalType(infer_type(type_expr.args[0]))
if is_node(type_expr.name):
assert(len(type_expr.args) == 0)
return NodeType(type_expr.name)
assert(len(type_expr.args) == 0)
return RawType(type_expr.name)
raise RuntimeError(f"unhandled type expression {type_expr}")
for node in nodes:
for member in node.members:
member.type_expr.type = infer_type(member.type_expr)
def is_optional_by_default(ty):
return isinstance(ty, NodeType)
def gen_cpp_type_expr(ty):
if isinstance(ty, NodeType):
return prefix + ty.name + "*"
if isinstance(ty, ListType):
return f"std::vector<{gen_cpp_type_expr(ty.element_type)}>"
if isinstance(ty, NodeType):
return 'class ' + ty.name + '*'
if isinstance(ty, OptionalType):
cpp_expr = gen_cpp_type_expr(ty.element_type)
if is_optional_by_default(ty.element_type):
return cpp_expr
return f"std::optional<{cpp_expr}>"
if isinstance(ty, VariantType):
return f"std::variant<{','.join(gen_cpp_type_expr(t) for t in ty.types)}>"
if isinstance(ty, RawType):
return ty.text
raise RuntimeError(f"unhandled Type {ty}")
def gen_cpp_dtor(expr, ty):
if isinstance(ty, NodeType) or (isinstance(ty, OptionalType) and isinstance(ty.element_type, NodeType)):
return f'{expr}->unref();\n'
if isinstance(ty, ListType):
dtor = gen_cpp_dtor('Element', ty.element_type)
if dtor:
out = ''
out += f'for (auto& Element: {expr})'
out += '{\n'
out += dtor
out += '}\n'
return out
if isinstance(ty, OptionalType):
dtor = gen_cpp_dtor('(*' + expr + ')', ty.element_type)
if dtor:
out = ''
out += 'if ('
out += expr
out += ') {\n'
out += dtor
out += '}\n'
return out
def gen_cpp_ctor_params(out, node):
visited = set()
queue = deque([ node ])
first = True
while queue:
node = queue.popleft()
if node.name in visited:
return
visited.add(node.name)
for member in node.members:
if first:
first = False
else:
out.write(', ')
out.write(gen_cpp_type_expr(member.type_expr.type))
out.write(' ')
out.write(variable_casing(member.name))
for parent in node.parents:
queue.append(types[parent])
def gen_cpp_ctor_args(out, orig_node):
first = True
if orig_node.parents:
for parent in orig_node.parents:
if first:
first = False
else:
out.write(', ')
node = types[parent]
out.write(prefix + node.name)
out.write('(')
out.write(', '.join(variable_casing(member.name) for member in node.members))
out.write(')')
else:
out.write(cpp_root_node_name + '()')
first = False
for member in orig_node.members:
if first:
first = False
else:
out.write(', ')
out.write(variable_casing(member.name))
out.write('(')
out.write(variable_casing(member.name))
out.write(')')
node_hdr = Writer(path=full_include_path / (out_name + '.hpp'))
node_src = Writer(path=source_path / (out_name + '.cc'))
macro_prefix = '_'.join(namespace).upper() + '_'
node_hdr_guard_name = f'{macro_prefix}{out_name.upper()}_HPP'
node_hdr.write(f'#ifndef {node_hdr_guard_name}')
node_hdr.write('\n')
node_hdr.write(f'#define {node_hdr_guard_name}')
node_hdr.write('\n\n')
for element in elements:
if isinstance(element, IncludeDiretive):
node_hdr.write(str(element))
for name in namespace:
node_hdr.write("namespace ")
node_hdr.write(name)
node_hdr.write(" {\n\n")
if write_serde:
node_hdr.write('class Encoder;\n')
node_hdr.write('class Decoder;\n')
node_hdr.write('\n')
for element in elements:
if isinstance(element, NodeDecl):
node_hdr.write(f'class {element.name};\n')
node_hdr.write('\n')
node_hdr.write("enum class ")
node_hdr.write(prefix)
node_hdr.write("NodeType {\n")
for node in leaf_nodes:
node_hdr.write(node.name)
node_hdr.write(",\n")
node_hdr.write("};\n\n")
node_hdr.write("class ")
node_hdr.write(cpp_root_node_name)
node_hdr.write(" {\n\n")
node_hdr.write('unsigned RefCount = 0;\n\n')
node_hdr.write_dedent('public:\n\n')
node_hdr.write('inline void ref() {\n++RefCount;\n}\n\n')
node_hdr.write('inline void unref() {\n--RefCount;\nif (RefCount == 0) {\ndelete this;\n}\n}\n\n')
node_hdr.write(f"{cpp_root_node_name}Type Type;\n\n")
node_hdr.write('void setParents();\n\n')
node_hdr.write('SourceFile* getSourceFile();\n\n');
node_hdr.write('const SourceFile* getSourceFile() const;\n\n');
if write_serde:
node_hdr.write('void encode(Encoder& encoder) const;\n\n')
node_hdr.write('virtual void encode_fields(Encoder& encoder) const = 0;\n');
node_hdr.write('virtual void decode_fields(Decoder& decoder) = 0;\n\n');
node_hdr.write(f'virtual ~{cpp_root_node_name}();')
node_hdr.write('\n')
node_hdr.write("};\n\n")
for element in elements:
if isinstance(element, NodeDecl):
node = element
is_leaf = not list(graph.get_children(node.name))
cpp_node_name = prefix + node.name
node_hdr.write("class ")
node_hdr.write(cpp_node_name)
node_hdr.write(" : public ")
if node.parents:
node_hdr.write(', '.join(prefix + parent for parent in node.parents))
else:
node_hdr.write(cpp_root_node_name)
node_hdr.write(" {\n")
node_hdr.write_dedent('public:\n\n')
node_hdr.write(cpp_node_name + '(')
gen_cpp_ctor_params(node_hdr, node)
node_hdr.write('): ')
gen_cpp_ctor_args(node_hdr, node)
node_hdr.write(' {}\n\n')
if node.members:
for member in node.members:
node_hdr.write(gen_cpp_type_expr(member.type_expr.type))
node_hdr.write(" ");
node_hdr.write(variable_casing(member.name))
node_hdr.write(";\n");
node_hdr.write('\n')
if write_serde and is_leaf:
node_hdr.write('void encode_fields(Encoder& encoder) const override;\n');
node_hdr.write('void decode_fields(Decoder& decoder) override;\n\n');
if node.name == 'SourceFile':
node_hdr.write('TextFile& getTextFile() {\nreturn File;\n}\n\n')
node_hdr.write('const TextFile& getTextFile() const {\nreturn File;\n}\n\n')
# elif node.name == 'Token':
# node_hdr.write('TextRange getRange() const {\nreturn TextRange(StartLoc, EndLoc);\n}\n\n')
node_hdr.write('~')
node_hdr.write(cpp_node_name)
node_hdr.write('();\n\n')
node_hdr.write("};\n\n")
for _ in namespace:
node_hdr.write("}\n\n")
node_hdr.write('#endif\n')
node_src.write('#include "')
node_src.write(str(include_path / (out_name + '.hpp')))
node_src.write('"\n')
node_src.write('\n')
if write_serde:
node_src.write('#include "')
node_src.write(str(include_path / 'Encoder.hpp'))
node_src.write('"\n')
node_src.write('\n')
for name in namespace:
node_src.write("namespace ")
node_src.write(name)
node_src.write(" {\n\n")
if write_serde:
node_src.write(f'void {cpp_root_node_name}::encode(Encoder& encoder) const')
node_src.write(' {\n')
node_src.write(f'encoder.start_encode_struct("{cpp_root_node_name}");')
node_src.write('\n')
node_src.write('encode_fields(encoder);\n')
node_src.write('encoder.end_encode_struct();\n')
node_src.write('}\n\n')
for node in nodes:
is_leaf = not list(graph.get_children(node.name))
cpp_node_name = prefix + node.name
if write_serde and is_leaf:
node_src.write(f'void {cpp_node_name}::encode_fields(Encoder& encoder) const')
node_src.write(' { \n')
for member in node.members:
node_src.write(f'encoder.encode_field("{member.name}", {member.name});')
node_src.write('\n')
node_src.write('}\n\n')
node_src.write(f'{cpp_node_name}::~{cpp_node_name}()')
node_src.write('{ \n')
for member in node.members:
dtor = gen_cpp_dtor(variable_casing(member.name), member.type_expr.type)
if dtor:
node_src.write(dtor)
node_src.write('}\n\n')
for _ in namespace:
node_src.write("}\n\n")
if args.dry_run:
print('# ' + str(node_hdr.path))
print(node_hdr.text)
print('# ' + str(node_src.path))
print(node_src.text)
else:
out_dir.mkdir(exist_ok=True, parents=True)
node_hdr.save(out_dir)
node_src.save(out_dir)
if __name__ == '__main__':
main()