bolt/scripts/gennodes.py

849 lines
26 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
from os import wait
import re
from collections import deque
from pathlib import Path
import argparse
from typing import List, Optional
from sweetener.record import Record
import templaty
here = Path(__file__).parent.resolve()
EOF = '\uFFFF'
END_OF_FILE = 0
IDENTIFIER = 1
SEMI = 2
EXTERNAL = 3
NODE = 4
LBRACE = 5
RBRACE = 6
LESSTHAN = 7
GREATERTHAN = 8
COLON = 9
LPAREN = 10
RPAREN = 11
VBAR = 12
COMMA = 13
HASH = 14
STRING = 15
RE_WHTITESPACE = re.compile(r"[\n\r\t ]")
RE_IDENT_START = re.compile(r"[a-zA-Z_]")
RE_IDENT_PART = re.compile(r"[a-zA-Z_0-9]")
KEYWORDS = {
'external': EXTERNAL,
'node': NODE,
}
def escape_char(ch):
code = ord(ch)
if code >= 32 and code < 126:
return ch
if code <= 127:
return f"\\x{code:02X}"
return f"\\u{code:04X}"
def camel_case(ident: str) -> str:
out = ident[0].upper()
i = 1
while i < len(ident):
ch = ident[i]
i += 1
if ch == '_':
c1 = ident[i]
i += 1
out += c1.upper()
else:
out += ch
return out
class ScanError(RuntimeError):
def __init__(self, file, position, actual):
super().__init__(f"{file.name}:{position.line}:{position.column}: unexpected character '{escape_char(actual)}'")
self.file = file
self.position = position
self.actual = actual
TOKEN_TYPE_TO_STRING = {
LPAREN: '(',
RPAREN: ')',
LBRACE: '{',
RBRACE: '}',
LESSTHAN: '<',
GREATERTHAN: '>',
NODE: 'node',
EXTERNAL: 'external',
SEMI: ';',
COLON: ':',
COMMA: ',',
VBAR: '|',
HASH: '#',
}
class Token:
def __init__(self, type, position=None, value=None):
self.type = type
self.start_pos = position
self.value = value
@property
def text(self):
if self.type in TOKEN_TYPE_TO_STRING:
return TOKEN_TYPE_TO_STRING[self.type]
if self.type == IDENTIFIER:
return self.value
if self.type == STRING:
return f'"{self.value}"'
if self.type == END_OF_FILE:
return ''
return '(unknown token)'
class TextFile:
def __init__(self, filename, text=None):
self.name = filename
self._cached_text = text
@property
def text(self):
if self._cached_text is None:
with open(self.name, 'r') as f:
self._cached_text = f.read()
return self._cached_text
class TextPos:
def __init__(self, line=1, column=1):
self.line = line
self.column = column
def clone(self):
return TextPos(self.line, self.column)
def advance(self, text):
for ch in text:
if ch == '\n':
self.line += 1
self.column = 1
else:
self.column += 1
class Scanner:
def __init__(self, text, text_offset=0, filename=None):
self._text = text
self._text_offset = text_offset
self.file = TextFile(filename, text)
self._curr_pos = TextPos()
def _peek_char(self, offset=1):
i = self._text_offset + offset - 1
return self._text[i] if i < len(self._text) else EOF
def _get_char(self):
if self._text_offset == len(self._text):
return EOF
i = self._text_offset
self._text_offset += 1
ch = self._text[i]
self._curr_pos.advance(ch)
return ch
def _take_while(self, pred):
out = ''
while True:
ch = self._peek_char()
if not pred(ch):
break
self._get_char()
out += ch
return out
def scan(self):
while True:
c0 = self._peek_char()
c1 = self._peek_char(2)
if c0 == '/' and c1 == '/':
self._get_char()
self._get_char()
while True:
c3 = self._get_char()
if c3 == '\n' or c3 == EOF:
break
continue
if RE_WHTITESPACE.match(c0):
self._get_char()
continue
break
if c0 == EOF:
return Token(END_OF_FILE, self._curr_pos.clone())
start_pos = self._curr_pos.clone()
self._get_char()
if c0 == ';': return Token(SEMI, start_pos)
if c0 == '{': return Token(LBRACE, start_pos)
if c0 == '}': return Token(RBRACE, start_pos)
if c0 == '(': return Token(LPAREN, start_pos)
if c0 == ')': return Token(RPAREN, start_pos)
if c0 == '<': return Token(LESSTHAN, start_pos)
if c0 == '>': return Token(GREATERTHAN, start_pos)
if c0 == ':': return Token(COLON, start_pos)
if c0 == '|': return Token(VBAR, start_pos)
if c0 == ',': return Token(COMMA, start_pos)
if c0 == '#': return Token(HASH, start_pos)
if c0 == '"':
text = ''
while True:
c1 = self._get_char()
if c1 == '"':
break
text += c1
return Token(STRING, start_pos, text)
if RE_IDENT_START.match(c0):
name = c0 + self._take_while(lambda ch: RE_IDENT_PART.match(ch))
return Token(KEYWORDS[name], start_pos) \
if name in KEYWORDS \
else Token(IDENTIFIER, start_pos, name)
raise ScanError(self.file, start_pos, c0)
class Type(Record):
pass
class ListType(Type):
element_type: Type
class OptionalType(Type):
element_type: Type
class NodeType(Type):
name: str
class VariantType(Type):
types: List[Type]
class RawType(Type):
text: str
class AST(Record):
pass
class Directive(AST):
pass
INCLUDEMODE_LOCAL = 0
INCLUDEMODE_SYSTEM = 1
class IncludeDiretive(Directive):
path: str
mode: int
def __str__(self):
if self.mode == INCLUDEMODE_LOCAL:
return f"#include \"{self.path}\"\n"
if self.mode == INCLUDEMODE_SYSTEM:
return f"#include <{self.path}>\n"
class TypeExpr(AST):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.type = None
class RefTypeExpr(TypeExpr):
name: str
args: List[TypeExpr]
class UnionTypeExpr(TypeExpr):
types: List[TypeExpr]
class External(AST):
name: str
class NodeDeclField(AST):
name: str
type_expr: TypeExpr
class NodeDecl(AST):
name: str
parents: List[str]
members: List[NodeDeclField]
def pretty_token(token):
if token.type == END_OF_FILE:
return 'end-of-file'
return f"'{token.text}'"
def pretty_token_type(token_type):
if token_type in TOKEN_TYPE_TO_STRING:
return f"'{TOKEN_TYPE_TO_STRING[token_type]}'"
if token_type == IDENTIFIER:
return 'an identfier'
if token_type == STRING:
return 'a string literal'
if token_type == END_OF_FILE:
return 'end-of-file'
return f"(unknown token type {token_type})"
def pretty_alternatives(elements):
try:
out = next(elements)
except StopIteration:
return 'nothing'
try:
prev_element = next(elements)
except StopIteration:
return out
while True:
try:
element = next(elements)
except StopIteration:
break
out += ', ' + prev_element
prev_element = element
return out + ' or ' + prev_element
class ParseError(RuntimeError):
def __init__(self, file, actual, expected):
super().__init__(f"{file.name}:{actual.start_pos.line}:{actual.start_pos.column}: got {pretty_token(actual)} but expected {pretty_alternatives(pretty_token_type(tt) for tt in expected)}")
self.actual = actual
self.expected = expected
class Parser:
def __init__(self, scanner):
self._scanner = scanner
self._token_buffer = deque()
def _peek_token(self, offset=1):
while len(self._token_buffer) < offset:
self._token_buffer.append(self._scanner.scan())
return self._token_buffer[offset-1]
def _get_token(self):
if self._token_buffer:
return self._token_buffer.popleft()
return self._scanner.scan()
def _expect_token(self, expected_token_type):
t0 = self._get_token()
if t0.type != expected_token_type:
raise ParseError(self._scanner.file, t0, [ expected_token_type ])
return t0
def _parse_prim_type_expr(self):
t0 = self._get_token()
if t0.type == LPAREN:
result = self.parse_type_expr()
self._expect_token(RPAREN)
return result
if t0.type == IDENTIFIER:
t1 = self._peek_token()
args = []
if t1.type == LESSTHAN:
self._get_token()
while True:
t2 = self._peek_token()
if t2.type == GREATERTHAN:
self._get_token()
break
args.append(self.parse_type_expr())
t3 = self._get_token()
if t3.type == GREATERTHAN:
break
if t3.type != COMMA:
raise ParseError(self._scanner.file, t3, [ COMMA, GREATERTHAN ])
return RefTypeExpr(t0.value, args)
raise ParseError(self._scanner.file, t0, [ LPAREN, IDENTIFIER ])
def parse_type_expr(self):
return self._parse_prim_type_expr()
def parse_member(self):
type_expr = self.parse_type_expr()
name = self._expect_token(IDENTIFIER)
self._expect_token(SEMI)
return NodeDeclField(name.value, type_expr)
def parse_toplevel(self):
t0 = self._get_token()
if t0.type == EXTERNAL:
name = self._expect_token(IDENTIFIER)
self._expect_token(SEMI)
return External(name.value)
if t0.type == NODE:
name = self._expect_token(IDENTIFIER).value
parents = []
t1 = self._peek_token()
if t1.type == COLON:
self._get_token()
while True:
parent = self._expect_token(IDENTIFIER).value
parents.append(parent)
t2 = self._peek_token()
if t2.type == COMMA:
self._get_token()
continue
if t2.type == LBRACE:
break
raise ParseError(self._scanner.file, t2, [ COMMA, LBRACE ])
self._expect_token(LBRACE)
members = []
while True:
t2 = self._peek_token()
if t2.type == RBRACE:
self._get_token()
break
member = self.parse_member()
members.append(member)
return NodeDecl(name, parents, members)
if t0.type == HASH:
name = self._expect_token(IDENTIFIER)
if name.value == 'include':
t1 = self._get_token()
if t1.type == LESSTHAN:
assert(not self._token_buffer)
path = self._scanner._take_while(lambda ch: ch != '>')
self._scanner._get_char()
mode = INCLUDEMODE_SYSTEM
elif t1.type == STRING:
mode = INCLUDEMODE_LOCAL
path = t1.value
else:
raise ParseError(self._scanner.file, t1, [ STRING, LESSTHAN ])
return IncludeDiretive(path, mode)
raise RuntimeError(f"invalid preprocessor directive '{name.value}'")
raise ParseError(self._scanner.file, t0, [ EXTERNAL, NODE, HASH ])
def parse_grammar(self):
elements = []
while True:
t0 = self._peek_token()
if t0.type == END_OF_FILE:
break
element = self.parse_toplevel()
elements.append(element)
return elements
class Writer:
def __init__(self, text='', path=None):
self.path = path
self.text = text
self._at_blank_line = True
self._indentation = ' '
self._indent_level = 0
def indent(self, count=1):
self._indent_level += count
def dedent(self, count=1):
self._indent_level -= count
def write(self, chunk):
for ch in chunk:
if ch == '}':
self.dedent()
if ch == '\n':
self._at_blank_line = True
elif self._at_blank_line and not RE_WHTITESPACE.match(ch):
self.text += self._indentation * self._indent_level
self._at_blank_line = False
self.text += ch
if ch == '{':
self.indent()
def save(self, dest_dir):
dest_path = dest_dir / self.path
print(f'Writing file {dest_path} ...')
with open(dest_path, 'w') as f:
f.write(self.text)
class DiGraph:
def __init__(self):
self._out_edges = dict()
self._in_edges = dict()
def add_edge(self, a, b):
if a not in self._out_edges:
self._out_edges[a] = set()
self._out_edges[a].add(b)
if b not in self._in_edges:
self._in_edges[b] = set()
self._in_edges[b].add(a)
def get_children(self, node):
if node not in self._out_edges:
return
for child in self._out_edges[node]:
yield child
def has_children(self, node):
return node in self._out_edges
def is_child_of(self, a, b):
stack = [ b ]
visited = set()
while stack:
node = stack.pop()
if node in visited:
break
visited.add(node)
if node == a:
return True
for child in self.get_children(node):
stack.append(child)
return False
def get_ancestors(self, node):
if node not in self._in_edges:
return
for parent in self._in_edges[node]:
yield parent
def get_common_ancestor(self, nodes):
out = nodes[0]
parents = []
for node in nodes[1:]:
if not self.is_child_of(node, out):
for parent in self.get_ancestors(node):
parents.append(parent)
if not parents:
return out
parents.append(out)
return self.get_common_ancestor(parents)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('file', nargs=1, help='The specification file to generate C++ code for')
parser.add_argument('--namespace', default='', help='What C++ namespace to put generated code under')
parser.add_argument('--name', default='AST', help='How to name the generated tree')
parser.add_argument('-I', default='.', help='What path will be used to include generated header files')
parser.add_argument('--include-root', default='.', help='Where the headers live inside the include directroy')
parser.add_argument('--enable-serde', action='store_true', help='Also write (de)serialization logic')
parser.add_argument('--source-root', default='.', help='Where to store generated souce files')
parser.add_argument('--node-name', default='Node', help='How the root node of the hierachy should be called')
parser.add_argument('--node-prefix', default='', help='String to prepend to the names of node types')
parser.add_argument('--out-dir', default='.', help='Place the endire folder structure inside this folder')
parser.add_argument('--dry-run', action='store_true', help='Do not write generated code to the file system')
args = parser.parse_args()
filename = args.file[0]
prefix = args.node_prefix
cpp_root_node_name = prefix + args.node_name
include_dir = Path(args.I)
include_path = Path(args.include_root or '.')
full_include_path = include_dir / include_path
source_path = Path(args.source_root)
namespace = args.namespace.split('::')
out_dir = Path(args.out_dir)
out_name = args.name
write_serde = args.enable_serde
with open(filename, 'r') as f:
text = f.read()
scanner = Scanner(text, filename=filename)
parser = Parser(scanner)
elements = parser.parse_grammar()
types = dict()
nodes = list()
leaf_nodes = list()
graph = DiGraph()
parent_to_children = dict()
for element in elements:
if isinstance(element, External) \
or isinstance(element, NodeDecl):
types[element.name] = element
if isinstance(element, NodeDecl):
nodes.append(element)
for parent in element.parents:
graph.add_edge(parent, element.name)
if parent not in parent_to_children:
parent_to_children[parent] = set()
children = parent_to_children[parent]
children.add(element)
for node in nodes:
if node.name not in parent_to_children:
leaf_nodes.append(node)
def is_null_type_expr(type_expr):
return isinstance(type_expr, RefTypeExpr) and type_expr.name == 'null'
def is_node(name):
if name in types:
return isinstance(types[name], NodeDecl)
if name in parent_to_children:
return True
return False
def get_all_variant_elements(type_expr):
types = list()
def loop(ty):
if isinstance(ty, RefTypeExpr) and ty.name == 'Variant':
for arg in ty.args:
loop(arg)
else:
types.append(ty)
loop(type_expr)
return types
def infer_type(type_expr):
if isinstance(type_expr, RefTypeExpr):
if type_expr.name == 'Option':
assert(len(type_expr.args) == 1)
return OptionalType(infer_type(type_expr.args[0]))
if type_expr.name == 'List':
assert(len(type_expr.args) == 1)
return ListType(infer_type(type_expr.args[0]))
if type_expr.name == 'Variant':
types = get_all_variant_elements(type_expr)
has_null = False
if any(is_null_type_expr(ty) for ty in types):
has_null = True
types = list(ty for ty in types if not is_null_type_expr(ty))
if all(isinstance(ty, RefTypeExpr) and is_node(ty.name) for ty in types):
node_name = graph.get_common_ancestor(list(t.name for t in types))
return NodeType(node_name)
if len(types) == 1:
out = infer_type(types[0])
else:
out = VariantType(infer_type(ty) for ty in types)
return OptionalType(out) if has_null else out
if is_node(type_expr.name):
assert(len(type_expr.args) == 0)
return NodeType(type_expr.name)
assert(len(type_expr.args) == 0)
return RawType(type_expr.name)
raise RuntimeError(f"unhandled type expression {type_expr}")
for node in nodes:
for member in node.members:
member.type_expr.type = infer_type(member.type_expr)
def is_type_optional_by_default(ty):
return isinstance(ty, NodeType)
def gen_cpp_type_expr(ty):
if isinstance(ty, NodeType):
return prefix + ty.name + "*"
if isinstance(ty, ListType):
return f"std::vector<{gen_cpp_type_expr(ty.element_type)}>"
if isinstance(ty, NodeType):
return ty.name + '*'
if isinstance(ty, OptionalType):
cpp_expr = gen_cpp_type_expr(ty.element_type)
if is_type_optional_by_default(ty.element_type):
return cpp_expr
return f"std::optional<{cpp_expr}>"
if isinstance(ty, VariantType):
return f"std::variant<{','.join(gen_cpp_type_expr(t) for t in ty.element_types)}>"
if isinstance(ty, RawType):
return ty.text
raise RuntimeError(f"unhandled Type {ty}")
def gen_cpp_dtor(expr, ty):
if isinstance(ty, NodeType):
return f'{expr}->unref();\n'
elif isinstance(ty, ListType):
dtor = gen_cpp_dtor('Element', ty.element_type)
if dtor:
out = ''
out += f'for (auto& Element: {expr})'
out += '{\n'
out += dtor
out += '}\n'
return out
elif isinstance(ty, OptionalType):
if is_type_optional_by_default(ty.element_type):
element_expr = expr
else:
element_expr = f'(*{expr})'
dtor = gen_cpp_dtor(element_expr, ty.element_type)
if dtor:
out = ''
out += 'if ('
out += expr
out += ') {\n'
out += dtor
out += '}\n'
return out
elif isinstance(ty, RawType):
pass # field should be destroyed by class
else:
raise RuntimeError(f'unexpected {ty}')
def gen_cpp_ctor_params(out, node):
visited = set()
queue = deque([ node ])
is_leaf = not graph.has_children(node.name)
first = True
if not is_leaf:
out.write(f"{cpp_root_node_name}Type Type")
first = False
while queue:
node = queue.popleft()
if node.name in visited:
return
visited.add(node.name)
for member in node.members:
if first:
first = False
else:
out.write(', ')
out.write(gen_cpp_type_expr(member.type_expr.type))
out.write(' ')
out.write(camel_case(member.name))
for parent in node.parents:
queue.append(types[parent])
def gen_cpp_ctor_args(out, orig_node: NodeDecl):
first = True
is_leaf = not graph.has_children(orig_node.name)
if orig_node.parents:
for parent in orig_node.parents:
if first:
first = False
else:
out.write(', ')
node = types[parent]
refs = ''
if is_leaf:
refs += f"{cpp_root_node_name}Type::{orig_node.name}"
else:
refs += 'Type'
for member in node.members:
refs += f", {camel_case(member.name)}"
out.write(f"{prefix}{node.name}({refs})")
else:
if is_leaf:
out.write(f"{cpp_root_node_name}({cpp_root_node_name}Type::{orig_node.name})")
else:
out.write(f"{cpp_root_node_name}(Type)")
first = False
for member in orig_node.members:
if first:
first = False
else:
out.write(', ')
out.write(f"{camel_case(member.name)}({camel_case(member.name)})")
node_hdr = templaty.execute(here / 'CST.hpp.tply', ctx={
'namespaces': namespace,
'nodes': nodes,
'root_node_name': args.node_name
})
node_hdr = Writer(path=full_include_path / (out_name + '.hpp'))
node_src = Writer(path=source_path / (out_name + '.cc'))
# Generating the header file
if write_serde:
node_hdr.write('void encode(Encoder& encoder) const;\n\n')
node_hdr.write('virtual void encode_fields(Encoder& encoder) const = 0;\n');
#node_hdr.write('virtual void decode_fields(Decoder& decoder) = 0;\n\n');
for element in elements:
if isinstance(element, NodeDecl):
node = element
is_leaf = not list(graph.get_children(node.name))
cpp_node_name = prefix + node.name
node_hdr.write("class ")
node_hdr.write(cpp_node_name)
node_hdr.write(" : ")
if node.parents:
node_hdr.write(', '.join('public ' + prefix + parent for parent in node.parents))
else:
node_hdr.write('public ' + cpp_root_node_name)
node_hdr.write(" {\n\n")
node_hdr.write('public:\n\n')
node_hdr.write(cpp_node_name + '(')
gen_cpp_ctor_params(node_hdr, node)
node_hdr.write('): ')
gen_cpp_ctor_args(node_hdr, node)
node_hdr.write(' {}\n\n')
if node.members:
for member in node.members:
node_hdr.write(gen_cpp_type_expr(member.type_expr.type))
node_hdr.write(" ");
node_hdr.write(camel_case(member.name))
node_hdr.write(";\n");
node_hdr.write('\n')
if write_serde and is_leaf:
node_hdr.write('void encode_fields(Encoder& encoder) const override;\n');
#node_hdr.write('void decode_fields(Decoder& decoder) override;\n\n');
# Generating the source file
node_src.write(f"""#include "{include_path / (out_name + '.hpp')}"\n\n""")
for name in namespace:
node_src.write(f"namespace {name} {{\n\n")
node_src.write(f"""{cpp_root_node_name}::~{cpp_root_node_name}() {{ }}\n\n""")
if write_serde:
node_src.write(f"""
void {cpp_root_node_name}::encode(Encoder& encoder) const {{
encoder.start_encode_struct("{cpp_root_node_name}");
encode_fields(encoder);
encoder.end_encode_struct();
}}
""")
for node in nodes:
is_leaf = not list(graph.get_children(node.name))
cpp_node_name = prefix + node.name
if write_serde and is_leaf:
node_src.write(f'void {cpp_node_name}::encode_fields(Encoder& encoder) const {{\n')
for member in node.members:
node_src.write(f'encoder.encode_field("{member.name}", {member.name});\n')
node_src.write('}\n\n')
node_src.write(f'{cpp_node_name}::~{cpp_node_name}() {{\n')
for member in node.members:
dtor = gen_cpp_dtor(camel_case(member.name), member.type_expr.type)
if dtor:
node_src.write(dtor)
node_src.write('}\n\n')
for _ in namespace:
node_src.write("}\n\n")
if args.dry_run:
print('# ' + str(node_hdr.path))
print(node_hdr.text)
print('# ' + str(node_src.path))
print(node_src.text)
else:
out_dir.mkdir(exist_ok=True, parents=True)
node_hdr.save(out_dir)
node_src.save(out_dir)
if __name__ == '__main__':
main()