Add experimental support for type classes and many more enhancements

This commit is contained in:
Sam Vervaeck 2023-05-20 23:48:26 +02:00
parent a7fdc59440
commit db26fd3b18
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
23 changed files with 3176 additions and 2446 deletions

4
.vscode/launch.json vendored
View file

@ -9,9 +9,9 @@
"request": "launch",
"name": "Debug",
"program": "${workspaceFolder}/build/bolt",
"args": ["test.bolt"],
"args": [ "test.bolt" ],
"cwd": "${workspaceFolder}",
"preLaunchTask": "CMake: build"
}
]
}
}

51
.vscode/settings.json vendored
View file

@ -27,6 +27,53 @@
"initializer_list": "cpp",
"numeric": "cpp",
"ostream": "cpp",
"system_error": "cpp"
}
"system_error": "cpp",
"cctype": "cpp",
"clocale": "cpp",
"cstdarg": "cpp",
"cstddef": "cpp",
"cstdio": "cpp",
"cstring": "cpp",
"ctime": "cpp",
"cwchar": "cpp",
"cwctype": "cpp",
"any": "cpp",
"atomic": "cpp",
"strstream": "cpp",
"bit": "cpp",
"bitset": "cpp",
"cinttypes": "cpp",
"codecvt": "cpp",
"compare": "cpp",
"complex": "cpp",
"concepts": "cpp",
"condition_variable": "cpp",
"coroutine": "cpp",
"cstdint": "cpp",
"map": "cpp",
"set": "cpp",
"algorithm": "cpp",
"iterator": "cpp",
"memory_resource": "cpp",
"optional": "cpp",
"ratio": "cpp",
"tuple": "cpp",
"type_traits": "cpp",
"utility": "cpp",
"iomanip": "cpp",
"iostream": "cpp",
"mutex": "cpp",
"new": "cpp",
"numbers": "cpp",
"semaphore": "cpp",
"shared_mutex": "cpp",
"stdexcept": "cpp",
"stop_token": "cpp",
"thread": "cpp",
"cfenv": "cpp",
"typeindex": "cpp",
"variant": "cpp",
"__nullptr": "cpp"
},
"C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools"
}

5
.vscode/tasks.json vendored
View file

@ -8,7 +8,10 @@
"targets": [
"all"
],
"group": "build",
"group": {
"kind": "build",
"isDefault": true
},
"problemMatcher": [],
"detail": "CMake template build task"
}

View file

@ -73,11 +73,11 @@ if (BOLT_ENABLE_TESTS)
)
endif()
#add_custom_command(
# OUTPUT "${CMAKE_CURRENT_SOURCE_DIR}/include/bolt/CST.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/CST.cc"
# COMMAND scripts/gennodes.py --name=CST ./bolt-cst-spec.txt -Iinclude/ --include-root=bolt --source-root=src/ --namespace=bolt
# DEPENDS scripts/gennodes.py
# MAIN_DEPENDENCY "${CMAKE_CURRENT_SOURCE_DIR}/bolt-cst-spec.txt"
# WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}"
#)
# add_custom_command(
# OUTPUT "${CMAKE_CURRENT_SOURCE_DIR}/include/bolt/CST.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/CST.cc"
# COMMAND scripts/gennodes.py --name=CST ./bolt-cst-spec.txt -Iinclude/ --include-root=bolt --source-root=src/ --namespace=bolt
# DEPENDS scripts/gennodes.py
# MAIN_DEPENDENCY "${CMAKE_CURRENT_SOURCE_DIR}/bolt-cst-spec.txt"
# WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}"
# )

View file

@ -1,167 +0,0 @@
#include <vector>
#include <optional>
#include "bolt/Text.hpp"
#include "bolt/Integer.hpp"
#include "bolt/ByteString.hpp"
external Integer;
external ByteString;
external TextRange;
// Tokens
node Token {
TextLoc start_loc;
}
node Equals : Token {}
node Colon : Token {}
node Dot : Token {}
node LParen : Token {}
node RParen : Token {}
node LBracket : Token {}
node RBracket : Token {}
node LBrace : Token {}
node RBrace : Token {}
node LetKeyword : Token {}
node MutKeyword : Token {}
node PubKeyword : Token {}
node TypeKeyword : Token {}
node ReturnKeyword : Token {}
node ModKeyword : Token {}
node StructKeyword : Token {}
node Invalid : Token {}
node EndOfFile : Token {}
node BlockStart : Token {}
node BlockEnd : Token {}
node LineFoldEnd : Token {}
node CustomOperator : Token {
ByteString text;
}
node Identifier : Token {
ByteString text;
}
node StringLiteral : Token {
ByteString text;
}
node IntegerLiteral : Token {
Integer value;
}
node QualifiedName {
List<Identifier> module_path;
Identifier name;
}
node SourceElement {}
node LetBodyElement {}
// Type expressions
node TypeExpression {}
node ReferenceTypeExpression : TypeExpression {
QualifiedName name;
}
// Patterns
node Pattern {}
node BindPattern : Pattern {
Identifier name;
}
// Expresssions
node Expression {}
node ReferenceExpression : Expression {
Identifier name;
}
node ConstantExpression : Expression {
Variant<StringLiteral, IntegerLiteral> token;
}
node CallExpression : Expression {
Expression function;
List<Expression> args;
}
// Statements
node Statement : LetBodyElement {}
node ExpressionStatement : Statement, SourceElement {
Expression expression;
}
node ReturnStatement : Statement {
ReturnKeyword return_keyword;
Expression expression;
}
// Other nodes
node TypeAssert {
Colon colon;
TypeExpression type_expression;
}
node Param {
Pattern pattern;
TypeAssert type_assert;
}
// Declarations
node LetBody {}
node LetBlockBody : LetBody {
BlockStart block_start;
List<LetBodyElement> elements;
}
node LetExprBody : LetBody {
Equals equals;
Expression expression;
}
node LetDeclaration : SourceElement, LetBodyElement {
Option<PubKeyword> pub_keyword;
LetKeyword let_keywod;
Option<MutKeyword> mut_keyword;
Pattern pattern;
List<Param> params;
Option<TypeAssert> type_assert;
Option<LetBody> body;
}
node StructDeclField {
Identifier name;
Colon colon;
TypeExpression type_expression;
}
node StructDecl : SourceElement {
StructKeyword struct_keyword;
Identifier name;
Dot dot;
List<StructDeclField> fields;
}
node SourceFile {
List<SourceElement> elements;
}

File diff suppressed because it is too large Load diff

951
include/bolt/CSTVisitor.hpp Normal file
View file

@ -0,0 +1,951 @@
#pragma once
#include "bolt/CST.hpp"
namespace bolt {
template<typename D, typename R = void>
class CSTVisitor {
public:
void visit(Node* N) {
switch (N->getKind()) {
case NodeKind::Equals:
return static_cast<D*>(this)->visitEquals(static_cast<Equals*>(N));
case NodeKind::Colon:
return static_cast<D*>(this)->visitColon(static_cast<Colon*>(N));
case NodeKind::Comma:
return static_cast<D*>(this)->visitComma(static_cast<Comma*>(N));
case NodeKind::Dot:
return static_cast<D*>(this)->visitDot(static_cast<Dot*>(N));
case NodeKind::DotDot:
return static_cast<D*>(this)->visitDotDot(static_cast<DotDot*>(N));
case NodeKind::Tilde:
return static_cast<D*>(this)->visitTilde(static_cast<Tilde*>(N));
case NodeKind::LParen:
return static_cast<D*>(this)->visitLParen(static_cast<LParen*>(N));
case NodeKind::RParen:
return static_cast<D*>(this)->visitRParen(static_cast<RParen*>(N));
case NodeKind::LBracket:
return static_cast<D*>(this)->visitLBracket(static_cast<LBracket*>(N));
case NodeKind::RBracket:
return static_cast<D*>(this)->visitRBracket(static_cast<RBracket*>(N));
case NodeKind::LBrace:
return static_cast<D*>(this)->visitLBrace(static_cast<LBrace*>(N));
case NodeKind::RBrace:
return static_cast<D*>(this)->visitRBrace(static_cast<RBrace*>(N));
case NodeKind::RArrow:
return static_cast<D*>(this)->visitRArrow(static_cast<RArrow*>(N));
case NodeKind::RArrowAlt:
return static_cast<D*>(this)->visitRArrowAlt(static_cast<RArrowAlt*>(N));
case NodeKind::LetKeyword:
return static_cast<D*>(this)->visitLetKeyword(static_cast<LetKeyword*>(N));
case NodeKind::MutKeyword:
return static_cast<D*>(this)->visitMutKeyword(static_cast<MutKeyword*>(N));
case NodeKind::PubKeyword:
return static_cast<D*>(this)->visitPubKeyword(static_cast<PubKeyword*>(N));
case NodeKind::TypeKeyword:
return static_cast<D*>(this)->visitTypeKeyword(static_cast<TypeKeyword*>(N));
case NodeKind::ReturnKeyword:
return static_cast<D*>(this)->visitReturnKeyword(static_cast<ReturnKeyword*>(N));
case NodeKind::ModKeyword:
return static_cast<D*>(this)->visitModKeyword(static_cast<ModKeyword*>(N));
case NodeKind::StructKeyword:
return static_cast<D*>(this)->visitStructKeyword(static_cast<StructKeyword*>(N));
case NodeKind::ClassKeyword:
return static_cast<D*>(this)->visitClassKeyword(static_cast<ClassKeyword*>(N));
case NodeKind::InstanceKeyword:
return static_cast<D*>(this)->visitInstanceKeyword(static_cast<InstanceKeyword*>(N));
case NodeKind::ElifKeyword:
return static_cast<D*>(this)->visitElifKeyword(static_cast<ElifKeyword*>(N));
case NodeKind::IfKeyword:
return static_cast<D*>(this)->visitIfKeyword(static_cast<IfKeyword*>(N));
case NodeKind::ElseKeyword:
return static_cast<D*>(this)->visitElseKeyword(static_cast<ElseKeyword*>(N));
case NodeKind::Invalid:
return static_cast<D*>(this)->visitInvalid(static_cast<Invalid*>(N));
case NodeKind::EndOfFile:
return static_cast<D*>(this)->visitEndOfFile(static_cast<EndOfFile*>(N));
case NodeKind::BlockStart:
return static_cast<D*>(this)->visitBlockStart(static_cast<BlockStart*>(N));
case NodeKind::BlockEnd:
return static_cast<D*>(this)->visitBlockEnd(static_cast<BlockEnd*>(N));
case NodeKind::LineFoldEnd:
return static_cast<D*>(this)->visitLineFoldEnd(static_cast<LineFoldEnd*>(N));
case NodeKind::CustomOperator:
return static_cast<D*>(this)->visitCustomOperator(static_cast<CustomOperator*>(N));
case NodeKind::Assignment:
return static_cast<D*>(this)->visitAssignment(static_cast<Assignment*>(N));
case NodeKind::Identifier:
return static_cast<D*>(this)->visitIdentifier(static_cast<Identifier*>(N));
case NodeKind::StringLiteral:
return static_cast<D*>(this)->visitStringLiteral(static_cast<StringLiteral*>(N));
case NodeKind::IntegerLiteral:
return static_cast<D*>(this)->visitIntegerLiteral(static_cast<IntegerLiteral*>(N));
case NodeKind::QualifiedName:
return static_cast<D*>(this)->visitQualifiedName(static_cast<QualifiedName*>(N));
case NodeKind::TypeclassConstraintExpression:
return static_cast<D*>(this)->visitTypeclassConstraintExpression(static_cast<TypeclassConstraintExpression*>(N));
case NodeKind::EqualityConstraintExpression:
return static_cast<D*>(this)->visitEqualityConstraintExpression(static_cast<EqualityConstraintExpression*>(N));
case NodeKind::QualifiedTypeExpression:
return static_cast<D*>(this)->visitQualifiedTypeExpression(static_cast<QualifiedTypeExpression*>(N));
case NodeKind::ReferenceTypeExpression:
return static_cast<D*>(this)->visitReferenceTypeExpression(static_cast<ReferenceTypeExpression*>(N));
case NodeKind::ArrowTypeExpression:
return static_cast<D*>(this)->visitArrowTypeExpression(static_cast<ArrowTypeExpression*>(N));
case NodeKind::VarTypeExpression:
return static_cast<D*>(this)->visitVarTypeExpression(static_cast<VarTypeExpression*>(N));
case NodeKind::BindPattern:
return static_cast<D*>(this)->visitBindPattern(static_cast<BindPattern*>(N));
case NodeKind::ReferenceExpression:
return static_cast<D*>(this)->visitReferenceExpression(static_cast<ReferenceExpression*>(N));
case NodeKind::NestedExpression:
return static_cast<D*>(this)->visitNestedExpression(static_cast<NestedExpression*>(N));
case NodeKind::ConstantExpression:
return static_cast<D*>(this)->visitConstantExpression(static_cast<ConstantExpression*>(N));
case NodeKind::CallExpression:
return static_cast<D*>(this)->visitCallExpression(static_cast<CallExpression*>(N));
case NodeKind::InfixExpression:
return static_cast<D*>(this)->visitInfixExpression(static_cast<InfixExpression*>(N));
case NodeKind::PrefixExpression:
return static_cast<D*>(this)->visitPrefixExpression(static_cast<PrefixExpression*>(N));
case NodeKind::ExpressionStatement:
return static_cast<D*>(this)->visitExpressionStatement(static_cast<ExpressionStatement*>(N));
case NodeKind::ReturnStatement:
return static_cast<D*>(this)->visitReturnStatement(static_cast<ReturnStatement*>(N));
case NodeKind::IfStatement:
return static_cast<D*>(this)->visitIfStatement(static_cast<IfStatement*>(N));
case NodeKind::IfStatementPart:
return static_cast<D*>(this)->visitIfStatementPart(static_cast<IfStatementPart*>(N));
case NodeKind::TypeAssert:
return static_cast<D*>(this)->visitTypeAssert(static_cast<TypeAssert*>(N));
case NodeKind::Parameter:
return static_cast<D*>(this)->visitParameter(static_cast<Parameter*>(N));
case NodeKind::LetBlockBody:
return static_cast<D*>(this)->visitLetBlockBody(static_cast<LetBlockBody*>(N));
case NodeKind::LetExprBody:
return static_cast<D*>(this)->visitLetExprBody(static_cast<LetExprBody*>(N));
case NodeKind::LetDeclaration:
return static_cast<D*>(this)->visitLetDeclaration(static_cast<LetDeclaration*>(N));
case NodeKind::StructDeclarationField:
return static_cast<D*>(this)->visitStructDeclarationField(static_cast<StructDeclarationField*>(N));
case NodeKind::StructDeclaration:
return static_cast<D*>(this)->visitStructDeclaration(static_cast<StructDeclaration*>(N));
case NodeKind::ClassDeclaration:
return static_cast<D*>(this)->visitClassDeclaration(static_cast<ClassDeclaration*>(N));
case NodeKind::InstanceDeclaration:
return static_cast<D*>(this)->visitInstanceDeclaration(static_cast<InstanceDeclaration*>(N));
case NodeKind::SourceFile:
return static_cast<D*>(this)->visitSourceFile(static_cast<SourceFile*>(N));
}
}
protected:
void visitNode(Node* N) {
visitEachChild(N);
}
void visitToken(Token* N) {
visitNode(N);
}
void visitEquals(Equals* N) {
visitToken(N);
}
void visitColon(Colon* N) {
visitToken(N);
}
void visitComma(Comma* N) {
visitToken(N);
}
void visitDot(Dot* N) {
visitToken(N);
}
void visitDotDot(DotDot* N) {
visitToken(N);
}
void visitTilde(Tilde* N) {
visitToken(N);
}
void visitLParen(LParen* N) {
visitToken(N);
}
void visitRParen(RParen* N) {
visitToken(N);
}
void visitLBracket(LBracket* N) {
visitToken(N);
}
void visitRBracket(RBracket* N) {
visitToken(N);
}
void visitLBrace(LBrace* N) {
visitToken(N);
}
void visitRBrace(RBrace* N) {
visitToken(N);
}
void visitRArrow(RArrow* N) {
visitToken(N);
}
void visitRArrowAlt(RArrowAlt* N) {
visitToken(N);
}
void visitLetKeyword(LetKeyword* N) {
visitToken(N);
}
void visitMutKeyword(MutKeyword* N) {
visitToken(N);
}
void visitPubKeyword(PubKeyword* N) {
visitToken(N);
}
void visitTypeKeyword(TypeKeyword* N) {
visitToken(N);
}
void visitReturnKeyword(ReturnKeyword* N) {
visitToken(N);
}
void visitModKeyword(ModKeyword* N) {
visitToken(N);
}
void visitStructKeyword(StructKeyword* N) {
visitToken(N);
}
void visitClassKeyword(ClassKeyword* N) {
visitToken(N);
}
void visitInstanceKeyword(InstanceKeyword* N) {
visitToken(N);
}
void visitElifKeyword(ElifKeyword* N) {
visitToken(N);
}
void visitIfKeyword(IfKeyword* N) {
visitToken(N);
}
void visitElseKeyword(ElseKeyword* N) {
visitToken(N);
}
void visitInvalid(Invalid* N) {
visitToken(N);
}
void visitEndOfFile(EndOfFile* N) {
visitToken(N);
}
void visitBlockStart(BlockStart* N) {
visitToken(N);
}
void visitBlockEnd(BlockEnd* N) {
visitToken(N);
}
void visitLineFoldEnd(LineFoldEnd* N) {
visitToken(N);
}
void visitCustomOperator(CustomOperator* N) {
visitToken(N);
}
void visitAssignment(Assignment* N) {
visitToken(N);
}
void visitIdentifier(Identifier* N) {
visitToken(N);
}
void visitStringLiteral(StringLiteral* N) {
visitToken(N);
}
void visitIntegerLiteral(IntegerLiteral* N) {
visitToken(N);
}
void visitQualifiedName(QualifiedName* N) {
visitNode(N);
}
void visitConstraintExpression(ConstraintExpression* N) {
visitNode(N);
}
void visitTypeclassConstraintExpression(TypeclassConstraintExpression* N) {
visitConstraintExpression(N);
}
void visitEqualityConstraintExpression(EqualityConstraintExpression* N) {
visitConstraintExpression(N);
}
void visitTypeExpression(TypeExpression* N) {
visitNode(N);
}
void visitQualifiedTypeExpression(QualifiedTypeExpression* N) {
visitTypeExpression(N);
}
void visitReferenceTypeExpression(ReferenceTypeExpression* N) {
visitTypeExpression(N);
}
void visitArrowTypeExpression(ArrowTypeExpression* N) {
visitTypeExpression(N);
}
void visitVarTypeExpression(VarTypeExpression* N) {
visitTypeExpression(N);
}
void visitPattern(Pattern* N) {
visitNode(N);
}
void visitBindPattern(BindPattern* N) {
visitPattern(N);
}
void visitExpression(Expression* N) {
visitNode(N);
}
void visitReferenceExpression(ReferenceExpression* N) {
visitExpression(N);
}
void visitNestedExpression(NestedExpression* N) {
visitExpression(N);
}
void visitConstantExpression(ConstantExpression* N) {
visitExpression(N);
}
void visitCallExpression(CallExpression* N) {
visitExpression(N);
}
void visitInfixExpression(InfixExpression* N) {
visitExpression(N);
}
void visitPrefixExpression(PrefixExpression* N) {
visitExpression(N);
}
void visitStatement(Statement* N) {
visitNode(N);
}
void visitExpressionStatement(ExpressionStatement* N) {
visitStatement(N);
}
void visitReturnStatement(ReturnStatement* N) {
visitStatement(N);
}
void visitIfStatement(IfStatement* N) {
visitStatement(N);
}
void visitIfStatementPart(IfStatementPart* N) {
visitNode(N);
}
void visitTypeAssert(TypeAssert* N) {
visitNode(N);
}
void visitParameter(Parameter* N) {
visitNode(N);
}
void visitLetBody(LetBody* N) {
visitNode(N);
}
void visitLetBlockBody(LetBlockBody* N) {
visitLetBody(N);
}
void visitLetExprBody(LetExprBody* N) {
visitLetBody(N);
}
void visitLetDeclaration(LetDeclaration* N) {
visitNode(N);
}
void visitStructDeclarationField(StructDeclarationField* N) {
visitNode(N);
}
void visitStructDeclaration(StructDeclaration* N) {
visitNode(N);
}
void visitClassDeclaration(ClassDeclaration* N) {
visitNode(N);
}
void visitInstanceDeclaration(InstanceDeclaration* N) {
visitNode(N);
}
void visitSourceFile(SourceFile* N) {
visitNode(N);
}
public:
void visitEachChild(Node* N) {
switch (N->getKind()) {
case NodeKind::Equals:
visitEachChild(static_cast<Equals*>(N));
break;
case NodeKind::Colon:
visitEachChild(static_cast<Colon*>(N));
break;
case NodeKind::Comma:
visitEachChild(static_cast<Comma*>(N));
break;
case NodeKind::Dot:
visitEachChild(static_cast<Dot*>(N));
break;
case NodeKind::DotDot:
visitEachChild(static_cast<DotDot*>(N));
break;
case NodeKind::Tilde:
visitEachChild(static_cast<Tilde*>(N));
break;
case NodeKind::LParen:
visitEachChild(static_cast<LParen*>(N));
break;
case NodeKind::RParen:
visitEachChild(static_cast<RParen*>(N));
break;
case NodeKind::LBracket:
visitEachChild(static_cast<LBracket*>(N));
break;
case NodeKind::RBracket:
visitEachChild(static_cast<RBracket*>(N));
break;
case NodeKind::LBrace:
visitEachChild(static_cast<LBrace*>(N));
break;
case NodeKind::RBrace:
visitEachChild(static_cast<RBrace*>(N));
break;
case NodeKind::RArrow:
visitEachChild(static_cast<RArrow*>(N));
break;
case NodeKind::RArrowAlt:
visitEachChild(static_cast<RArrowAlt*>(N));
break;
case NodeKind::LetKeyword:
visitEachChild(static_cast<LetKeyword*>(N));
break;
case NodeKind::MutKeyword:
visitEachChild(static_cast<MutKeyword*>(N));
break;
case NodeKind::PubKeyword:
visitEachChild(static_cast<PubKeyword*>(N));
break;
case NodeKind::TypeKeyword:
visitEachChild(static_cast<TypeKeyword*>(N));
break;
case NodeKind::ReturnKeyword:
visitEachChild(static_cast<ReturnKeyword*>(N));
break;
case NodeKind::ModKeyword:
visitEachChild(static_cast<ModKeyword*>(N));
break;
case NodeKind::StructKeyword:
visitEachChild(static_cast<StructKeyword*>(N));
break;
case NodeKind::ClassKeyword:
visitEachChild(static_cast<ClassKeyword*>(N));
break;
case NodeKind::InstanceKeyword:
visitEachChild(static_cast<InstanceKeyword*>(N));
break;
case NodeKind::ElifKeyword:
visitEachChild(static_cast<ElifKeyword*>(N));
break;
case NodeKind::IfKeyword:
visitEachChild(static_cast<IfKeyword*>(N));
break;
case NodeKind::ElseKeyword:
visitEachChild(static_cast<ElseKeyword*>(N));
break;
case NodeKind::Invalid:
visitEachChild(static_cast<Invalid*>(N));
break;
case NodeKind::EndOfFile:
visitEachChild(static_cast<EndOfFile*>(N));
break;
case NodeKind::BlockStart:
visitEachChild(static_cast<BlockStart*>(N));
break;
case NodeKind::BlockEnd:
visitEachChild(static_cast<BlockEnd*>(N));
break;
case NodeKind::LineFoldEnd:
visitEachChild(static_cast<LineFoldEnd*>(N));
break;
case NodeKind::CustomOperator:
visitEachChild(static_cast<CustomOperator*>(N));
break;
case NodeKind::Assignment:
visitEachChild(static_cast<Assignment*>(N));
break;
case NodeKind::Identifier:
visitEachChild(static_cast<Identifier*>(N));
break;
case NodeKind::StringLiteral:
visitEachChild(static_cast<StringLiteral*>(N));
break;
case NodeKind::IntegerLiteral:
visitEachChild(static_cast<IntegerLiteral*>(N));
break;
case NodeKind::QualifiedName:
visitEachChild(static_cast<QualifiedName*>(N));
break;
case NodeKind::TypeclassConstraintExpression:
visitEachChild(static_cast<TypeclassConstraintExpression*>(N));
break;
case NodeKind::EqualityConstraintExpression:
visitEachChild(static_cast<EqualityConstraintExpression*>(N));
break;
case NodeKind::QualifiedTypeExpression:
visitEachChild(static_cast<QualifiedTypeExpression*>(N));
break;
case NodeKind::ReferenceTypeExpression:
visitEachChild(static_cast<ReferenceTypeExpression*>(N));
break;
case NodeKind::ArrowTypeExpression:
visitEachChild(static_cast<ArrowTypeExpression*>(N));
break;
case NodeKind::VarTypeExpression:
visitEachChild(static_cast<VarTypeExpression*>(N));
break;
case NodeKind::BindPattern:
visitEachChild(static_cast<BindPattern*>(N));
break;
case NodeKind::ReferenceExpression:
visitEachChild(static_cast<ReferenceExpression*>(N));
break;
case NodeKind::NestedExpression:
visitEachChild(static_cast<NestedExpression*>(N));
break;
case NodeKind::ConstantExpression:
visitEachChild(static_cast<ConstantExpression*>(N));
break;
case NodeKind::CallExpression:
visitEachChild(static_cast<CallExpression*>(N));
break;
case NodeKind::InfixExpression:
visitEachChild(static_cast<InfixExpression*>(N));
break;
case NodeKind::PrefixExpression:
visitEachChild(static_cast<PrefixExpression*>(N));
break;
case NodeKind::ExpressionStatement:
visitEachChild(static_cast<ExpressionStatement*>(N));
break;
case NodeKind::ReturnStatement:
visitEachChild(static_cast<ReturnStatement*>(N));
break;
case NodeKind::IfStatement:
visitEachChild(static_cast<IfStatement*>(N));
break;
case NodeKind::IfStatementPart:
visitEachChild(static_cast<IfStatementPart*>(N));
break;
case NodeKind::TypeAssert:
visitEachChild(static_cast<TypeAssert*>(N));
break;
case NodeKind::Parameter:
visitEachChild(static_cast<Parameter*>(N));
break;
case NodeKind::LetBlockBody:
visitEachChild(static_cast<LetBlockBody*>(N));
break;
case NodeKind::LetExprBody:
visitEachChild(static_cast<LetExprBody*>(N));
break;
case NodeKind::LetDeclaration:
visitEachChild(static_cast<LetDeclaration*>(N));
break;
case NodeKind::StructDeclaration:
visitEachChild(static_cast<StructDeclaration*>(N));
break;
case NodeKind::StructDeclarationField:
visitEachChild(static_cast<StructDeclarationField*>(N));
break;
case NodeKind::ClassDeclaration:
visitEachChild(static_cast<ClassDeclaration*>(N));
break;
case NodeKind::InstanceDeclaration:
visitEachChild(static_cast<InstanceDeclaration*>(N));
break;
case NodeKind::SourceFile:
visitEachChild(static_cast<SourceFile*>(N));
break;
default:
ZEN_UNREACHABLE
}
}
#define BOLT_VISIT(node) static_cast<D*>(this)->visit(node)
void visitEachChild(Equals* N) {
}
void visitEachChild(Colon* N) {
}
void visitEachChild(Comma* N) {
}
void visitEachChild(Dot* N) {
}
void visitEachChild(DotDot* N) {
}
void visitEachChild(Tilde* N) {
}
void visitEachChild(LParen* N) {
}
void visitEachChild(RParen* N) {
}
void visitEachChild(LBracket* N) {
}
void visitEachChild(RBracket* N) {
}
void visitEachChild(LBrace* N) {
}
void visitEachChild(RBrace* N) {
}
void visitEachChild(RArrow* N) {
}
void visitEachChild(RArrowAlt* N) {
}
void visitEachChild(LetKeyword* N) {
}
void visitEachChild(MutKeyword* N) {
}
void visitEachChild(PubKeyword* N) {
}
void visitEachChild(TypeKeyword* N) {
}
void visitEachChild(ReturnKeyword* N) {
}
void visitEachChild(ModKeyword* N) {
}
void visitEachChild(StructKeyword* N) {
}
void visitEachChild(ClassKeyword* N) {
}
void visitEachChild(InstanceKeyword* N) {
}
void visitEachChild(ElifKeyword* N) {
}
void visitEachChild(IfKeyword* N) {
}
void visitEachChild(ElseKeyword* N) {
}
void visitEachChild(Invalid* N) {
}
void visitEachChild(EndOfFile* N) {
}
void visitEachChild(BlockStart* N) {
}
void visitEachChild(BlockEnd* N) {
}
void visitEachChild(LineFoldEnd* N) {
}
void visitEachChild(CustomOperator* N) {
}
void visitEachChild(Assignment* N) {
}
void visitEachChild(Identifier* N) {
}
void visitEachChild(StringLiteral* N) {
}
void visitEachChild(IntegerLiteral* N) {
}
void visitEachChild(QualifiedName* N) {
for (auto Name: N->ModulePath) {
BOLT_VISIT(Name);
}
BOLT_VISIT(N->Name);
}
void visitEachChild(TypeclassConstraintExpression* N) {
BOLT_VISIT(N->Name);
for (auto TE: N->TEs) {
BOLT_VISIT(TE);
}
}
void visitEachChild(EqualityConstraintExpression* N) {
BOLT_VISIT(N->Left);
BOLT_VISIT(N->Tilde);
BOLT_VISIT(N->Right);
}
void visitEachChild(QualifiedTypeExpression* N) {
for (auto [CE, Comma]: N->Constraints) {
BOLT_VISIT(CE);
if (Comma) {
BOLT_VISIT(Comma);
}
}
BOLT_VISIT(N->RArrowAlt);
BOLT_VISIT(N->TE);
}
void visitEachChild(ReferenceTypeExpression* N) {
BOLT_VISIT(N->Name);
}
void visitEachChild(ArrowTypeExpression* N) {
for (auto PT: N->ParamTypes) {
BOLT_VISIT(PT);
}
BOLT_VISIT(N->ReturnType);
}
void visitEachChild(VarTypeExpression* N) {
BOLT_VISIT(N->Name);
}
void visitEachChild(BindPattern* N) {
BOLT_VISIT(N->Name);
}
void visitEachChild(ReferenceExpression* N) {
BOLT_VISIT(N->Name);
}
void visitEachChild(NestedExpression* N) {
BOLT_VISIT(N->LParen);
BOLT_VISIT(N->Inner);
BOLT_VISIT(N->RParen);
}
void visitEachChild(ConstantExpression* N) {
BOLT_VISIT(N->Token);
}
void visitEachChild(CallExpression* N) {
BOLT_VISIT(N->Function);
for (auto Arg: N->Args) {
BOLT_VISIT(Arg);
}
}
void visitEachChild(InfixExpression* N) {
BOLT_VISIT(N->LHS);
BOLT_VISIT(N->Operator);
BOLT_VISIT(N->RHS);
}
void visitEachChild(PrefixExpression* N) {
BOLT_VISIT(N->Operator);
BOLT_VISIT(N->Argument);
}
void visitEachChild(ExpressionStatement* N) {
BOLT_VISIT(N->Expression);
}
void visitEachChild(ReturnStatement* N) {
BOLT_VISIT(N->ReturnKeyword);
BOLT_VISIT(N->Expression);
}
void visitEachChild(IfStatement* N) {
for (auto Part: N->Parts) {
BOLT_VISIT(Part);
}
}
void visitEachChild(IfStatementPart* N) {
BOLT_VISIT(N->Keyword);
if (N->Test != nullptr) {
BOLT_VISIT(N->Test);
}
BOLT_VISIT(N->BlockStart);
for (auto Element: N->Elements) {
BOLT_VISIT(Element);
}
}
void visitEachChild(TypeAssert* N) {
BOLT_VISIT(N->Colon);
BOLT_VISIT(N->TypeExpression);
}
void visitEachChild(Parameter* N) {
BOLT_VISIT(N->Pattern);
if (N->TypeAssert != nullptr) {
BOLT_VISIT(N->TypeAssert);
}
}
void visitEachChild(LetBlockBody* N) {
BOLT_VISIT(N->BlockStart);
for (auto Element: N->Elements) {
BOLT_VISIT(Element);
}
}
void visitEachChild(LetExprBody* N) {
BOLT_VISIT(N->Equals);
BOLT_VISIT(N->Expression);
}
void visitEachChild(LetDeclaration* N) {
if (N->PubKeyword) {
BOLT_VISIT(N->PubKeyword);
}
BOLT_VISIT(N->LetKeyword);
if (N->MutKeyword) {
BOLT_VISIT(N->MutKeyword);
}
BOLT_VISIT(N->Pattern);
for (auto Param: N->Params) {
BOLT_VISIT(Param);
}
if (N->TypeAssert) {
BOLT_VISIT(N->TypeAssert);
}
if (N->Body) {
BOLT_VISIT(N->Body);
}
}
void visitEachChild(StructDeclarationField* N) {
BOLT_VISIT(N->Name);
BOLT_VISIT(N->Colon);
BOLT_VISIT(N->TypeExpression);
}
void visitEachChild(StructDeclaration* N) {
if (N->PubKeyword) {
BOLT_VISIT(N->PubKeyword);
}
BOLT_VISIT(N->StructKeyword);
BOLT_VISIT(N->Name);
BOLT_VISIT(N->StructKeyword);
for (auto Field: N->Fields) {
BOLT_VISIT(Field);
}
}
void visitEachChild(ClassDeclaration* N) {
if (N->PubKeyword) {
BOLT_VISIT(N->PubKeyword);
}
BOLT_VISIT(N->ClassKeyword);
BOLT_VISIT(N->Name);
for (auto Name: N->TypeVars) {
BOLT_VISIT(Name);
}
BOLT_VISIT(N->BlockStart);
for (auto Element: N->Elements) {
BOLT_VISIT(Element);
}
}
void visitEachChild(InstanceDeclaration* N) {
BOLT_VISIT(N->InstanceKeyword);
BOLT_VISIT(N->Name);
for (auto TE: N->TypeExps) {
BOLT_VISIT(TE);
}
BOLT_VISIT(N->BlockStart);
for (auto Element: N->Elements) {
BOLT_VISIT(Element);
}
}
void visitEachChild(SourceFile* N) {
for (auto Element: N->Elements) {
BOLT_VISIT(Element);
}
}
};
}

View file

@ -5,6 +5,7 @@
#include "bolt/ByteString.hpp"
#include "bolt/CST.hpp"
#include "bolt/Diagnostics.hpp"
#include <istream>
#include <unordered_map>
@ -14,6 +15,30 @@
namespace bolt {
class LanguageConfig {
enum ConfigFlags {
ConfigFlags_TypeVarsRequireForall = 1 << 0,
};
unsigned Flags;
public:
void setTypeVarsRequireForall(bool Enable) {
if (Enable) {
Flags |= ConfigFlags_TypeVarsRequireForall;
} else {
Flags |= ~ConfigFlags_TypeVarsRequireForall;
}
}
bool typeVarsRequireForall() const noexcept {
return Flags & ConfigFlags_TypeVarsRequireForall;
}
};
class DiagnosticEngine;
class Node;
@ -23,11 +48,12 @@ namespace bolt {
using TVSub = std::unordered_map<TVar*, Type*>;
using TVSet = std::unordered_set<TVar*>;
using TypeclassContext = std::unordered_set<TypeclassId>;
enum class TypeKind : unsigned char {
Var,
Con,
Arrow,
Any,
Tuple,
};
@ -70,15 +96,45 @@ namespace bolt {
inline TCon(const size_t Id, std::vector<Type*> Args, ByteString DisplayName):
Type(TypeKind::Con), Id(Id), Args(Args), DisplayName(DisplayName) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Con;
}
};
enum class VarKind {
Rigid,
Unification,
};
class TVar : public Type {
public:
const size_t Id;
VarKind VK;
inline TVar(size_t Id):
Type(TypeKind::Var), Id(Id) {}
TypeclassContext Contexts;
inline TVar(size_t Id, VarKind VK):
Type(TypeKind::Var), Id(Id), VK(VK) {}
inline VarKind getVarKind() const noexcept {
return VK;
}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Var;
}
};
class TVarRigid : public TVar {
public:
ByteString Name;
inline TVarRigid(size_t Id, ByteString Name):
TVar(Id, VarKind::Rigid), Name(Name) {}
};
@ -95,6 +151,10 @@ namespace bolt {
ParamTypes(ParamTypes),
ReturnType(ReturnType) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Arrow;
}
};
class TTuple : public Type {
@ -105,13 +165,9 @@ namespace bolt {
inline TTuple(std::vector<Type*> ElementTypes):
Type(TypeKind::Tuple), ElementTypes(ElementTypes) {}
};
class TAny : public Type {
public:
inline TAny():
Type(TypeKind::Any) {}
static bool classof(const Type* Ty) {
return Ty->getKind() == TypeKind::Tuple;
}
};
@ -126,26 +182,6 @@ namespace bolt {
using ConstraintSet = std::vector<Constraint*>;
class Forall {
public:
TVSet* TVs;
ConstraintSet* Constraints;
Type* Type;
inline Forall(class Type* Type):
TVs(new TVSet), Constraints(new ConstraintSet), Type(Type) {}
inline Forall(
TVSet& TVs,
ConstraintSet& Constraints,
class Type* Type
): TVs(&TVs),
Constraints(&Constraints),
Type(Type) {}
};
enum class SchemeKind : unsigned char {
Forall,
};
@ -154,61 +190,102 @@ namespace bolt {
const SchemeKind Kind;
union {
Forall F;
};
protected:
inline Scheme(SchemeKind Kind):
Kind(Kind) {}
public:
inline Scheme(Forall F):
Kind(SchemeKind::Forall), F(F) {}
inline Scheme(const Scheme& Other):
Kind(Other.Kind) {
switch (Kind) {
case SchemeKind::Forall:
F = Other.F;
break;
}
}
inline Scheme(Scheme&& Other):
Kind(std::move(Other.Kind)) {
switch (Kind) {
case SchemeKind::Forall:
F = std::move(Other.F);
break;
}
}
template<typename T>
T& as();
template<>
Forall& as<Forall>() {
ZEN_ASSERT(Kind == SchemeKind::Forall);
return F;
}
inline SchemeKind getKind() const noexcept {
return Kind;
}
~Scheme() {
switch (Kind) {
case SchemeKind::Forall:
F.~Forall();
break;
}
virtual ~Scheme() {}
};
class Forall : public Scheme {
public:
TVSet* TVs;
ConstraintSet* Constraints;
class Type* Type;
inline Forall(class Type* Type):
Scheme(SchemeKind::Forall), TVs(new TVSet), Constraints(new ConstraintSet), Type(Type) {}
inline Forall(
TVSet* TVs,
ConstraintSet* Constraints,
class Type* Type
): Scheme(SchemeKind::Forall),
TVs(TVs),
Constraints(Constraints),
Type(Type) {}
static bool classof(const Scheme* Scm) {
return Scm->getKind() == SchemeKind::Forall;
}
};
using TypeEnv = std::unordered_map<ByteString, Scheme>;
/* class Scheme { */
/* const SchemeKind Kind; */
/* public: */
/* inline Scheme(Forall F): */
/* Kind(SchemeKind::Forall), F(F) {} */
/* inline Scheme(const Scheme& Other): */
/* Kind(Other.Kind) { */
/* switch (Kind) { */
/* case SchemeKind::Forall: */
/* F = Other.F; */
/* break; */
/* } */
/* } */
/* inline Scheme(Scheme&& Other): */
/* Kind(std::move(Other.Kind)) { */
/* switch (Kind) { */
/* case SchemeKind::Forall: */
/* F = std::move(Other.F); */
/* break; */
/* } */
/* } */
/* inline SchemeKind getKind() const noexcept { */
/* return Kind; */
/* } */
/* template<typename T> */
/* T& as(); */
/* template<> */
/* Forall& as<Forall>() { */
/* ZEN_ASSERT(Kind == SchemeKind::Forall); */
/* return F; */
/* } */
/* ~Scheme() { */
/* switch (Kind) { */
/* case SchemeKind::Forall: */
/* F.~Forall(); */
/* break; */
/* } */
/* } */
/* }; */
using TypeEnv = std::unordered_map<ByteString, Scheme*>;
enum class ConstraintKind {
Equal,
Class,
Many,
Empty,
};
@ -249,8 +326,8 @@ namespace bolt {
ConstraintSet& Elements;
inline CMany(ConstraintSet& Constraints):
Constraint(ConstraintKind::Many), Elements(Constraints) {}
inline CMany(ConstraintSet& Elements):
Constraint(ConstraintKind::Many), Elements(Elements) {}
};
@ -262,32 +339,76 @@ namespace bolt {
};
class InferContext {
class CClass : public Constraint {
public:
TVSet TVs;
ConstraintSet Constraints;
TypeEnv Env;
Type* ReturnType;
ByteString Name;
std::vector<Type*> Types;
InferContext* Parent;
inline CClass(ByteString Name, std::vector<Type*> Types):
Constraint(ConstraintKind::Class), Name(Name), Types(Types) {}
};
enum {
/**
* Indicates that the typing environment of the current context will not
* hold on to any bindings.
*
* Concretely, bindings that are assigned fall through to the parent
* context, where this process is repeated until an environment is found
* that is not pervious.
*/
InferContextFlags_PerviousEnv = 1 << 0,
};
using InferContextFlagsMask = unsigned;
class InferContext {
InferContextFlagsMask Flags = 0;
public:
/**
* A heap-allocated list of type variables that eventually will become part of a Forall scheme.
*/
TVSet* TVs;
/**
* A heap-allocated list of constraints that eventually will become part of a Forall scheme.
*/
ConstraintSet* Constraints;
TypeEnv Env;
Type* ReturnType = nullptr;
std::vector<TypeclassSignature> Classes;
inline void setIsEnvPervious(bool Enable) noexcept {
if (Enable) {
Flags |= InferContextFlags_PerviousEnv;
} else {
Flags &= ~InferContextFlags_PerviousEnv;
}
}
inline bool isEnvPervious() const noexcept {
return Flags & InferContextFlags_PerviousEnv;
}
//inline InferContext(InferContext* Parent, TVSet& TVs, ConstraintSet& Constraints, TypeEnv& Env, Type* ReturnType):
// Parent(Parent), TVs(TVs), Constraints(Constraints), Env(Env), ReturnType(ReturnType) {}
inline InferContext(InferContext* Parent = nullptr):
Parent(Parent), ReturnType(nullptr) {}
};
class Checker {
const LanguageConfig& Config;
DiagnosticEngine& DE;
size_t nextConTypeId = 0;
size_t nextTypeVarId = 0;
std::unordered_map<Node*, Type*> Mapping;
size_t NextConTypeId = 0;
size_t NextTypeVarId = 0;
std::unordered_map<Node*, InferContext*> CallGraph;
@ -295,44 +416,83 @@ namespace bolt {
Type* IntType;
Type* StringType;
TVSub Solution;
std::vector<InferContext*> Contexts;
/**
* Holds the current inferred type class contexts in a given LetDeclaration body.
*/
// std::vector<TypeclassContext*> TCCs;
InferContext& getContext();
void addConstraint(Constraint* Constraint);
void addClass(TypeclassSignature Sig);
void forwardDeclare(Node* Node);
Type* inferExpression(Expression* Expression);
Type* inferTypeExpression(TypeExpression* TE);
void inferBindings(Pattern* Pattern, Type* T, ConstraintSet& Constraints, TVSet& Tvs);
void inferBindings(Pattern* Pattern, Type* T, ConstraintSet* Constraints, TVSet* TVs);
void inferBindings(Pattern* Pattern, Type* T);
void infer(Node* node);
Constraint* convertToConstraint(ConstraintExpression* C);
TCon* createPrimConType();
TVar* createTypeVar();
TVarRigid* createRigidVar(ByteString Name);
InferContext* createInferContext();
void addBinding(ByteString Name, Scheme Scm);
void addBinding(ByteString Name, Scheme* Scm);
Scheme* lookup(ByteString Name);
/**
* Looks up a type/variable and ensures that it is a monomorphic type.
*
* This method is mainly syntactic sugar to make it clear in the code when a
* monomorphic type is expected.
*
* Note that if the type is not monomorphic the program will abort with a
* stack trace. It wil **not** print a user-friendly error message.
*
* \returns If the type/variable could not be found `nullptr` is returned.
* Otherwise, a [Type] is returned.
*/
Type* lookupMono(ByteString Name);
InferContext* lookupCall(Node* Source, SymbolPath Path);
/**
* Get the return type for the current context. If none could be found, the program will abort.
*/
Type* getReturnType();
Scheme* lookup(ByteString Name);
Type* instantiate(Scheme* S, Node* Source);
Type* instantiate(Scheme& S, Node* Source);
/* void addToTypeclassContexts(Node* N, std::vector<TypeclassContext>& Contexts); */
bool unify(Type* A, Type* B, TVSub& Solution);
std::unordered_map<ByteString, std::vector<InstanceDeclaration*>> InstanceMap;
std::vector<TypeclassContext> findInstanceContext(TCon* Ty, TypeclassId& Class, Node* Source);
void propagateClasses(TypeclassContext& Classes, Type* Ty, Node* Source);
void propagateClassTycon(TypeclassId& Class, TCon* Ty, Node* Source);
void checkTypeclassSigs(Node* N);
bool unify(Type* A, Type* B, Node* Source);
void solveCEqual(CEqual* C);
void solve(Constraint* Constraint, TVSub& Solution);
public:
Checker(DiagnosticEngine& DE);
Checker(const LanguageConfig& Config, DiagnosticEngine& DE);
TVSub check(SourceFile* SF);
void check(SourceFile* SF);
inline Type* getBoolType() {
return BoolType;
@ -346,7 +506,7 @@ namespace bolt {
return IntType;
}
Type* getType(Node* Node, const TVSub& Solution);
Type* getType(TypedNode* Node);
};

View file

@ -13,12 +13,30 @@
namespace bolt {
class Type;
class TCon;
class TVar;
using TypeclassId = ByteString;
struct TypeclassSignature {
using TypeclassId = ByteString;
TypeclassId Id;
std::vector<TVar*> Params;
bool operator<(const TypeclassSignature& Other) const;
bool operator==(const TypeclassSignature& Other) const;
};
enum class DiagnosticKind : unsigned char {
UnexpectedToken,
UnexpectedString,
BindingNotFound,
UnificationError,
TypeclassMissing,
InstanceNotFound,
ClassNotFound,
};
class Diagnostic : std::runtime_error {
@ -31,7 +49,7 @@ namespace bolt {
public:
DiagnosticKind getKind() const noexcept {
inline DiagnosticKind getKind() const noexcept {
return Kind;
}
@ -42,9 +60,9 @@ namespace bolt {
TextFile& File;
Token* Actual;
std::vector<NodeType> Expected;
std::vector<NodeKind> Expected;
inline UnexpectedTokenDiagnostic(TextFile& File, Token* Actual, std::vector<NodeType> Expected):
inline UnexpectedTokenDiagnostic(TextFile& File, Token* Actual, std::vector<NodeKind> Expected):
Diagnostic(DiagnosticKind::UnexpectedToken), File(File), Actual(Actual), Expected(Expected) {}
};
@ -84,6 +102,39 @@ namespace bolt {
};
class TypeclassMissingDiagnostic : public Diagnostic {
public:
TypeclassSignature Sig;
LetDeclaration* Decl;
inline TypeclassMissingDiagnostic(TypeclassSignature Sig, LetDeclaration* Decl):
Diagnostic(DiagnosticKind::TypeclassMissing), Sig(Sig), Decl(Decl) {}
};
class InstanceNotFoundDiagnostic : public Diagnostic {
public:
ByteString TypeclassName;
TCon* Ty;
Node* Source;
inline InstanceNotFoundDiagnostic(ByteString TypeclassName, TCon* Ty, Node* Source):
Diagnostic(DiagnosticKind::InstanceNotFound), TypeclassName(TypeclassName), Ty(Ty), Source(Source) {}
};
class ClassNotFoundDiagnostic : public Diagnostic {
public:
ByteString Name;
inline ClassNotFoundDiagnostic(ByteString Name):
Diagnostic(DiagnosticKind::ClassNotFound), Name(Name) {}
};
class DiagnosticEngine {
protected:

View file

@ -2,9 +2,10 @@
#pragma once
#include <unordered_map>
#include <optional>
#include <optional>
#include "bolt/CST.hpp"
#include "bolt/Stream.hpp"
namespace bolt {
@ -68,14 +69,24 @@ namespace bolt {
Token* peekFirstTokenAfterModifiers();
Token* expectToken(NodeType Ty);
Token* expectToken(NodeKind Ty);
template<typename T>
T* expectToken() {
return static_cast<T*>(expectToken(getNodeType<T>()));
}
Expression* parseInfixOperatorAfterExpression(Expression* LHS, int MinPrecedence);
TypeExpression* parsePrimitiveTypeExpression();
Expression* parsePrimitiveExpression();
ConstraintExpression* parseConstraintExpression();
TypeExpression* parsePrimitiveTypeExpression();
TypeExpression* parseQualifiedTypeExpression();
TypeExpression* parseArrowTypeExpression();
VarTypeExpression* parseVarTypeExpression();
public:
Parser(TextFile& File, Stream<Token*>& S);
@ -86,7 +97,7 @@ namespace bolt {
Pattern* parsePattern();
Param* parseParam();
Parameter* parseParam();
ReferenceExpression* parseReferenceExpression();
@ -106,6 +117,12 @@ namespace bolt {
LetDeclaration* parseLetDeclaration();
Node* parseClassElement();
ClassDeclaration* parseClassDeclaration();
InstanceDeclaration* parseInstanceDeclaration();
Node* parseSourceElement();
SourceFile* parseSourceFile();

View file

@ -8,78 +8,12 @@
#include "bolt/Text.hpp"
#include "bolt/String.hpp"
#include "bolt/Stream.hpp"
namespace bolt {
class Token;
template<typename T>
class Stream {
public:
virtual T get() = 0;
virtual T peek(std::size_t Offset = 0) = 0;
virtual ~Stream() {}
};
template<typename ContainerT, typename T = typename ContainerT::value_type>
class VectorStream : public Stream<T> {
public:
using value_type = T;
ContainerT& Data;
value_type Sentry;
std::size_t Offset;
VectorStream(ContainerT& Data, value_type Sentry, std::size_t Offset = 0):
Data(Data), Sentry(Sentry), Offset(Offset) {}
value_type get() override {
return Offset < Data.size() ? Data[Offset++] : Sentry;
}
value_type peek(std::size_t Offset2) override {
auto I = Offset + Offset2;
return I < Data.size() ? Data[I] : Sentry;
}
};
template<typename T>
class BufferedStream : public Stream<T> {
std::deque<T> Buffer;
protected:
virtual T read() = 0;
public:
using value_type = T;
value_type get() override {
if (Buffer.empty()) {
return read();
} else {
auto Keep = Buffer.front();
Buffer.pop_front();
return Keep;
}
}
value_type peek(std::size_t Offset = 0) override {
while (Buffer.size() <= Offset) {
Buffer.push_back(read());
}
return Buffer[Offset];
}
};
class Scanner : public BufferedStream<Token*> {
TextFile& File;

View file

@ -36,8 +36,7 @@ namespace bolt {
};
class TextRange {
public:
struct TextRange {
TextLoc Start;
TextLoc End;
};

View file

@ -1,66 +0,0 @@
{!
root_node_name = root_node.name
def gen_cpp_unref(expr, ty):
if isinstance(ty, NodeType):
return f'{expr}->unref();\n'
elif isinstance(ty, ListType):
dtor = gen_cpp_unref('Element', ty.element_type)
if dtor:
out = ''
out += f'for (auto& Element: {expr})'
out += '{\n'
out += dtor
out += '}\n'
return out
elif isinstance(ty, OptionalType):
if is_type_optional_by_default(ty.element_type):
element_expr = expr
else:
element_expr = f'(*{expr})'
dtor = gen_cpp_unref(element_expr, ty.element_type)
if dtor:
out = ''
out += 'if ('
out += expr
out += ') {\n'
out += dtor
out += '}\n'
return out
elif isinstance(ty, RawType):
pass # field should be destroyed by class
else:
raise RuntimeError(f'unexpected {ty}')
!}
#include "{{include_path}}/{{name}}.hpp"
{% for namespace in namespaces %}
namespace {{namespace}} { {! indent() !}
{% endfor %}
{{root_node_name}}:~{{root_node_name}}() {}
SourceFile* {root_node.name}::getSourceFile() {
auto CurrNode = this;
for (;;) {
if (CurrNode->Type == NodeType::SourceFile) {
return static_cast<SourceFile*>(this);
}
CurrNode = CurrNode->Parent;
ZEN_ASSERT(CurrNode != nullptr);
}
}
{% for node in nodes %}
{{node.name}}::~{{node.name}}() {
{% for name, ty in node.fields %}
{{gen_cpp_unref(name, ty)}
{% endfor %}
}
{% endfor %}
{% for namespace in namespaces %}
} {! dedent() !}
{% endfor %}

View file

@ -1,118 +0,0 @@
{!
macro_prefix = '_'.join(namespaces).upper() + '_'
variant_name = root_node_name + 'Type'
!}
#pragma once
{% for namespace in namespaces %}
namespace {{namespace}} { {! indent() !}
{% endfor %}
class {{base_node.name}};
class {{root_node_name}} {
unsigned RefCount = 0;
{{root_node_name}}* Parent = nullptr;
public:
inline void ref() {
++RefCount;
}
inline void unref() {
--RefCount;
if (RefCount == 0) {
delete this;
}
}
const {{variant_name}} Type;
inline {{root_node_name}}({{variant_name}}Type):
Type(Type) {}
{{base_node.name}}* get{{base_node.name}}();
virtual void setParents();
virtual ~Node();
};
{% for node in nodes %}
{!
def gen_cpp_ctor_params(out, node):
visited = set()
queue = deque([ node ])
is_leaf = not graph.has_children(node.name)
first = True
if not is_leaf:
out.write(f"{cpp_root_node_name}Type Type")
first = False
while queue:
node = queue.popleft()
if node.name in visited:
return
visited.add(node.name)
for member in node.members:
if first:
first = False
else:
out.write(', ')
out.write(gen_cpp_type_expr(member.type_expr.type))
out.write(' ')
out.write(camel_case(member.name))
for parent in node.parents:
queue.append(types[parent])
def gen_cpp_ctor_args(out, orig_node: NodeDecl):
first = True
is_leaf = not graph.has_children(orig_node.name)
if orig_node.parents:
for parent in orig_node.parents:
if first:
first = False
else:
out.write(', ')
node = types[parent]
refs = ''
if is_leaf:
refs += f"{cpp_root_node_name}Type::{orig_node.name}"
else:
refs += 'Type'
for member in node.members:
refs += f", {camel_case(member.name)}"
out.write(f"{prefix}{node.name}({refs})")
else:
if is_leaf:
out.write(f"{cpp_root_node_name}({cpp_root_node_name}Type::{orig_node.name})")
else:
out.write(f"{cpp_root_node_name}(Type)")
first = False
for member in orig_node.members:
if first:
first = False
else:
out.write(', ')
out.write(f"{camel_case(member.name)}({camel_case(member.name)})")
!}
class {{node.name}} : public {{node.parent.name}} {
{{node.name}}(
{{cpp_ctor_params}}
): {{node.parent.name}}({{variant_name}}::{{node.name}}{{cpp_ctor_args}} {}
~{{node.name}}();
};
{% endfor %}
{% for namespace in namespaces %}
} {! dedent() !}
{% endfor %}

View file

@ -1,848 +0,0 @@
#!/usr/bin/env python3
from os import wait
import re
from collections import deque
from pathlib import Path
import argparse
from typing import List, Optional
from sweetener.record import Record
import templaty
here = Path(__file__).parent.resolve()
EOF = '\uFFFF'
END_OF_FILE = 0
IDENTIFIER = 1
SEMI = 2
EXTERNAL = 3
NODE = 4
LBRACE = 5
RBRACE = 6
LESSTHAN = 7
GREATERTHAN = 8
COLON = 9
LPAREN = 10
RPAREN = 11
VBAR = 12
COMMA = 13
HASH = 14
STRING = 15
RE_WHTITESPACE = re.compile(r"[\n\r\t ]")
RE_IDENT_START = re.compile(r"[a-zA-Z_]")
RE_IDENT_PART = re.compile(r"[a-zA-Z_0-9]")
KEYWORDS = {
'external': EXTERNAL,
'node': NODE,
}
def escape_char(ch):
code = ord(ch)
if code >= 32 and code < 126:
return ch
if code <= 127:
return f"\\x{code:02X}"
return f"\\u{code:04X}"
def camel_case(ident: str) -> str:
out = ident[0].upper()
i = 1
while i < len(ident):
ch = ident[i]
i += 1
if ch == '_':
c1 = ident[i]
i += 1
out += c1.upper()
else:
out += ch
return out
class ScanError(RuntimeError):
def __init__(self, file, position, actual):
super().__init__(f"{file.name}:{position.line}:{position.column}: unexpected character '{escape_char(actual)}'")
self.file = file
self.position = position
self.actual = actual
TOKEN_TYPE_TO_STRING = {
LPAREN: '(',
RPAREN: ')',
LBRACE: '{',
RBRACE: '}',
LESSTHAN: '<',
GREATERTHAN: '>',
NODE: 'node',
EXTERNAL: 'external',
SEMI: ';',
COLON: ':',
COMMA: ',',
VBAR: '|',
HASH: '#',
}
class Token:
def __init__(self, type, position=None, value=None):
self.type = type
self.start_pos = position
self.value = value
@property
def text(self):
if self.type in TOKEN_TYPE_TO_STRING:
return TOKEN_TYPE_TO_STRING[self.type]
if self.type == IDENTIFIER:
return self.value
if self.type == STRING:
return f'"{self.value}"'
if self.type == END_OF_FILE:
return ''
return '(unknown token)'
class TextFile:
def __init__(self, filename, text=None):
self.name = filename
self._cached_text = text
@property
def text(self):
if self._cached_text is None:
with open(self.name, 'r') as f:
self._cached_text = f.read()
return self._cached_text
class TextPos:
def __init__(self, line=1, column=1):
self.line = line
self.column = column
def clone(self):
return TextPos(self.line, self.column)
def advance(self, text):
for ch in text:
if ch == '\n':
self.line += 1
self.column = 1
else:
self.column += 1
class Scanner:
def __init__(self, text, text_offset=0, filename=None):
self._text = text
self._text_offset = text_offset
self.file = TextFile(filename, text)
self._curr_pos = TextPos()
def _peek_char(self, offset=1):
i = self._text_offset + offset - 1
return self._text[i] if i < len(self._text) else EOF
def _get_char(self):
if self._text_offset == len(self._text):
return EOF
i = self._text_offset
self._text_offset += 1
ch = self._text[i]
self._curr_pos.advance(ch)
return ch
def _take_while(self, pred):
out = ''
while True:
ch = self._peek_char()
if not pred(ch):
break
self._get_char()
out += ch
return out
def scan(self):
while True:
c0 = self._peek_char()
c1 = self._peek_char(2)
if c0 == '/' and c1 == '/':
self._get_char()
self._get_char()
while True:
c3 = self._get_char()
if c3 == '\n' or c3 == EOF:
break
continue
if RE_WHTITESPACE.match(c0):
self._get_char()
continue
break
if c0 == EOF:
return Token(END_OF_FILE, self._curr_pos.clone())
start_pos = self._curr_pos.clone()
self._get_char()
if c0 == ';': return Token(SEMI, start_pos)
if c0 == '{': return Token(LBRACE, start_pos)
if c0 == '}': return Token(RBRACE, start_pos)
if c0 == '(': return Token(LPAREN, start_pos)
if c0 == ')': return Token(RPAREN, start_pos)
if c0 == '<': return Token(LESSTHAN, start_pos)
if c0 == '>': return Token(GREATERTHAN, start_pos)
if c0 == ':': return Token(COLON, start_pos)
if c0 == '|': return Token(VBAR, start_pos)
if c0 == ',': return Token(COMMA, start_pos)
if c0 == '#': return Token(HASH, start_pos)
if c0 == '"':
text = ''
while True:
c1 = self._get_char()
if c1 == '"':
break
text += c1
return Token(STRING, start_pos, text)
if RE_IDENT_START.match(c0):
name = c0 + self._take_while(lambda ch: RE_IDENT_PART.match(ch))
return Token(KEYWORDS[name], start_pos) \
if name in KEYWORDS \
else Token(IDENTIFIER, start_pos, name)
raise ScanError(self.file, start_pos, c0)
class Type(Record):
pass
class ListType(Type):
element_type: Type
class OptionalType(Type):
element_type: Type
class NodeType(Type):
name: str
class VariantType(Type):
types: List[Type]
class RawType(Type):
text: str
class AST(Record):
pass
class Directive(AST):
pass
INCLUDEMODE_LOCAL = 0
INCLUDEMODE_SYSTEM = 1
class IncludeDiretive(Directive):
path: str
mode: int
def __str__(self):
if self.mode == INCLUDEMODE_LOCAL:
return f"#include \"{self.path}\"\n"
if self.mode == INCLUDEMODE_SYSTEM:
return f"#include <{self.path}>\n"
class TypeExpr(AST):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.type = None
class RefTypeExpr(TypeExpr):
name: str
args: List[TypeExpr]
class UnionTypeExpr(TypeExpr):
types: List[TypeExpr]
class External(AST):
name: str
class NodeDeclField(AST):
name: str
type_expr: TypeExpr
class NodeDecl(AST):
name: str
parents: List[str]
members: List[NodeDeclField]
def pretty_token(token):
if token.type == END_OF_FILE:
return 'end-of-file'
return f"'{token.text}'"
def pretty_token_type(token_type):
if token_type in TOKEN_TYPE_TO_STRING:
return f"'{TOKEN_TYPE_TO_STRING[token_type]}'"
if token_type == IDENTIFIER:
return 'an identfier'
if token_type == STRING:
return 'a string literal'
if token_type == END_OF_FILE:
return 'end-of-file'
return f"(unknown token type {token_type})"
def pretty_alternatives(elements):
try:
out = next(elements)
except StopIteration:
return 'nothing'
try:
prev_element = next(elements)
except StopIteration:
return out
while True:
try:
element = next(elements)
except StopIteration:
break
out += ', ' + prev_element
prev_element = element
return out + ' or ' + prev_element
class ParseError(RuntimeError):
def __init__(self, file, actual, expected):
super().__init__(f"{file.name}:{actual.start_pos.line}:{actual.start_pos.column}: got {pretty_token(actual)} but expected {pretty_alternatives(pretty_token_type(tt) for tt in expected)}")
self.actual = actual
self.expected = expected
class Parser:
def __init__(self, scanner):
self._scanner = scanner
self._token_buffer = deque()
def _peek_token(self, offset=1):
while len(self._token_buffer) < offset:
self._token_buffer.append(self._scanner.scan())
return self._token_buffer[offset-1]
def _get_token(self):
if self._token_buffer:
return self._token_buffer.popleft()
return self._scanner.scan()
def _expect_token(self, expected_token_type):
t0 = self._get_token()
if t0.type != expected_token_type:
raise ParseError(self._scanner.file, t0, [ expected_token_type ])
return t0
def _parse_prim_type_expr(self):
t0 = self._get_token()
if t0.type == LPAREN:
result = self.parse_type_expr()
self._expect_token(RPAREN)
return result
if t0.type == IDENTIFIER:
t1 = self._peek_token()
args = []
if t1.type == LESSTHAN:
self._get_token()
while True:
t2 = self._peek_token()
if t2.type == GREATERTHAN:
self._get_token()
break
args.append(self.parse_type_expr())
t3 = self._get_token()
if t3.type == GREATERTHAN:
break
if t3.type != COMMA:
raise ParseError(self._scanner.file, t3, [ COMMA, GREATERTHAN ])
return RefTypeExpr(t0.value, args)
raise ParseError(self._scanner.file, t0, [ LPAREN, IDENTIFIER ])
def parse_type_expr(self):
return self._parse_prim_type_expr()
def parse_member(self):
type_expr = self.parse_type_expr()
name = self._expect_token(IDENTIFIER)
self._expect_token(SEMI)
return NodeDeclField(name.value, type_expr)
def parse_toplevel(self):
t0 = self._get_token()
if t0.type == EXTERNAL:
name = self._expect_token(IDENTIFIER)
self._expect_token(SEMI)
return External(name.value)
if t0.type == NODE:
name = self._expect_token(IDENTIFIER).value
parents = []
t1 = self._peek_token()
if t1.type == COLON:
self._get_token()
while True:
parent = self._expect_token(IDENTIFIER).value
parents.append(parent)
t2 = self._peek_token()
if t2.type == COMMA:
self._get_token()
continue
if t2.type == LBRACE:
break
raise ParseError(self._scanner.file, t2, [ COMMA, LBRACE ])
self._expect_token(LBRACE)
members = []
while True:
t2 = self._peek_token()
if t2.type == RBRACE:
self._get_token()
break
member = self.parse_member()
members.append(member)
return NodeDecl(name, parents, members)
if t0.type == HASH:
name = self._expect_token(IDENTIFIER)
if name.value == 'include':
t1 = self._get_token()
if t1.type == LESSTHAN:
assert(not self._token_buffer)
path = self._scanner._take_while(lambda ch: ch != '>')
self._scanner._get_char()
mode = INCLUDEMODE_SYSTEM
elif t1.type == STRING:
mode = INCLUDEMODE_LOCAL
path = t1.value
else:
raise ParseError(self._scanner.file, t1, [ STRING, LESSTHAN ])
return IncludeDiretive(path, mode)
raise RuntimeError(f"invalid preprocessor directive '{name.value}'")
raise ParseError(self._scanner.file, t0, [ EXTERNAL, NODE, HASH ])
def parse_grammar(self):
elements = []
while True:
t0 = self._peek_token()
if t0.type == END_OF_FILE:
break
element = self.parse_toplevel()
elements.append(element)
return elements
class Writer:
def __init__(self, text='', path=None):
self.path = path
self.text = text
self._at_blank_line = True
self._indentation = ' '
self._indent_level = 0
def indent(self, count=1):
self._indent_level += count
def dedent(self, count=1):
self._indent_level -= count
def write(self, chunk):
for ch in chunk:
if ch == '}':
self.dedent()
if ch == '\n':
self._at_blank_line = True
elif self._at_blank_line and not RE_WHTITESPACE.match(ch):
self.text += self._indentation * self._indent_level
self._at_blank_line = False
self.text += ch
if ch == '{':
self.indent()
def save(self, dest_dir):
dest_path = dest_dir / self.path
print(f'Writing file {dest_path} ...')
with open(dest_path, 'w') as f:
f.write(self.text)
class DiGraph:
def __init__(self):
self._out_edges = dict()
self._in_edges = dict()
def add_edge(self, a, b):
if a not in self._out_edges:
self._out_edges[a] = set()
self._out_edges[a].add(b)
if b not in self._in_edges:
self._in_edges[b] = set()
self._in_edges[b].add(a)
def get_children(self, node):
if node not in self._out_edges:
return
for child in self._out_edges[node]:
yield child
def has_children(self, node):
return node in self._out_edges
def is_child_of(self, a, b):
stack = [ b ]
visited = set()
while stack:
node = stack.pop()
if node in visited:
break
visited.add(node)
if node == a:
return True
for child in self.get_children(node):
stack.append(child)
return False
def get_ancestors(self, node):
if node not in self._in_edges:
return
for parent in self._in_edges[node]:
yield parent
def get_common_ancestor(self, nodes):
out = nodes[0]
parents = []
for node in nodes[1:]:
if not self.is_child_of(node, out):
for parent in self.get_ancestors(node):
parents.append(parent)
if not parents:
return out
parents.append(out)
return self.get_common_ancestor(parents)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('file', nargs=1, help='The specification file to generate C++ code for')
parser.add_argument('--namespace', default='', help='What C++ namespace to put generated code under')
parser.add_argument('--name', default='AST', help='How to name the generated tree')
parser.add_argument('-I', default='.', help='What path will be used to include generated header files')
parser.add_argument('--include-root', default='.', help='Where the headers live inside the include directroy')
parser.add_argument('--enable-serde', action='store_true', help='Also write (de)serialization logic')
parser.add_argument('--source-root', default='.', help='Where to store generated souce files')
parser.add_argument('--node-name', default='Node', help='How the root node of the hierachy should be called')
parser.add_argument('--node-prefix', default='', help='String to prepend to the names of node types')
parser.add_argument('--out-dir', default='.', help='Place the endire folder structure inside this folder')
parser.add_argument('--dry-run', action='store_true', help='Do not write generated code to the file system')
args = parser.parse_args()
filename = args.file[0]
prefix = args.node_prefix
cpp_root_node_name = prefix + args.node_name
include_dir = Path(args.I)
include_path = Path(args.include_root or '.')
full_include_path = include_dir / include_path
source_path = Path(args.source_root)
namespace = args.namespace.split('::')
out_dir = Path(args.out_dir)
out_name = args.name
write_serde = args.enable_serde
with open(filename, 'r') as f:
text = f.read()
scanner = Scanner(text, filename=filename)
parser = Parser(scanner)
elements = parser.parse_grammar()
types = dict()
nodes = list()
leaf_nodes = list()
graph = DiGraph()
parent_to_children = dict()
for element in elements:
if isinstance(element, External) \
or isinstance(element, NodeDecl):
types[element.name] = element
if isinstance(element, NodeDecl):
nodes.append(element)
for parent in element.parents:
graph.add_edge(parent, element.name)
if parent not in parent_to_children:
parent_to_children[parent] = set()
children = parent_to_children[parent]
children.add(element)
for node in nodes:
if node.name not in parent_to_children:
leaf_nodes.append(node)
def is_null_type_expr(type_expr):
return isinstance(type_expr, RefTypeExpr) and type_expr.name == 'null'
def is_node(name):
if name in types:
return isinstance(types[name], NodeDecl)
if name in parent_to_children:
return True
return False
def get_all_variant_elements(type_expr):
types = list()
def loop(ty):
if isinstance(ty, RefTypeExpr) and ty.name == 'Variant':
for arg in ty.args:
loop(arg)
else:
types.append(ty)
loop(type_expr)
return types
def infer_type(type_expr):
if isinstance(type_expr, RefTypeExpr):
if type_expr.name == 'Option':
assert(len(type_expr.args) == 1)
return OptionalType(infer_type(type_expr.args[0]))
if type_expr.name == 'List':
assert(len(type_expr.args) == 1)
return ListType(infer_type(type_expr.args[0]))
if type_expr.name == 'Variant':
types = get_all_variant_elements(type_expr)
has_null = False
if any(is_null_type_expr(ty) for ty in types):
has_null = True
types = list(ty for ty in types if not is_null_type_expr(ty))
if all(isinstance(ty, RefTypeExpr) and is_node(ty.name) for ty in types):
node_name = graph.get_common_ancestor(list(t.name for t in types))
return NodeType(node_name)
if len(types) == 1:
out = infer_type(types[0])
else:
out = VariantType(infer_type(ty) for ty in types)
return OptionalType(out) if has_null else out
if is_node(type_expr.name):
assert(len(type_expr.args) == 0)
return NodeType(type_expr.name)
assert(len(type_expr.args) == 0)
return RawType(type_expr.name)
raise RuntimeError(f"unhandled type expression {type_expr}")
for node in nodes:
for member in node.members:
member.type_expr.type = infer_type(member.type_expr)
def is_type_optional_by_default(ty):
return isinstance(ty, NodeType)
def gen_cpp_type_expr(ty):
if isinstance(ty, NodeType):
return prefix + ty.name + "*"
if isinstance(ty, ListType):
return f"std::vector<{gen_cpp_type_expr(ty.element_type)}>"
if isinstance(ty, NodeType):
return ty.name + '*'
if isinstance(ty, OptionalType):
cpp_expr = gen_cpp_type_expr(ty.element_type)
if is_type_optional_by_default(ty.element_type):
return cpp_expr
return f"std::optional<{cpp_expr}>"
if isinstance(ty, VariantType):
return f"std::variant<{','.join(gen_cpp_type_expr(t) for t in ty.element_types)}>"
if isinstance(ty, RawType):
return ty.text
raise RuntimeError(f"unhandled Type {ty}")
def gen_cpp_dtor(expr, ty):
if isinstance(ty, NodeType):
return f'{expr}->unref();\n'
elif isinstance(ty, ListType):
dtor = gen_cpp_dtor('Element', ty.element_type)
if dtor:
out = ''
out += f'for (auto& Element: {expr})'
out += '{\n'
out += dtor
out += '}\n'
return out
elif isinstance(ty, OptionalType):
if is_type_optional_by_default(ty.element_type):
element_expr = expr
else:
element_expr = f'(*{expr})'
dtor = gen_cpp_dtor(element_expr, ty.element_type)
if dtor:
out = ''
out += 'if ('
out += expr
out += ') {\n'
out += dtor
out += '}\n'
return out
elif isinstance(ty, RawType):
pass # field should be destroyed by class
else:
raise RuntimeError(f'unexpected {ty}')
def gen_cpp_ctor_params(out, node):
visited = set()
queue = deque([ node ])
is_leaf = not graph.has_children(node.name)
first = True
if not is_leaf:
out.write(f"{cpp_root_node_name}Type Type")
first = False
while queue:
node = queue.popleft()
if node.name in visited:
return
visited.add(node.name)
for member in node.members:
if first:
first = False
else:
out.write(', ')
out.write(gen_cpp_type_expr(member.type_expr.type))
out.write(' ')
out.write(camel_case(member.name))
for parent in node.parents:
queue.append(types[parent])
def gen_cpp_ctor_args(out, orig_node: NodeDecl):
first = True
is_leaf = not graph.has_children(orig_node.name)
if orig_node.parents:
for parent in orig_node.parents:
if first:
first = False
else:
out.write(', ')
node = types[parent]
refs = ''
if is_leaf:
refs += f"{cpp_root_node_name}Type::{orig_node.name}"
else:
refs += 'Type'
for member in node.members:
refs += f", {camel_case(member.name)}"
out.write(f"{prefix}{node.name}({refs})")
else:
if is_leaf:
out.write(f"{cpp_root_node_name}({cpp_root_node_name}Type::{orig_node.name})")
else:
out.write(f"{cpp_root_node_name}(Type)")
first = False
for member in orig_node.members:
if first:
first = False
else:
out.write(', ')
out.write(f"{camel_case(member.name)}({camel_case(member.name)})")
node_hdr = templaty.execute(here / 'CST.hpp.tply', ctx={
'namespaces': namespace,
'nodes': nodes,
'root_node_name': args.node_name
})
node_hdr = Writer(path=full_include_path / (out_name + '.hpp'))
node_src = Writer(path=source_path / (out_name + '.cc'))
# Generating the header file
if write_serde:
node_hdr.write('void encode(Encoder& encoder) const;\n\n')
node_hdr.write('virtual void encode_fields(Encoder& encoder) const = 0;\n');
#node_hdr.write('virtual void decode_fields(Decoder& decoder) = 0;\n\n');
for element in elements:
if isinstance(element, NodeDecl):
node = element
is_leaf = not list(graph.get_children(node.name))
cpp_node_name = prefix + node.name
node_hdr.write("class ")
node_hdr.write(cpp_node_name)
node_hdr.write(" : ")
if node.parents:
node_hdr.write(', '.join('public ' + prefix + parent for parent in node.parents))
else:
node_hdr.write('public ' + cpp_root_node_name)
node_hdr.write(" {\n\n")
node_hdr.write('public:\n\n')
node_hdr.write(cpp_node_name + '(')
gen_cpp_ctor_params(node_hdr, node)
node_hdr.write('): ')
gen_cpp_ctor_args(node_hdr, node)
node_hdr.write(' {}\n\n')
if node.members:
for member in node.members:
node_hdr.write(gen_cpp_type_expr(member.type_expr.type))
node_hdr.write(" ");
node_hdr.write(camel_case(member.name))
node_hdr.write(";\n");
node_hdr.write('\n')
if write_serde and is_leaf:
node_hdr.write('void encode_fields(Encoder& encoder) const override;\n');
#node_hdr.write('void decode_fields(Decoder& decoder) override;\n\n');
# Generating the source file
node_src.write(f"""#include "{include_path / (out_name + '.hpp')}"\n\n""")
for name in namespace:
node_src.write(f"namespace {name} {{\n\n")
node_src.write(f"""{cpp_root_node_name}::~{cpp_root_node_name}() {{ }}\n\n""")
if write_serde:
node_src.write(f"""
void {cpp_root_node_name}::encode(Encoder& encoder) const {{
encoder.start_encode_struct("{cpp_root_node_name}");
encode_fields(encoder);
encoder.end_encode_struct();
}}
""")
for node in nodes:
is_leaf = not list(graph.get_children(node.name))
cpp_node_name = prefix + node.name
if write_serde and is_leaf:
node_src.write(f'void {cpp_node_name}::encode_fields(Encoder& encoder) const {{\n')
for member in node.members:
node_src.write(f'encoder.encode_field("{member.name}", {member.name});\n')
node_src.write('}\n\n')
node_src.write(f'{cpp_node_name}::~{cpp_node_name}() {{\n')
for member in node.members:
dtor = gen_cpp_dtor(camel_case(member.name), member.type_expr.type)
if dtor:
node_src.write(dtor)
node_src.write('}\n\n')
for _ in namespace:
node_src.write("}\n\n")
if args.dry_run:
print('# ' + str(node_hdr.path))
print(node_hdr.text)
print('# ' + str(node_src.path))
print(node_src.text)
else:
out_dir.mkdir(exist_ok=True, parents=True)
node_hdr.save(out_dir)
node_src.save(out_dir)
if __name__ == '__main__':
main()

View file

@ -2,6 +2,7 @@
#include "zen/config.hpp"
#include "bolt/CST.hpp"
#include "bolt/CSTVisitor.hpp"
namespace bolt {
@ -11,23 +12,34 @@ namespace bolt {
}
void Scope::scan(Node* X) {
switch (X->Type) {
case NodeType::ExpressionStatement:
case NodeType::ReturnStatement:
case NodeType::IfStatement:
switch (X->getKind()) {
case NodeKind::ExpressionStatement:
case NodeKind::ReturnStatement:
case NodeKind::IfStatement:
break;
case NodeType::SourceFile:
case NodeKind::SourceFile:
{
auto Y = static_cast<SourceFile*>(X);
for (auto Element: Y->Elements) {
auto File = static_cast<SourceFile*>(X);
for (auto Element: File->Elements) {
scan(Element);
}
break;
}
case NodeType::LetDeclaration:
case NodeKind::ClassDeclaration:
{
auto Y = static_cast<LetDeclaration*>(X);
addBindings(Y->Pattern, Y);
auto Decl = static_cast<ClassDeclaration*>(X);
for (auto Element: Decl->Elements) {
scan(Element);
}
break;
}
case NodeKind::InstanceDeclaration:
// FIXME is this right?
break;
case NodeKind::LetDeclaration:
{
auto Decl = static_cast<LetDeclaration*>(X);
addBindings(Decl->Pattern, Decl);
break;
}
default:
@ -36,8 +48,8 @@ namespace bolt {
}
void Scope::addBindings(Pattern* X, Node* ToInsert) {
switch (X->Type) {
case NodeType::BindPattern:
switch (X->getKind()) {
case NodeKind::BindPattern:
{
auto Y = static_cast<BindPattern*>(X);
Mapping.emplace(Y->Name->Text, ToInsert);
@ -49,6 +61,7 @@ namespace bolt {
}
Node* Scope::lookup(SymbolPath Path) {
ZEN_ASSERT(Path.Modules.empty());
auto Curr = this;
do {
auto Match = Curr->Mapping.find(Path.Name);
@ -70,7 +83,7 @@ namespace bolt {
SourceFile* Node::getSourceFile() {
auto CurrNode = this;
for (;;) {
if (CurrNode->Type == NodeType::SourceFile) {
if (CurrNode->Kind == NodeKind::SourceFile) {
return static_cast<SourceFile*>(CurrNode);
}
CurrNode = CurrNode->Parent;
@ -95,435 +108,49 @@ namespace bolt {
return EndLoc;
}
void Token::setParents() {
}
void Node::setParents() {
void QualifiedName::setParents() {
for (auto Name: ModulePath) {
Name->Parent = this;
}
Name->Parent = this;
}
struct SetParentsVisitor : public CSTVisitor<SetParentsVisitor> {
void ReferenceTypeExpression::setParents() {
Name->Parent = this;
Name->setParents();
}
void ArrowTypeExpression::setParents() {
for (auto ParamType: ParamTypes) {
ParamType->Parent = this;
ParamType->setParents();
}
ReturnType->Parent = this;
ReturnType->setParents();
}
std::vector<Node*> Parents { nullptr };
void BindPattern::setParents() {
Name->Parent = this;
}
void visit(Node* N) {
N->Parent = Parents.back();
Parents.push_back(N);
visitEachChild(N);
Parents.pop_back();
}
void ReferenceExpression::setParents() {
Name->Parent = this;
}
};
void NestedExpression::setParents() {
LParen->Parent = this;
Inner->Parent = this;
Inner->setParents();
RParen->Parent = this;
}
SetParentsVisitor V;
V.visit(this);
void ConstantExpression::setParents() {
Token->Parent = this;
}
void CallExpression::setParents() {
Function->Parent = this;
Function->setParents();
for (auto Arg: Args) {
Arg->Parent = this;
Arg->setParents();
}
}
void InfixExpression::setParents() {
LHS->Parent = this;
LHS->setParents();
Operator->Parent = this;
RHS->Parent = this;
RHS->setParents();
}
void UnaryExpression::setParents() {
Operator->Parent = this;
Argument->Parent = this;
Argument->setParents();
}
void ExpressionStatement::setParents() {
Expression->Parent = this;
Expression->setParents();
}
void ReturnStatement::setParents() {
ReturnKeyword->Parent = this;
Expression->Parent = this;
Expression->setParents();
}
void IfStatementPart::setParents() {
Keyword->Parent = this;
if (Test) {
Test->Parent = this;
Test->setParents();
}
BlockStart->Parent = this;
for (auto Element: Elements) {
Element->Parent = this;
Element->setParents();
}
}
void IfStatement::setParents() {
for (auto Part: Parts) {
Part->Parent = this;
Part->setParents();
}
}
void TypeAssert::setParents() {
Colon->Parent = this;
TypeExpression->Parent = this;
TypeExpression->setParents();
}
void LetBlockBody::setParents() {
BlockStart->Parent = this;
for (auto Element: Elements) {
Element->Parent = this;
Element->setParents();
}
}
void LetExprBody::setParents() {
Equals->Parent = this;
Expression->Parent = this;
Expression->setParents();
}
void Param::setParents() {
Pattern->Parent = this;
Pattern->setParents();
if (TypeAssert) {
TypeAssert->Parent = this;
TypeAssert->setParents();
}
}
void LetDeclaration::setParents() {
if (PubKeyword) {
PubKeyword->Parent = this;
}
LetKeyword->Parent = this;
if (MutKeyword) {
MutKeyword->Parent = this;
}
Pattern->Parent = this;
Pattern->setParents();
for (auto Param: Params) {
Param->Parent = this;
Param->setParents();
}
if (TypeAssert) {
TypeAssert->Parent = this;
TypeAssert->setParents();
}
if (Body) {
Body->Parent = this;
Body->setParents();
}
}
void StructDeclField::setParents() {
Name->Parent = this;
Colon->Parent = this;
TypeExpression->Parent = this;
TypeExpression->setParents();
}
void StructDecl::setParents() {
StructKeyword->Parent = this;
Name->Parent = this;
BlockStart->Parent = this;
for (auto Field: Fields) {
Field->Parent = this;
Field->setParents();
}
}
void SourceFile::setParents() {
for (auto Element: Elements) {
Element->Parent = this;
Element->setParents();
}
}
Node::~Node() {
struct UnrefVisitor : public CSTVisitor<UnrefVisitor> {
void visit(Node* N) {
N->unref();
visitEachChild(N);
}
};
UnrefVisitor V;
V.visitEachChild(this);
}
Token::~Token() {
}
Equals::~Equals() {
}
Colon::~Colon() {
}
RArrow::~RArrow() {
}
Dot::~Dot() {
}
DotDot::~DotDot() {
}
LParen::~LParen() {
}
RParen::~RParen() {
}
LBracket::~LBracket() {
}
RBracket::~RBracket() {
}
LBrace::~LBrace() {
}
RBrace::~RBrace() {
}
LetKeyword::~LetKeyword() {
}
MutKeyword::~MutKeyword() {
}
PubKeyword::~PubKeyword() {
}
TypeKeyword::~TypeKeyword() {
}
ReturnKeyword::~ReturnKeyword() {
}
IfKeyword::~IfKeyword() {
}
ElifKeyword::~ElifKeyword() {
}
ElseKeyword::~ElseKeyword() {
}
ModKeyword::~ModKeyword() {
}
StructKeyword::~StructKeyword() {
}
Invalid::~Invalid() {
}
EndOfFile::~EndOfFile() {
}
BlockStart::~BlockStart() {
}
BlockEnd::~BlockEnd() {
}
LineFoldEnd::~LineFoldEnd() {
}
CustomOperator::~CustomOperator() {
}
Assignment::~Assignment() {
}
Identifier::~Identifier() {
}
StringLiteral::~StringLiteral() {
}
IntegerLiteral::~IntegerLiteral() {
}
QualifiedName::~QualifiedName() {
for (auto& Element: ModulePath){
Element->unref();
}
Name->unref();
}
TypeExpression::~TypeExpression() {
}
ReferenceTypeExpression::~ReferenceTypeExpression() {
Name->unref();
}
ArrowTypeExpression::~ArrowTypeExpression() {
for (auto ParamType: ParamTypes) {
ParamType->unref();
}
ReturnType->unref();
}
Pattern::~Pattern() {
}
BindPattern::~BindPattern() {
Name->unref();
}
Expression::~Expression() {
}
ReferenceExpression::~ReferenceExpression() {
Name->unref();
}
NestedExpression::~NestedExpression() {
LParen->unref();
Inner->unref();
RParen->unref();
}
ConstantExpression::~ConstantExpression() {
Token->unref();
}
CallExpression::~CallExpression() {
Function->unref();
for (auto& Element: Args){
Element->unref();
}
}
InfixExpression::~InfixExpression() {
LHS->unref();
Operator->unref();
RHS->unref();
}
UnaryExpression::~UnaryExpression() {
Operator->unref();
Argument->unref();
}
Statement::~Statement() {
}
ExpressionStatement::~ExpressionStatement() {
Expression->unref();
}
ReturnStatement::~ReturnStatement() {
ReturnKeyword->unref();
Expression->unref();
}
IfStatementPart::~IfStatementPart() {
Keyword->unref();
if (Test) {
Test->unref();
}
BlockStart->unref();
for (auto Element: Elements) {
Element->unref();
}
}
IfStatement::~IfStatement() {
for (auto Part: Parts) {
Part->unref();
}
}
TypeAssert::~TypeAssert() {
Colon->unref();
TypeExpression->unref();
}
Param::~Param() {
Pattern->unref();
TypeAssert->unref();
}
LetBody::~LetBody() {
}
LetBlockBody::~LetBlockBody() {
BlockStart->unref();
for (auto& Element: Elements){
Element->unref();
}
}
LetExprBody::~LetExprBody() {
Equals->unref();
Expression->unref();
}
LetDeclaration::~LetDeclaration() {
if (PubKeyword) {
PubKeyword->unref();
}
LetKeyword->unref();
if (MutKeyword) {
MutKeyword->unref();
}
Pattern->unref();
for (auto& Element: Params){
Element->unref();
}
if (TypeAssert) {
TypeAssert->unref();
}
if (Body) {
Body->unref();
}
}
StructDeclField::~StructDeclField() {
Name->unref();
Colon->unref();
TypeExpression->unref();
}
StructDecl::~StructDecl() {
StructKeyword->unref();
Name->unref();
BlockStart->unref();
for (auto& Element: Fields){
Element->unref();
}
}
SourceFile::~SourceFile() {
for (auto& Element: Elements){
Element->unref();
bool Identifier::isTypeVar() const {
for (auto C: Text) {
if (!((C >= 97 && C <= 122) || C == '_')) {
return false;
}
}
return true;
}
Token* QualifiedName::getFirstToken() {
@ -537,6 +164,36 @@ namespace bolt {
return Name;
}
Token* TypeclassConstraintExpression::getFirstToken() {
return Name;
}
Token* TypeclassConstraintExpression::getLastToken() {
if (!TEs.empty()) {
return TEs.back()->getLastToken();
}
return Name;
}
Token* EqualityConstraintExpression::getFirstToken() {
return Left->getFirstToken();
}
Token* EqualityConstraintExpression::getLastToken() {
return Left->getLastToken();
}
Token* QualifiedTypeExpression::getFirstToken() {
if (!Constraints.empty()) {
return std::get<0>(Constraints.front())->getFirstToken();
}
return TE->getFirstToken();
}
Token* QualifiedTypeExpression::getLastToken() {
return TE->getLastToken();
}
Token* ReferenceTypeExpression::getFirstToken() {
return Name->getFirstToken();
}
@ -556,6 +213,14 @@ namespace bolt {
return ReturnType->getLastToken();
}
Token* VarTypeExpression::getLastToken() {
return Name;
}
Token* VarTypeExpression::getFirstToken() {
return Name;
}
Token* BindPattern::getFirstToken() {
return Name;
}
@ -607,11 +272,11 @@ namespace bolt {
return RHS->getLastToken();
}
Token* UnaryExpression::getFirstToken() {
Token* PrefixExpression::getFirstToken() {
return Operator;
}
Token* UnaryExpression::getLastToken() {
Token* PrefixExpression::getLastToken() {
return Argument->getLastToken();
}
@ -663,11 +328,11 @@ namespace bolt {
return TypeExpression->getLastToken();
}
Token* Param::getFirstToken() {
Token* Parameter::getFirstToken() {
return Pattern->getFirstToken();
}
Token* Param::getLastToken() {
Token* Parameter::getLastToken() {
if (TypeAssert) {
return TypeAssert->getLastToken();
}
@ -713,28 +378,53 @@ namespace bolt {
return Pattern->getLastToken();
}
Token* StructDeclField::getFirstToken() {
Token* StructDeclarationField::getFirstToken() {
return Name;
}
Token* StructDeclField::getLastToken() {
Token* StructDeclarationField::getLastToken() {
return TypeExpression->getLastToken();
}
Token* StructDecl::getFirstToken() {
Token* StructDeclaration::getFirstToken() {
if (PubKeyword) {
return PubKeyword;
}
return StructKeyword;
}
Token* StructDecl::getLastToken() {
Token* StructDeclaration::getLastToken() {
if (Fields.size()) {
Fields.back()->getLastToken();
}
return BlockStart;
}
Token* InstanceDeclaration::getFirstToken() {
return InstanceKeyword;
}
Token* InstanceDeclaration::getLastToken() {
if (!Elements.empty()) {
return Elements.back()->getLastToken();
}
return BlockStart;
}
Token* ClassDeclaration::getFirstToken() {
if (PubKeyword != nullptr) {
return PubKeyword;
}
return ClassKeyword;
}
Token* ClassDeclaration::getLastToken() {
if (!Elements.empty()) {
return Elements.back()->getLastToken();
}
return BlockStart;
}
Token* SourceFile::getFirstToken() {
if (Elements.size()) {
return Elements.front()->getFirstToken();
@ -757,10 +447,18 @@ namespace bolt {
return ":";
}
std::string Comma::getText() const {
return ",";
}
std::string RArrow::getText() const {
return "->";
}
std::string RArrowAlt::getText() const {
return "=>";
}
std::string Dot::getText() const {
return ".";
}
@ -873,6 +571,18 @@ namespace bolt {
return "..";
}
std::string Tilde::getText() const {
return "~";
}
std::string ClassKeyword::getText() const {
return "class";
}
std::string InstanceKeyword::getText() const {
return "instance";
}
SymbolPath QualifiedName::getSymbolPath() const {
std::vector<ByteString> ModuleNames;
for (auto Ident: ModulePath) {

File diff suppressed because it is too large Load diff

View file

@ -44,51 +44,61 @@ namespace bolt {
Diagnostic::Diagnostic(DiagnosticKind Kind):
std::runtime_error("a compiler error occurred without being caught"), Kind(Kind) {}
static std::string describe(NodeType Type) {
static std::string describe(NodeKind Type) {
switch (Type) {
case NodeType::Identifier:
case NodeKind::Identifier:
return "an identifier";
case NodeType::CustomOperator:
case NodeKind::CustomOperator:
return "an operator";
case NodeType::IntegerLiteral:
case NodeKind::IntegerLiteral:
return "an integer literal";
case NodeType::EndOfFile:
case NodeKind::EndOfFile:
return "end-of-file";
case NodeType::BlockStart:
case NodeKind::BlockStart:
return "the start of a new indented block";
case NodeType::BlockEnd:
case NodeKind::BlockEnd:
return "the end of the current indented block";
case NodeType::LineFoldEnd:
case NodeKind::LineFoldEnd:
return "the end of the current line-fold";
case NodeType::LParen:
case NodeKind::LParen:
return "'('";
case NodeType::RParen:
case NodeKind::RParen:
return "')'";
case NodeType::LBrace:
case NodeKind::LBrace:
return "'['";
case NodeType::RBrace:
case NodeKind::RBrace:
return "']'";
case NodeType::LBracket:
case NodeKind::LBracket:
return "'{'";
case NodeType::RBracket:
case NodeKind::RBracket:
return "'}'";
case NodeType::Colon:
case NodeKind::Colon:
return "':'";
case NodeType::Equals:
case NodeKind::Comma:
return "','";
case NodeKind::Equals:
return "'='";
case NodeType::StringLiteral:
case NodeKind::StringLiteral:
return "a string literal";
case NodeType::Dot:
case NodeKind::Dot:
return "'.'";
case NodeType::PubKeyword:
case NodeKind::DotDot:
return "'..'";
case NodeKind::Tilde:
return "'~'";
case NodeKind::RArrow:
return "'->'";
case NodeKind::RArrowAlt:
return "'=>'";
case NodeKind::PubKeyword:
return "'pub'";
case NodeType::LetKeyword:
case NodeKind::LetKeyword:
return "'let'";
case NodeType::MutKeyword:
case NodeKind::MutKeyword:
return "'mut'";
case NodeType::ReturnKeyword:
case NodeKind::ReturnKeyword:
return "'return'";
case NodeType::TypeKeyword:
case NodeKind::TypeKeyword:
return "'type'";
default:
ZEN_UNREACHABLE
@ -97,10 +107,14 @@ namespace bolt {
std::string describe(const Type* Ty) {
switch (Ty->getKind()) {
case TypeKind::Any:
return "any";
case TypeKind::Var:
return "a" + std::to_string(static_cast<const TVar*>(Ty)->Id);
{
auto TV = static_cast<const TVar*>(Ty);
if (TV->getVarKind() == VarKind::Rigid) {
return static_cast<const TVarRigid*>(TV)->Name;
}
return "a" + std::to_string(TV->Id);
}
case TypeKind::Arrow:
{
auto Y = static_cast<const TArrow*>(Ty);
@ -342,7 +356,7 @@ namespace bolt {
writeExcerpt(E.Initiator->getSourceFile()->getTextFile(), Range, Range, Color::Red);
Out << "\n";
}
break;
return;
}
case DiagnosticKind::UnexpectedToken:
@ -366,7 +380,7 @@ namespace bolt {
default:
auto Iter = E.Expected.begin();
Out << describe(*Iter++);
NodeType Prev = *Iter++;
NodeKind Prev = *Iter++;
while (Iter != E.Expected.end()) {
Out << ", " << describe(Prev);
Prev = *Iter++;
@ -377,7 +391,7 @@ namespace bolt {
Out << " but instead got '" << E.Actual->getText() << "'\n\n";
writeExcerpt(E.File, E.Actual->getRange(), E.Actual->getRange(), Color::Red);
Out << "\n";
break;
return;
}
case DiagnosticKind::UnexpectedString:
@ -405,7 +419,7 @@ namespace bolt {
TextRange Range { E.Location, E.Location + E.Actual };
writeExcerpt(E.File, Range, Range, Color::Red);
Out << "\n";
break;
return;
}
case DiagnosticKind::UnificationError:
@ -423,11 +437,56 @@ namespace bolt {
writeExcerpt(E.Source->getSourceFile()->getTextFile(), Range, Range, Color::Red);
Out << "\n";
}
break;
return;
}
case DiagnosticKind::TypeclassMissing:
{
auto E = static_cast<const TypeclassMissingDiagnostic&>(D);
setForegroundColor(Color::Red);
setBold(true);
Out << "error: ";
resetStyles();
Out << "the type class " << ANSI_FG_YELLOW << E.Sig.Id;
for (auto TV: E.Sig.Params) {
Out << " " << describe(TV);
}
Out << ANSI_RESET << " is missing from the declaration's type signature\n\n";
auto Range = E.Decl->getRange();
writeExcerpt(E.Decl->getSourceFile()->getTextFile(), Range, Range, Color::Yellow);
Out << "\n\n";
return;
}
case DiagnosticKind::InstanceNotFound:
{
auto E = static_cast<const InstanceNotFoundDiagnostic&>(D);
setForegroundColor(Color::Red);
setBold(true);
Out << "error: ";
resetStyles();
Out << "a type class instance " << ANSI_FG_YELLOW << E.TypeclassName << " " << describe(E.Ty) << ANSI_RESET " was not found.\n\n";
auto Range = E.Source->getRange();
//std::cerr << Range.Start.Line << ":" << Range.Start.Column << "-" << Range.End.Line << ":" << Range.End.Column << "\n";
writeExcerpt(E.Source->getSourceFile()->getTextFile(), Range, Range, Color::Red);
Out << "\n";
return;
}
case DiagnosticKind::ClassNotFound:
{
auto E = static_cast<const ClassNotFoundDiagnostic&>(D);
setForegroundColor(Color::Red);
setBold(true);
Out << "error: ";
resetStyles();
Out << "the type class " << ANSI_FG_YELLOW << E.Name << ANSI_RESET " was not found.\n\n";
return;
}
}
ZEN_UNREACHABLE
}
}

View file

@ -11,9 +11,9 @@ namespace bolt {
void IPRGraph::populate(Node* X, Node* Decl) {
switch (X->Type) {
switch (X->getKind()) {
case NodeType::SourceFile:
case NodeKind::SourceFile:
{
auto Y = static_cast<SourceFile*>(X);
for (auto Element: Y->Elements) {
@ -22,7 +22,7 @@ namespace bolt {
break;
}
case NodeType::IfStatement:
case NodeKind::IfStatement:
{
auto Y = static_cast<IfStatement*>(X);
for (auto Part: Y->Parts) {
@ -33,12 +33,12 @@ namespace bolt {
break;
}
case NodeType::LetDeclaration:
case NodeKind::LetDeclaration:
{
auto Y = static_cast<LetDeclaration*>(X);
if (Y->Body) {
switch (Y->Body->Type) {
case NodeType::LetBlockBody:
switch (Y->Body->getKind()) {
case NodeKind::LetBlockBody:
{
auto Z = static_cast<LetBlockBody*>(Y->Body);
for (auto Element: Z->Elements) {
@ -46,7 +46,7 @@ namespace bolt {
}
break;
}
case NodeType::LetExprBody:
case NodeKind::LetExprBody:
{
auto Z = static_cast<LetExprBody*>(Y->Body);
populate(Z->Expression, Y);
@ -59,10 +59,10 @@ namespace bolt {
break;
}
case NodeType::ConstantExpression:
case NodeKind::ConstantExpression:
break;
case NodeType::CallExpression:
case NodeKind::CallExpression:
{
auto Y = static_cast<CallExpression*>(X);
populate(Y->Function, Decl);
@ -72,7 +72,7 @@ namespace bolt {
break;
}
case NodeType::ReferenceExpression:
case NodeKind::ReferenceExpression:
{
auto Y = static_cast<ReferenceExpression*>(X);
auto Def = Y->getScope()->lookup(Y->Name->getSymbolPath());

View file

@ -1,10 +1,13 @@
#include <exception>
#include <vector>
#include "llvm/Support/Casting.h"
#include "bolt/CST.hpp"
#include "bolt/Scanner.hpp"
#include "bolt/Parser.hpp"
#include "bolt/Diagnostics.hpp"
#include <exception>
#include <vector>
namespace bolt {
@ -57,9 +60,9 @@ namespace bolt {
std::size_t I = 0;
for (;;) {
auto T0 = Tokens.peek(I++);
switch (T0->Type) {
case NodeType::PubKeyword:
case NodeType::MutKeyword:
switch (T0->getKind()) {
case NodeKind::PubKeyword:
case NodeKind::MutKeyword:
continue;
default:
return T0;
@ -70,71 +73,141 @@ namespace bolt {
#define BOLT_EXPECT_TOKEN(name) \
{ \
auto __Token = Tokens.get(); \
if (__Token->Type != NodeType::name) { \
throw UnexpectedTokenDiagnostic(File, __Token, std::vector<NodeType> { NodeType::name }); \
if (!llvm::isa<name>(__Token)) { \
throw UnexpectedTokenDiagnostic(File, __Token, std::vector<NodeKind> { NodeKind::name }); \
} \
}
Token* Parser::expectToken(NodeType Type) {
Token* Parser::expectToken(NodeKind Kind) {
auto T = Tokens.get();
if (T->Type != Type) {
throw UnexpectedTokenDiagnostic(File, T, std::vector<NodeType> { Type }); \
if (T->getKind() != Kind) {
throw UnexpectedTokenDiagnostic(File, T, std::vector<NodeKind> { Kind }); \
}
return T;
}
Pattern* Parser::parsePattern() {
auto T0 = Tokens.peek();
switch (T0->Type) {
case NodeType::Identifier:
switch (T0->getKind()) {
case NodeKind::Identifier:
Tokens.get();
return new BindPattern(static_cast<Identifier*>(T0));
default:
throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeType::Identifier });
throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeKind::Identifier });
}
}
QualifiedName* Parser::parseQualifiedName() {
std::vector<Identifier*> ModulePath;
auto Name = expectToken(NodeType::Identifier);
auto Name = expectToken(NodeKind::Identifier);
for (;;) {
auto T1 = Tokens.peek();
if (T1->Type != NodeType::Dot) {
if (T1->getKind() != NodeKind::Dot) {
break;
}
Tokens.get();
ModulePath.push_back(static_cast<Identifier*>(Name));
Name = Tokens.get();
if (Name->Type != NodeType::Identifier) {
throw UnexpectedTokenDiagnostic(File, Name, std::vector { NodeType::Identifier });
if (Name->getKind() != NodeKind::Identifier) {
throw UnexpectedTokenDiagnostic(File, Name, std::vector { NodeKind::Identifier });
}
}
return new QualifiedName(ModulePath, static_cast<Identifier*>(Name));
}
TypeExpression* Parser::parseTypeExpression() {
return parseQualifiedTypeExpression();
}
TypeExpression* Parser::parseQualifiedTypeExpression() {
bool HasConstraints = false;
auto T0 = Tokens.peek();
if (llvm::isa<LParen>(T0)) {
std::size_t I = 1;
for (;;) {
auto T0 = Tokens.peek(I++);
switch (T0->getKind()) {
case NodeKind::RArrowAlt:
HasConstraints = true;
goto after_scan;
case NodeKind::Equals:
case NodeKind::BlockStart:
case NodeKind::LineFoldEnd:
case NodeKind::EndOfFile:
goto after_scan;
default:
break;
}
}
}
after_scan:
if (!HasConstraints) {
return parseArrowTypeExpression();
}
Tokens.get();
LParen* LParen = static_cast<class LParen*>(T0);
std::vector<std::tuple<ConstraintExpression*, Comma*>> Constraints;
RParen* RParen;
RArrowAlt* RArrowAlt;
for (;;) {
ConstraintExpression* C;
auto T0 = Tokens.peek();
switch (T0->getKind()) {
case NodeKind::RParen:
Tokens.get();
RParen = static_cast<class RParen*>(T0);
RArrowAlt = expectToken<class RArrowAlt>();
goto after_constraints;
default:
C = parseConstraintExpression();
break;
}
Comma* Comma = nullptr;
auto T1 = Tokens.get();
switch (T1->getKind()) {
case NodeKind::Comma:
Constraints.push_back(std::make_tuple(C, static_cast<class Comma*>(T1)));
continue;
case NodeKind::RParen:
RArrowAlt = static_cast<class RArrowAlt*>(T1);
Constraints.push_back(std::make_tuple(C, nullptr));
RArrowAlt = expectToken<class RArrowAlt>();
goto after_constraints;
default:
throw UnexpectedTokenDiagnostic(File, T1, std::vector { NodeKind::Comma, NodeKind::RArrowAlt });
}
}
after_constraints:
auto TE = parseArrowTypeExpression();
return new QualifiedTypeExpression(Constraints, RArrowAlt, TE);
}
TypeExpression* Parser::parsePrimitiveTypeExpression() {
auto T0 = Tokens.peek();
switch (T0->Type) {
case NodeType::Identifier:
switch (T0->getKind()) {
case NodeKind::Identifier:
if (static_cast<Identifier*>(T0)->isTypeVar()) {
return parseVarTypeExpression();
}
return new ReferenceTypeExpression(parseQualifiedName());
default:
throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeType::Identifier });
throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeKind::Identifier });
}
}
TypeExpression* Parser::parseTypeExpression() {
TypeExpression* Parser::parseArrowTypeExpression() {
auto RetType = parsePrimitiveTypeExpression();
std::vector<TypeExpression*> ParamTypes;
for (;;) {
auto T1 = Tokens.peek();
if (T1->Type != NodeType::RArrow) {
if (T1->getKind() != NodeKind::RArrow) {
break;
}
Tokens.get();
ParamTypes.push_back(RetType);
RetType = parsePrimitiveTypeExpression();
}
if (ParamTypes.size()) {
if (!ParamTypes.empty()) {
return new ArrowTypeExpression(ParamTypes, RetType);
}
return RetType;
@ -142,25 +215,25 @@ namespace bolt {
Expression* Parser::parsePrimitiveExpression() {
auto T0 = Tokens.peek();
switch (T0->Type) {
case NodeType::Identifier:
switch (T0->getKind()) {
case NodeKind::Identifier:
{
auto Name = parseQualifiedName();
return new ReferenceExpression(Name);
}
case NodeType::LParen:
case NodeKind::LParen:
{
Tokens.get();
auto E = parseExpression();
auto T2 = static_cast<RParen*>(expectToken(NodeType::RParen));
auto T2 = static_cast<RParen*>(expectToken(NodeKind::RParen));
return new NestedExpression(static_cast<LParen*>(T0), E, T2);
}
case NodeType::IntegerLiteral:
case NodeType::StringLiteral:
case NodeKind::IntegerLiteral:
case NodeKind::StringLiteral:
Tokens.get();
return new ConstantExpression(T0);
default:
throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeType::Identifier, NodeType::IntegerLiteral, NodeType::StringLiteral });
throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeKind::Identifier, NodeKind::IntegerLiteral, NodeKind::StringLiteral });
}
}
@ -169,7 +242,7 @@ namespace bolt {
std::vector<Expression*> Args;
for (;;) {
auto T1 = Tokens.peek();
if (T1->Type == NodeType::LineFoldEnd || T1->Type == NodeType::RParen || T1->Type == NodeType::BlockStart || ExprOperators.isInfix(T1)) {
if (T1->getKind() == NodeKind::LineFoldEnd || T1->getKind() == NodeKind::RParen || T1->getKind() == NodeKind::BlockStart || ExprOperators.isInfix(T1)) {
break;
}
Args.push_back(parsePrimitiveExpression());
@ -192,7 +265,7 @@ namespace bolt {
}
auto E = parseCallExpression();
for (auto Iter = Prefix.rbegin(); Iter != Prefix.rend(); Iter++) {
E = new UnaryExpression(*Iter, E);
E = new PrefixExpression(*Iter, E);
}
return E;
}
@ -230,10 +303,10 @@ namespace bolt {
}
ReturnStatement* Parser::parseReturnStatement() {
auto T0 = static_cast<ReturnKeyword*>(expectToken(NodeType::ReturnKeyword));
auto T0 = static_cast<ReturnKeyword*>(expectToken(NodeKind::ReturnKeyword));
Expression* Expression = nullptr;
auto T1 = Tokens.peek();
if (T1->Type != NodeType::LineFoldEnd) {
if (T1->getKind() != NodeKind::LineFoldEnd) {
Expression = parseExpression();
}
BOLT_EXPECT_TOKEN(LineFoldEnd);
@ -242,13 +315,13 @@ namespace bolt {
IfStatement* Parser::parseIfStatement() {
std::vector<IfStatementPart*> Parts;
auto T0 = expectToken(NodeType::IfKeyword);
auto T0 = expectToken(NodeKind::IfKeyword);
auto Test = parseExpression();
auto T1 = static_cast<BlockStart*>(expectToken(NodeType::BlockStart));
auto T1 = static_cast<BlockStart*>(expectToken(NodeKind::BlockStart));
std::vector<Node*> Then;
for (;;) {
auto T2 = Tokens.peek();
if (T2->Type == NodeType::BlockEnd) {
if (T2->getKind() == NodeKind::BlockEnd) {
Tokens.get();
break;
}
@ -257,13 +330,13 @@ namespace bolt {
Parts.push_back(new IfStatementPart(T0, Test, T1, Then));
BOLT_EXPECT_TOKEN(LineFoldEnd)
auto T3 = Tokens.peek();
if (T3->Type == NodeType::ElseKeyword) {
if (T3->getKind() == NodeKind::ElseKeyword) {
Tokens.get();
auto T4 = static_cast<BlockStart*>(expectToken(NodeType::BlockStart));
auto T4 = static_cast<BlockStart*>(expectToken(NodeKind::BlockStart));
std::vector<Node*> Else;
for (;;) {
auto T5 = Tokens.peek();
if (T5->Type == NodeType::BlockEnd) {
if (T5->getKind() == NodeKind::BlockEnd) {
Tokens.get();
break;
}
@ -281,41 +354,41 @@ namespace bolt {
LetKeyword* Let;
MutKeyword* Mut = nullptr;
auto T0 = Tokens.get();
if (T0->Type == NodeType::PubKeyword) {
if (T0->getKind() == NodeKind::PubKeyword) {
Pub = static_cast<PubKeyword*>(T0);
T0 = Tokens.get();
}
if (T0->Type != NodeType::LetKeyword) {
throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeType::LetKeyword });
if (T0->getKind() != NodeKind::LetKeyword) {
throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeKind::LetKeyword });
}
Let = static_cast<LetKeyword*>(T0);
auto T1 = Tokens.peek();
if (T1->Type == NodeType::MutKeyword) {
if (T1->getKind() == NodeKind::MutKeyword) {
Mut = static_cast<MutKeyword*>(T1);
Tokens.get();
}
auto Patt = parsePattern();
std::vector<Param*> Params;
std::vector<Parameter*> Params;
Token* T2;
for (;;) {
T2 = Tokens.peek();
switch (T2->Type) {
case NodeType::LineFoldEnd:
case NodeType::BlockStart:
case NodeType::Equals:
case NodeType::Colon:
switch (T2->getKind()) {
case NodeKind::LineFoldEnd:
case NodeKind::BlockStart:
case NodeKind::Equals:
case NodeKind::Colon:
goto after_params;
default:
Params.push_back(new Param(parsePattern(), nullptr));
Params.push_back(new Parameter(parsePattern(), nullptr));
}
}
after_params:
TypeAssert* TA = nullptr;
if (T2->Type == NodeType::Colon) {
if (T2->getKind() == NodeKind::Colon) {
Tokens.get();
auto TE = parseTypeExpression();
TA = new TypeAssert(static_cast<Colon*>(T2), TE);
@ -323,14 +396,14 @@ after_params:
}
LetBody* Body;
switch (T2->Type) {
case NodeType::BlockStart:
switch (T2->getKind()) {
case NodeKind::BlockStart:
{
Tokens.get();
std::vector<Node*> Elements;
for (;;) {
auto T3 = Tokens.peek();
if (T3->Type == NodeType::BlockEnd) {
if (T3->getKind() == NodeKind::BlockEnd) {
break;
}
Elements.push_back(parseLetBodyElement());
@ -339,20 +412,20 @@ after_params:
Body = new LetBlockBody(static_cast<BlockStart*>(T2), Elements);
break;
}
case NodeType::Equals:
case NodeKind::Equals:
Tokens.get();
Body = new LetExprBody(static_cast<Equals*>(T2), parseExpression());
break;
case NodeType::LineFoldEnd:
case NodeKind::LineFoldEnd:
Body = nullptr;
break;
default:
std::vector<NodeType> Expected { NodeType::BlockStart, NodeType::LineFoldEnd, NodeType::Equals };
std::vector<NodeKind> Expected { NodeKind::BlockStart, NodeKind::LineFoldEnd, NodeKind::Equals };
if (TA == nullptr) {
// First tokens of TypeAssert
Expected.push_back(NodeType::Colon);
Expected.push_back(NodeKind::Colon);
// First tokens of Pattern
Expected.push_back(NodeType::Identifier);
Expected.push_back(NodeKind::Identifier);
}
throw UnexpectedTokenDiagnostic(File, T2, Expected);
}
@ -372,25 +445,161 @@ after_params:
Node* Parser::parseLetBodyElement() {
auto T0 = peekFirstTokenAfterModifiers();
switch (T0->Type) {
case NodeType::LetKeyword:
switch (T0->getKind()) {
case NodeKind::LetKeyword:
return parseLetDeclaration();
case NodeType::ReturnKeyword:
case NodeKind::ReturnKeyword:
return parseReturnStatement();
case NodeType::IfKeyword:
case NodeKind::IfKeyword:
return parseIfStatement();
default:
return parseExpressionStatement();
}
}
ConstraintExpression* Parser::parseConstraintExpression() {
bool HasTilde = false;
for (std::size_t I = 0; ; I++) {
auto Tok = Tokens.peek(I);
switch (Tok->getKind()) {
case NodeKind::Tilde:
HasTilde = true;
goto after_seek;
case NodeKind::RParen:
case NodeKind::Comma:
case NodeKind::RArrowAlt:
case NodeKind::EndOfFile:
goto after_seek;
default:
continue;
}
}
after_seek:
if (HasTilde) {
auto Left = parseArrowTypeExpression();
auto Tilde = expectToken<class Tilde>();
auto Right = parseArrowTypeExpression();
return new EqualityConstraintExpression { Left, Tilde, Right };
}
auto Name = expectToken<Identifier>();
std::vector<VarTypeExpression*> TEs;
for (;;) {
auto T1 = Tokens.peek();
switch (T1->getKind()) {
case NodeKind::RParen:
case NodeKind::RArrowAlt:
case NodeKind::Comma:
goto after_vars;
case NodeKind::Identifier:
Tokens.get();
TEs.push_back(new VarTypeExpression { static_cast<Identifier*>(T1) });
break;
default:
throw UnexpectedTokenDiagnostic(File, T1, std::vector { NodeKind::RParen, NodeKind::RArrowAlt, NodeKind::Comma, NodeKind::Identifier });
}
}
after_vars:
return new TypeclassConstraintExpression { Name, TEs };
}
VarTypeExpression* Parser::parseVarTypeExpression() {
auto Name = expectToken<Identifier>();
// TODO reject constructor symbols (starting with a capital letter)
return new VarTypeExpression { Name };
}
InstanceDeclaration* Parser::parseInstanceDeclaration() {
auto InstanceKeyword = expectToken<class InstanceKeyword>();
auto Name = expectToken<Identifier>();
std::vector<TypeExpression*> TypeExps;
for (;;) {
auto T1 = Tokens.peek();
if (T1->is<BlockStart>()) {
break;
}
TypeExps.push_back(parseTypeExpression());
}
auto BlockStart = expectToken<class BlockStart>();
std::vector<Node*> Elements;
for (;;) {
auto T2 = Tokens.peek();
if (T2->is<BlockEnd>()) {
Tokens.get();
break;
}
Elements.push_back(parseClassElement());
}
expectToken(NodeKind::LineFoldEnd);
return new InstanceDeclaration(
InstanceKeyword,
Name,
TypeExps,
BlockStart,
Elements
);
}
ClassDeclaration* Parser::parseClassDeclaration() {
PubKeyword* PubKeyword = nullptr;
auto T0 = Tokens.peek();
if (T0->getKind() == NodeKind::PubKeyword) {
Tokens.get();
PubKeyword = static_cast<class PubKeyword*>(T0);
}
auto ClassKeyword = expectToken<class ClassKeyword>();
auto Name = expectToken<Identifier>();
std::vector<VarTypeExpression*> TypeVars;
for (;;) {
auto T2 = Tokens.peek();
if (T2->getKind() == NodeKind::BlockStart) {
break;
}
TypeVars.push_back(parseVarTypeExpression());
}
auto BlockStart = expectToken<class BlockStart>();
std::vector<Node*> Elements;
for (;;) {
auto T2 = Tokens.peek();
if (T2->is<BlockEnd>()) {
Tokens.get();
break;
}
Elements.push_back(parseClassElement());
}
expectToken(NodeKind::LineFoldEnd);
return new ClassDeclaration(
PubKeyword,
ClassKeyword,
Name,
TypeVars,
BlockStart,
Elements
);
}
Node* Parser::parseClassElement() {
auto T0 = Tokens.peek();
switch (T0->getKind()) {
case NodeKind::LetKeyword:
return parseLetDeclaration();
case NodeKind::TypeKeyword:
// TODO
default:
throw UnexpectedTokenDiagnostic(File, T0, std::vector<NodeKind> { NodeKind::LetKeyword, NodeKind::TypeKeyword });
}
}
Node* Parser::parseSourceElement() {
auto T0 = peekFirstTokenAfterModifiers();
switch (T0->Type) {
case NodeType::LetKeyword:
switch (T0->getKind()) {
case NodeKind::LetKeyword:
return parseLetDeclaration();
case NodeType::IfKeyword:
case NodeKind::IfKeyword:
return parseIfStatement();
case NodeKind::ClassKeyword:
return parseClassDeclaration();
case NodeKind::InstanceKeyword:
return parseInstanceDeclaration();
default:
return parseExpressionStatement();
}
@ -400,7 +609,7 @@ after_params:
std::vector<Node*> Elements;
for (;;) {
auto T0 = Tokens.peek();
if (T0->Type == NodeType::EndOfFile) {
if (T0->is<EndOfFile>()) {
break;
}
Elements.push_back(parseSourceElement());

View file

@ -3,6 +3,8 @@
#include "zen/config.hpp"
#include "llvm/Support/Casting.h"
#include "bolt/Text.hpp"
#include "bolt/Integer.hpp"
#include "bolt/CST.hpp"
@ -57,16 +59,18 @@ namespace bolt {
return Chr - 48;
}
std::unordered_map<ByteString, NodeType> Keywords = {
{ "pub", NodeType::PubKeyword },
{ "let", NodeType::LetKeyword },
{ "mut", NodeType::MutKeyword },
{ "return", NodeType::ReturnKeyword },
{ "type", NodeType::TypeKeyword },
{ "mod", NodeType::ModKeyword },
{ "if", NodeType::IfKeyword },
{ "else", NodeType::ElseKeyword },
{ "elif", NodeType::ElifKeyword },
std::unordered_map<ByteString, NodeKind> Keywords = {
{ "pub", NodeKind::PubKeyword },
{ "let", NodeKind::LetKeyword },
{ "mut", NodeKind::MutKeyword },
{ "return", NodeKind::ReturnKeyword },
{ "type", NodeKind::TypeKeyword },
{ "mod", NodeKind::ModKeyword },
{ "if", NodeKind::IfKeyword },
{ "else", NodeKind::ElseKeyword },
{ "elif", NodeKind::ElifKeyword },
{ "class", NodeKind::ClassKeyword },
{ "instance", NodeKind::InstanceKeyword },
};
Scanner::Scanner(TextFile& File, Stream<Char>& Chars):
@ -202,22 +206,26 @@ digit_finish:
auto Match = Keywords.find(Text);
if (Match != Keywords.end()) {
switch (Match->second) {
case NodeType::PubKeyword:
case NodeKind::PubKeyword:
return new PubKeyword(StartLoc);
case NodeType::LetKeyword:
case NodeKind::LetKeyword:
return new LetKeyword(StartLoc);
case NodeType::MutKeyword:
case NodeKind::MutKeyword:
return new MutKeyword(StartLoc);
case NodeType::TypeKeyword:
case NodeKind::TypeKeyword:
return new TypeKeyword(StartLoc);
case NodeType::ReturnKeyword:
case NodeKind::ReturnKeyword:
return new ReturnKeyword(StartLoc);
case NodeType::IfKeyword:
case NodeKind::IfKeyword:
return new IfKeyword(StartLoc);
case NodeType::ElifKeyword:
case NodeKind::ElifKeyword:
return new ElifKeyword(StartLoc);
case NodeType::ElseKeyword:
case NodeKind::ElseKeyword:
return new ElseKeyword(StartLoc);
case NodeKind::ClassKeyword:
return new ClassKeyword(StartLoc);
case NodeKind::InstanceKeyword:
return new InstanceKeyword(StartLoc);
default:
ZEN_UNREACHABLE
}
@ -305,6 +313,8 @@ after_string_contents:
}
if (Text == "->") {
return new RArrow(StartLoc);
} else if (Text == "=>") {
return new RArrowAlt(StartLoc);
} else if (Text == "=") {
return new Equals(StartLoc);
} else if (Text.back() == '=' && Text[Text.size()-2] != '=') {
@ -316,7 +326,7 @@ after_string_contents:
#define BOLT_SIMPLE_TOKEN(ch, name) case ch: return new name(StartLoc);
//BOLT_SIMPLE_TOKEN(',', Comma)
BOLT_SIMPLE_TOKEN(',', Comma)
BOLT_SIMPLE_TOKEN(':', Colon)
BOLT_SIMPLE_TOKEN('(', LParen)
BOLT_SIMPLE_TOKEN(')', RParen)
@ -324,6 +334,7 @@ after_string_contents:
BOLT_SIMPLE_TOKEN(']', RBracket)
BOLT_SIMPLE_TOKEN('{', LBrace)
BOLT_SIMPLE_TOKEN('}', RBrace)
BOLT_SIMPLE_TOKEN('~', Tilde)
default:
throw UnexpectedStringDiagnostic(File, StartLoc, String { static_cast<char>(C0) });
@ -342,7 +353,7 @@ after_string_contents:
auto T0 = Tokens.peek();
if (T0->Type == NodeType::EndOfFile) {
if (llvm::isa<EndOfFile>(T0)) {
if (Frames.size() == 1) {
return T0;
}
@ -366,7 +377,7 @@ after_string_contents:
Locations.pop();
return new LineFoldEnd(T0->getStartLoc());
}
if (T0->Type == NodeType::Dot) {
if (llvm::isa<Dot>(T0)) {
auto T1 = Tokens.peek(1);
if (T1->getStartLine() > T0->getEndLine()) {
Tokens.get();

View file

@ -16,18 +16,18 @@ auto checkExpression(std::string Input) {
Scanner S(T, Chars);
Punctuator PT(S);
Parser P(T, PT);
LanguageConfig Config;
auto SF = P.parseSourceFile();
Checker C(DS);
auto Solution = C.check(SF);
Checker C(Config, DS);
C.check(SF);
return std::make_tuple(
static_cast<ExpressionStatement*>(SF->Elements[0])->Expression,
C,
Solution
C
);
}
TEST(CheckerTest, InfersIntFromIntegerLiteral) {
auto [Expression, Checker, Solution] = checkExpression("1");
ASSERT_EQ(Checker.getType(Expression, Solution), Checker.getIntType());
auto [Expression, Checker] = checkExpression("1");
ASSERT_EQ(Checker.getType(Expression), Checker.getIntType());
}

View file

@ -37,6 +37,7 @@ int main(int argc, const char* argv[]) {
}
ConsoleDiagnostics DE;
LanguageConfig Config;
auto Text = readFile(argv[1]);
TextFile File { argv[1], Text };
@ -56,7 +57,7 @@ int main(int argc, const char* argv[]) {
SF->setParents();
Checker TheChecker { DE };
Checker TheChecker { Config, DE };
TheChecker.check(SF);
return 0;