From db26fd3b183f7c7ef2c3282adc8028b605cbf8b5 Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Sat, 20 May 2023 23:48:26 +0200 Subject: [PATCH] Add experimental support for type classes and many more enhancements --- .vscode/launch.json | 4 +- .vscode/settings.json | 51 +- .vscode/tasks.json | 5 +- CMakeLists.txt | 14 +- bolt-cst-spec.txt | 167 ------ include/bolt/CST.hpp | 928 +++++++++++++++++++++++----------- include/bolt/CSTVisitor.hpp | 951 +++++++++++++++++++++++++++++++++++ include/bolt/Checker.hpp | 350 +++++++++---- include/bolt/Diagnostics.hpp | 57 ++- include/bolt/Parser.hpp | 27 +- include/bolt/Scanner.hpp | 68 +-- include/bolt/Text.hpp | 3 +- scripts/CST.cc.tply | 66 --- scripts/CST.hpp.tply | 118 ----- scripts/gennodes.py | 848 ------------------------------- src/CST.cc | 586 ++++++--------------- src/Checker.cc | 825 +++++++++++++++++++++++------- src/Diagnostics.cc | 121 +++-- src/IPRGraph.cc | 20 +- src/Parser.cc | 345 ++++++++++--- src/Scanner.cc | 53 +- src/TestChecker.cc | 12 +- src/main.cc | 3 +- 23 files changed, 3176 insertions(+), 2446 deletions(-) delete mode 100644 bolt-cst-spec.txt create mode 100644 include/bolt/CSTVisitor.hpp delete mode 100644 scripts/CST.cc.tply delete mode 100644 scripts/CST.hpp.tply delete mode 100755 scripts/gennodes.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 0e0386686..87e8e6c7b 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -9,9 +9,9 @@ "request": "launch", "name": "Debug", "program": "${workspaceFolder}/build/bolt", - "args": ["test.bolt"], + "args": [ "test.bolt" ], "cwd": "${workspaceFolder}", "preLaunchTask": "CMake: build" } ] -} \ No newline at end of file +} diff --git a/.vscode/settings.json b/.vscode/settings.json index 5efd846a1..fbb07df58 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -27,6 +27,53 @@ "initializer_list": "cpp", "numeric": "cpp", "ostream": "cpp", - "system_error": "cpp" - } + "system_error": "cpp", + "cctype": "cpp", + "clocale": "cpp", + "cstdarg": "cpp", + "cstddef": "cpp", + "cstdio": "cpp", + "cstring": "cpp", + "ctime": "cpp", + "cwchar": "cpp", + "cwctype": "cpp", + "any": "cpp", + "atomic": "cpp", + "strstream": "cpp", + "bit": "cpp", + "bitset": "cpp", + "cinttypes": "cpp", + "codecvt": "cpp", + "compare": "cpp", + "complex": "cpp", + "concepts": "cpp", + "condition_variable": "cpp", + "coroutine": "cpp", + "cstdint": "cpp", + "map": "cpp", + "set": "cpp", + "algorithm": "cpp", + "iterator": "cpp", + "memory_resource": "cpp", + "optional": "cpp", + "ratio": "cpp", + "tuple": "cpp", + "type_traits": "cpp", + "utility": "cpp", + "iomanip": "cpp", + "iostream": "cpp", + "mutex": "cpp", + "new": "cpp", + "numbers": "cpp", + "semaphore": "cpp", + "shared_mutex": "cpp", + "stdexcept": "cpp", + "stop_token": "cpp", + "thread": "cpp", + "cfenv": "cpp", + "typeindex": "cpp", + "variant": "cpp", + "__nullptr": "cpp" + }, + "C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools" } \ No newline at end of file diff --git a/.vscode/tasks.json b/.vscode/tasks.json index 4edc48b2e..5fb03c01f 100644 --- a/.vscode/tasks.json +++ b/.vscode/tasks.json @@ -8,7 +8,10 @@ "targets": [ "all" ], - "group": "build", + "group": { + "kind": "build", + "isDefault": true + }, "problemMatcher": [], "detail": "CMake template build task" } diff --git a/CMakeLists.txt b/CMakeLists.txt index b25cf3ac8..503ab5c25 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -73,11 +73,11 @@ if (BOLT_ENABLE_TESTS) ) endif() -#add_custom_command( -# OUTPUT "${CMAKE_CURRENT_SOURCE_DIR}/include/bolt/CST.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/CST.cc" -# COMMAND scripts/gennodes.py --name=CST ./bolt-cst-spec.txt -Iinclude/ --include-root=bolt --source-root=src/ --namespace=bolt -# DEPENDS scripts/gennodes.py -# MAIN_DEPENDENCY "${CMAKE_CURRENT_SOURCE_DIR}/bolt-cst-spec.txt" -# WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}" -#) +# add_custom_command( +# OUTPUT "${CMAKE_CURRENT_SOURCE_DIR}/include/bolt/CST.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/CST.cc" +# COMMAND scripts/gennodes.py --name=CST ./bolt-cst-spec.txt -Iinclude/ --include-root=bolt --source-root=src/ --namespace=bolt +# DEPENDS scripts/gennodes.py +# MAIN_DEPENDENCY "${CMAKE_CURRENT_SOURCE_DIR}/bolt-cst-spec.txt" +# WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}" +# ) diff --git a/bolt-cst-spec.txt b/bolt-cst-spec.txt deleted file mode 100644 index f46f8a47b..000000000 --- a/bolt-cst-spec.txt +++ /dev/null @@ -1,167 +0,0 @@ - -#include -#include - -#include "bolt/Text.hpp" -#include "bolt/Integer.hpp" -#include "bolt/ByteString.hpp" - -external Integer; -external ByteString; -external TextRange; - -// Tokens - -node Token { - TextLoc start_loc; -} - -node Equals : Token {} -node Colon : Token {} -node Dot : Token {} -node LParen : Token {} -node RParen : Token {} -node LBracket : Token {} -node RBracket : Token {} -node LBrace : Token {} -node RBrace : Token {} - -node LetKeyword : Token {} -node MutKeyword : Token {} -node PubKeyword : Token {} -node TypeKeyword : Token {} -node ReturnKeyword : Token {} -node ModKeyword : Token {} -node StructKeyword : Token {} - -node Invalid : Token {} - -node EndOfFile : Token {} -node BlockStart : Token {} -node BlockEnd : Token {} -node LineFoldEnd : Token {} - -node CustomOperator : Token { - ByteString text; -} - -node Identifier : Token { - ByteString text; -} - -node StringLiteral : Token { - ByteString text; -} - -node IntegerLiteral : Token { - Integer value; -} - -node QualifiedName { - List module_path; - Identifier name; -} - -node SourceElement {} - -node LetBodyElement {} - -// Type expressions - -node TypeExpression {} - -node ReferenceTypeExpression : TypeExpression { - QualifiedName name; -} - -// Patterns - -node Pattern {} - -node BindPattern : Pattern { - Identifier name; -} - -// Expresssions - -node Expression {} - -node ReferenceExpression : Expression { - Identifier name; -} - -node ConstantExpression : Expression { - Variant token; -} - -node CallExpression : Expression { - Expression function; - List args; -} - -// Statements - -node Statement : LetBodyElement {} - -node ExpressionStatement : Statement, SourceElement { - Expression expression; -} - -node ReturnStatement : Statement { - ReturnKeyword return_keyword; - Expression expression; -} - -// Other nodes - -node TypeAssert { - Colon colon; - TypeExpression type_expression; -} - -node Param { - Pattern pattern; - TypeAssert type_assert; -} - -// Declarations - -node LetBody {} - -node LetBlockBody : LetBody { - BlockStart block_start; - List elements; -} - -node LetExprBody : LetBody { - Equals equals; - Expression expression; -} - -node LetDeclaration : SourceElement, LetBodyElement { - Option pub_keyword; - LetKeyword let_keywod; - Option mut_keyword; - Pattern pattern; - List params; - Option type_assert; - Option body; -} - -node StructDeclField { - Identifier name; - Colon colon; - TypeExpression type_expression; -} - -node StructDecl : SourceElement { - StructKeyword struct_keyword; - Identifier name; - Dot dot; - List fields; -} - -node SourceFile { - List elements; -} - diff --git a/include/bolt/CST.hpp b/include/bolt/CST.hpp index beebf4cb8..e1c52b004 100644 --- a/include/bolt/CST.hpp +++ b/include/bolt/CST.hpp @@ -1,27 +1,36 @@ #ifndef BOLT_CST_HPP #define BOLT_CST_HPP +#include #include #include #include #include +#include "zen/config.hpp" + #include "bolt/Text.hpp" #include "bolt/Integer.hpp" #include "bolt/ByteString.hpp" namespace bolt { + class Type; + class Token; class SourceFile; class Scope; class Pattern; + class Expression; + class Statement; - enum class NodeType { + enum class NodeKind { Equals, Colon, + Comma, Dot, DotDot, + Tilde, LParen, RParen, LBracket, @@ -29,6 +38,7 @@ namespace bolt { LBrace, RBrace, RArrow, + RArrowAlt, LetKeyword, MutKeyword, PubKeyword, @@ -36,6 +46,8 @@ namespace bolt { ReturnKeyword, ModKeyword, StructKeyword, + ClassKeyword, + InstanceKeyword, ElifKeyword, IfKeyword, ElseKeyword, @@ -50,26 +62,32 @@ namespace bolt { StringLiteral, IntegerLiteral, QualifiedName, + TypeclassConstraintExpression, + EqualityConstraintExpression, + QualifiedTypeExpression, ReferenceTypeExpression, ArrowTypeExpression, + VarTypeExpression, BindPattern, ReferenceExpression, NestedExpression, ConstantExpression, CallExpression, InfixExpression, - UnaryExpression, + PrefixExpression, ExpressionStatement, ReturnStatement, IfStatement, IfStatementPart, TypeAssert, - Param, + Parameter, LetBlockBody, LetExprBody, LetDeclaration, - StructDeclField, - StructDecl, + StructDeclarationField, + StructDeclaration, + ClassDeclaration, + InstanceDeclaration, SourceFile, }; @@ -78,9 +96,14 @@ namespace bolt { ByteString Name; }; + template + NodeKind getNodeType(); + class Node { - unsigned RefCount = 0; + unsigned RefCount = 1; + + const NodeKind Kind; public: @@ -97,17 +120,40 @@ namespace bolt { } } - virtual void setParents() = 0; - + void setParents(); + virtual Token* getFirstToken() = 0; virtual Token* getLastToken() = 0; + inline NodeKind getKind() const noexcept { + return Kind; + } + + template + bool is() const noexcept { + return Kind == getNodeType(); + } + + template<> + bool is() const noexcept { + return Kind == NodeKind::ReferenceExpression + || Kind == NodeKind::ConstantExpression + || Kind == NodeKind::PrefixExpression + || Kind == NodeKind::InfixExpression + || Kind == NodeKind::CallExpression + || Kind == NodeKind::NestedExpression; + } + + template + T* as() { + ZEN_ASSERT(is()); + return static_cast(this); + } + TextRange getRange(); - const NodeType Type; - - inline Node(NodeType Type): - Type(Type) {} + inline Node(NodeKind Type): + Kind(Type) {} SourceFile* getSourceFile(); @@ -142,13 +188,11 @@ namespace bolt { public: - Token(NodeType Type, TextLoc StartLoc): Node(Type), StartLoc(StartLoc) {} + Token(NodeKind Type, TextLoc StartLoc): Node(Type), StartLoc(StartLoc) {} virtual std::string getText() const = 0; - void setParents() override; - - inline Token* getFirstToken() override { + inline Token* getFirstToken() override { return this; } @@ -178,319 +222,439 @@ namespace bolt { return getEndLoc().Column; } - ~Token(); - }; class Equals : public Token { public: - Equals(TextLoc StartLoc): - Token(NodeType::Equals, StartLoc) {} + inline Equals(TextLoc StartLoc): + Token(NodeKind::Equals, StartLoc) {} std::string getText() const override; - ~Equals(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::Equals; + } }; class Colon : public Token { public: - Colon(TextLoc StartLoc): - Token(NodeType::Colon, StartLoc) {} + inline Colon(TextLoc StartLoc): + Token(NodeKind::Colon, StartLoc) {} std::string getText() const override; - ~Colon(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::Colon; + } }; - class RArrow : public Token { + class Comma : public Token { public: - RArrow(TextLoc StartLoc): - Token(NodeType::RArrow, StartLoc) {} + inline Comma(TextLoc StartLoc): + Token(NodeKind::Comma, StartLoc) {} std::string getText() const override; - ~RArrow(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::Comma; + } }; class Dot : public Token { public: - Dot(TextLoc StartLoc): - Token(NodeType::Dot, StartLoc) {} + inline Dot(TextLoc StartLoc): + Token(NodeKind::Dot, StartLoc) {} std::string getText() const override; - ~Dot(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::Dot; + } }; class DotDot : public Token { public: - DotDot(TextLoc StartLoc): - Token(NodeType::DotDot, StartLoc) {} + inline DotDot(TextLoc StartLoc): + Token(NodeKind::DotDot, StartLoc) {} std::string getText() const override; - ~DotDot(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::DotDot; + } + + }; + + class Tilde : public Token { + public: + + inline Tilde(TextLoc StartLoc): + Token(NodeKind::Tilde, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::Tilde; + } }; class LParen : public Token { public: - LParen(TextLoc StartLoc): - Token(NodeType::LParen, StartLoc) {} + inline LParen(TextLoc StartLoc): + Token(NodeKind::LParen, StartLoc) {} std::string getText() const override; - ~LParen(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::LParen; + } }; class RParen : public Token { public: - RParen(TextLoc StartLoc): - Token(NodeType::RParen, StartLoc) {} + inline RParen(TextLoc StartLoc): + Token(NodeKind::RParen, StartLoc) {} std::string getText() const override; - ~RParen(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::RParen; + } }; class LBracket : public Token { public: - LBracket(TextLoc StartLoc): - Token(NodeType::LBracket, StartLoc) {} + inline LBracket(TextLoc StartLoc): + Token(NodeKind::LBracket, StartLoc) {} std::string getText() const override; - ~LBracket(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::LBracket; + } }; class RBracket : public Token { public: - RBracket(TextLoc StartLoc): - Token(NodeType::RBracket, StartLoc) {} + inline RBracket(TextLoc StartLoc): + Token(NodeKind::RBracket, StartLoc) {} std::string getText() const override; - ~RBracket(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::RBracket; + } }; class LBrace : public Token { public: - LBrace(TextLoc StartLoc): - Token(NodeType::LBrace, StartLoc) {} + inline LBrace(TextLoc StartLoc): + Token(NodeKind::LBrace, StartLoc) {} std::string getText() const override; - ~LBrace(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::LBrace; + } }; class RBrace : public Token { public: - RBrace(TextLoc StartLoc): - Token(NodeType::RBrace, StartLoc) {} + inline RBrace(TextLoc StartLoc): + Token(NodeKind::RBrace, StartLoc) {} std::string getText() const override; - ~RBrace(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::RBrace; + } + + }; + + class RArrow : public Token { + public: + + inline RArrow(TextLoc StartLoc): + Token(NodeKind::RArrow, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::RArrow; + } + + }; + + class RArrowAlt : public Token { + public: + + inline RArrowAlt(TextLoc StartLoc): + Token(NodeKind::RArrowAlt, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::RArrowAlt; + } }; class LetKeyword : public Token { public: - LetKeyword(TextLoc StartLoc): - Token(NodeType::LetKeyword, StartLoc) {} + inline LetKeyword(TextLoc StartLoc): + Token(NodeKind::LetKeyword, StartLoc) {} std::string getText() const override; - ~LetKeyword(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::LetKeyword; + } }; class MutKeyword : public Token { public: - MutKeyword(TextLoc StartLoc): - Token(NodeType::MutKeyword, StartLoc) {} + inline MutKeyword(TextLoc StartLoc): + Token(NodeKind::MutKeyword, StartLoc) {} std::string getText() const override; - ~MutKeyword(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::MutKeyword; + } }; class PubKeyword : public Token { public: - PubKeyword(TextLoc StartLoc): - Token(NodeType::PubKeyword, StartLoc) {} + inline PubKeyword(TextLoc StartLoc): + Token(NodeKind::PubKeyword, StartLoc) {} std::string getText() const override; - ~PubKeyword(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::PubKeyword; + } }; class TypeKeyword : public Token { public: - TypeKeyword(TextLoc StartLoc): - Token(NodeType::TypeKeyword, StartLoc) {} + inline TypeKeyword(TextLoc StartLoc): + Token(NodeKind::TypeKeyword, StartLoc) {} std::string getText() const override; - ~TypeKeyword(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::TypeKeyword; + } }; class ReturnKeyword : public Token { public: - ReturnKeyword(TextLoc StartLoc): - Token(NodeType::ReturnKeyword, StartLoc) {} + inline ReturnKeyword(TextLoc StartLoc): + Token(NodeKind::ReturnKeyword, StartLoc) {} std::string getText() const override; - ~ReturnKeyword(); - - }; - - class ElseKeyword : public Token { - public: - - ElseKeyword(TextLoc StartLoc): - Token(NodeType::ElseKeyword, StartLoc) {} - - std::string getText() const override; - - ~ElseKeyword(); - - }; - - class ElifKeyword : public Token { - public: - - ElifKeyword(TextLoc StartLoc): - Token(NodeType::ElifKeyword, StartLoc) {} - - std::string getText() const override; - - ~ElifKeyword(); - - }; - - class IfKeyword : public Token { - public: - - IfKeyword(TextLoc StartLoc): - Token(NodeType::IfKeyword, StartLoc) {} - - std::string getText() const override; - - ~IfKeyword(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::ReturnKeyword; + } }; class ModKeyword : public Token { public: - ModKeyword(TextLoc StartLoc): - Token(NodeType::ModKeyword, StartLoc) {} + inline ModKeyword(TextLoc StartLoc): + Token(NodeKind::ModKeyword, StartLoc) {} std::string getText() const override; - ~ModKeyword(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::ModKeyword; + } }; class StructKeyword : public Token { public: - StructKeyword(TextLoc StartLoc): - Token(NodeType::StructKeyword, StartLoc) {} + inline StructKeyword(TextLoc StartLoc): + Token(NodeKind::StructKeyword, StartLoc) {} std::string getText() const override; - ~StructKeyword(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::StructKeyword; + } + + }; + + class ClassKeyword : public Token { + public: + + inline ClassKeyword(TextLoc StartLoc): + Token(NodeKind::ClassKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::ClassKeyword; + } + + }; + + class InstanceKeyword : public Token { + public: + + inline InstanceKeyword(TextLoc StartLoc): + Token(NodeKind::InstanceKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::InstanceKeyword; + } + + }; + + class ElifKeyword : public Token { + public: + + inline ElifKeyword(TextLoc StartLoc): + Token(NodeKind::ElifKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::ElifKeyword; + } + + }; + + class IfKeyword : public Token { + public: + + inline IfKeyword(TextLoc StartLoc): + Token(NodeKind::IfKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::IfKeyword; + } + + }; + + class ElseKeyword : public Token { + public: + + inline ElseKeyword(TextLoc StartLoc): + Token(NodeKind::ElseKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::ElseKeyword; + } }; class Invalid : public Token { public: - Invalid(TextLoc StartLoc): - Token(NodeType::Invalid, StartLoc) {} + inline Invalid(TextLoc StartLoc): + Token(NodeKind::Invalid, StartLoc) {} std::string getText() const override; - ~Invalid(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::Invalid; + } }; class EndOfFile : public Token { public: - EndOfFile(TextLoc StartLoc): - Token(NodeType::EndOfFile, StartLoc) {} + inline EndOfFile(TextLoc StartLoc): + Token(NodeKind::EndOfFile, StartLoc) {} std::string getText() const override; - ~EndOfFile(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::EndOfFile; + } }; class BlockStart : public Token { public: - BlockStart(TextLoc StartLoc): - Token(NodeType::BlockStart, StartLoc) {} + inline BlockStart(TextLoc StartLoc): + Token(NodeKind::BlockStart, StartLoc) {} std::string getText() const override; - ~BlockStart(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::BlockStart; + } }; class BlockEnd : public Token { public: - BlockEnd(TextLoc StartLoc): - Token(NodeType::BlockEnd, StartLoc) {} + inline BlockEnd(TextLoc StartLoc): + Token(NodeKind::BlockEnd, StartLoc) {} std::string getText() const override; - ~BlockEnd(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::BlockEnd; + } }; class LineFoldEnd : public Token { public: - LineFoldEnd(TextLoc StartLoc): - Token(NodeType::LineFoldEnd, StartLoc) {} + inline LineFoldEnd(TextLoc StartLoc): + Token(NodeKind::LineFoldEnd, StartLoc) {} std::string getText() const override; - ~LineFoldEnd(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::LineFoldEnd; + } }; @@ -500,11 +664,13 @@ namespace bolt { ByteString Text; CustomOperator(ByteString Text, TextLoc StartLoc): - Token(NodeType::CustomOperator, StartLoc), Text(Text) {} + Token(NodeKind::CustomOperator, StartLoc), Text(Text) {} std::string getText() const override; - ~CustomOperator(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::CustomOperator; + } }; @@ -514,11 +680,13 @@ namespace bolt { ByteString Text; Assignment(ByteString Text, TextLoc StartLoc): - Token(NodeType::Assignment, StartLoc), Text(Text) {} + Token(NodeKind::Assignment, StartLoc), Text(Text) {} std::string getText() const override; - ~Assignment(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::Assignment; + } }; @@ -528,11 +696,15 @@ namespace bolt { ByteString Text; Identifier(ByteString Text, TextLoc StartLoc): - Token(NodeType::Identifier, StartLoc), Text(Text) {} + Token(NodeKind::Identifier, StartLoc), Text(Text) {} std::string getText() const override; - ~Identifier(); + bool isTypeVar() const; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::Identifier; + } }; @@ -542,11 +714,13 @@ namespace bolt { ByteString Text; StringLiteral(ByteString Text, TextLoc StartLoc): - Token(NodeType::StringLiteral, StartLoc), Text(Text) {} + Token(NodeKind::StringLiteral, StartLoc), Text(Text) {} std::string getText() const override; - ~StringLiteral(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::StringLiteral; + } }; @@ -556,14 +730,15 @@ namespace bolt { Integer Value; IntegerLiteral(Integer Value, TextLoc StartLoc): - Token(NodeType::IntegerLiteral, StartLoc), Value(Value) {} + Token(NodeKind::IntegerLiteral, StartLoc), Value(Value) {} std::string getText() const override; - ~IntegerLiteral(); + static bool classof(const Node* N) { + return N->getKind() == NodeKind::IntegerLiteral; + } }; - class QualifiedName : public Node { public: @@ -573,27 +748,125 @@ namespace bolt { QualifiedName( std::vector ModulePath, Identifier* Name - ): Node(NodeType::QualifiedName), + ): Node(NodeKind::QualifiedName), ModulePath(ModulePath), Name(Name) {} Token* getFirstToken() override; Token* getLastToken() override; - void setParents() override; - SymbolPath getSymbolPath() const; - ~QualifiedName(); - }; - class TypeExpression : public Node { + class TypedNode : public Node { + protected: + + Type* Ty; + + inline TypedNode(NodeKind Kind): + Node(Kind) {} + public: - TypeExpression(NodeType Type): Node(Type) {} + inline void setType(Type* Ty2) { + Ty = Ty2; + } - ~TypeExpression(); + inline Type* getType() const noexcept { + ZEN_ASSERT(Ty != nullptr); + return Ty; + } + + }; + + class TypeExpression : public TypedNode { + protected: + + TypeExpression(NodeKind Kind): + TypedNode(Kind) {} + + }; + + class ConstraintExpression : public Node { + public: + + inline ConstraintExpression(NodeKind Kind): + Node(Kind) {} + + }; + + class VarTypeExpression; + + class TypeclassConstraintExpression : public ConstraintExpression { + public: + + Identifier* Name; + std::vector TEs; + + TypeclassConstraintExpression( + Identifier* Name, + std::vector TEs + ): ConstraintExpression(NodeKind::TypeclassConstraintExpression), + Name(Name), + TEs(TEs) {} + + Token* getFirstToken() override; + Token* getLastToken() override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::TypeclassConstraintExpression; + } + + }; + + class EqualityConstraintExpression : public ConstraintExpression { + public: + + TypeExpression* Left; + class Tilde* Tilde; + TypeExpression* Right; + + inline EqualityConstraintExpression( + TypeExpression* Left, + class Tilde* Tilde, + TypeExpression* Right + ): ConstraintExpression(NodeKind::EqualityConstraintExpression), + Left(Left), + Tilde(Tilde), + Right(Right) {} + + Token* getFirstToken() override; + Token* getLastToken() override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::EqualityConstraintExpression; + } + + }; + + class QualifiedTypeExpression : public TypeExpression { + public: + + std::vector> Constraints; + class RArrowAlt* RArrowAlt; + TypeExpression* TE; + + QualifiedTypeExpression( + std::vector> Constraints, + class RArrowAlt* RArrowAlt, + TypeExpression* TE + ): TypeExpression(NodeKind::QualifiedTypeExpression), + Constraints(Constraints), + RArrowAlt(RArrowAlt), + TE(TE) {} + + Token* getFirstToken() override; + Token* getLastToken() override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::QualifiedTypeExpression; + } }; @@ -604,16 +877,12 @@ namespace bolt { ReferenceTypeExpression( QualifiedName* Name - ): TypeExpression(NodeType::ReferenceTypeExpression), + ): TypeExpression(NodeKind::ReferenceTypeExpression), Name(Name) {} - void setParents() override; - Token* getFirstToken() override; Token* getLastToken() override; - ~ReferenceTypeExpression(); - }; class ArrowTypeExpression : public TypeExpression { @@ -625,25 +894,33 @@ namespace bolt { inline ArrowTypeExpression( std::vector ParamTypes, TypeExpression* ReturnType - ): TypeExpression(NodeType::ArrowTypeExpression), + ): TypeExpression(NodeKind::ArrowTypeExpression), ParamTypes(ParamTypes), ReturnType(ReturnType) {} - void setParents() override; - Token* getFirstToken() override; Token* getLastToken() override; - ~ArrowTypeExpression(); + }; + + class VarTypeExpression : public TypeExpression { + public: + + Identifier* Name; + + inline VarTypeExpression(Identifier* Name): + TypeExpression(NodeKind::VarTypeExpression), Name(Name) {} + + Token* getFirstToken() override; + Token* getLastToken() override; }; class Pattern : public Node { - public: + protected: - Pattern(NodeType Type): Node(Type) {} - - ~Pattern(); + inline Pattern(NodeKind Type): + Node(Type) {} }; @@ -654,24 +931,19 @@ namespace bolt { BindPattern( Identifier* Name - ): Pattern(NodeType::BindPattern), + ): Pattern(NodeKind::BindPattern), Name(Name) {} - void setParents() override; - Token* getFirstToken() override; Token* getLastToken() override; - ~BindPattern(); - }; - class Expression : public Node { - public: + class Expression : public TypedNode { + protected: - Expression(NodeType Type): Node(Type) {} - - ~Expression(); + inline Expression(NodeKind Kind): + TypedNode(Kind) {} }; @@ -682,60 +954,48 @@ namespace bolt { ReferenceExpression( QualifiedName* Name - ): Expression(NodeType::ReferenceExpression), + ): Expression(NodeKind::ReferenceExpression), Name(Name) {} - void setParents() override; - Token* getFirstToken() override; Token* getLastToken() override; - ~ReferenceExpression(); - }; class NestedExpression : public Expression { public: - LParen* LParen; + class LParen* LParen; Expression* Inner; - RParen* RParen; + class RParen* RParen; inline NestedExpression( class LParen* LParen, Expression* Inner, class RParen* RParen - ): Expression(NodeType::NestedExpression), + ): Expression(NodeKind::NestedExpression), LParen(LParen), Inner(Inner), RParen(RParen) {} - void setParents() override; - Token* getFirstToken() override; Token* getLastToken() override; - ~NestedExpression(); - }; class ConstantExpression : public Expression { public: - Token* Token; + class Token* Token; ConstantExpression( class Token* Token - ): Expression(NodeType::ConstantExpression), + ): Expression(NodeKind::ConstantExpression), Token(Token) {} - void setParents() override; - - class Token* getFirstToken() override; + class Token* getFirstToken() override; class Token* getLastToken() override; - ~ConstantExpression(); - }; class CallExpression : public Expression { @@ -747,17 +1007,13 @@ namespace bolt { CallExpression( Expression* Function, std::vector Args - ): Expression(NodeType::CallExpression), + ): Expression(NodeKind::CallExpression), Function(Function), Args(Args) {} - void setParents() override; - Token* getFirstToken() override; Token* getLastToken() override; - ~CallExpression(); - }; class InfixExpression : public Expression { @@ -768,66 +1024,53 @@ namespace bolt { Expression* RHS; InfixExpression(Expression* LHS, Token* Operator, Expression* RHS): - Expression(NodeType::InfixExpression), + Expression(NodeKind::InfixExpression), LHS(LHS), Operator(Operator), RHS(RHS) {} - void setParents() override; - Token* getFirstToken() override; Token* getLastToken() override; - ~InfixExpression(); - }; - class UnaryExpression : public Expression { + class PrefixExpression : public Expression { public: Token* Operator; Expression* Argument; - UnaryExpression( + PrefixExpression( Token* Operator, Expression* Argument - ): Expression(NodeType::UnaryExpression), + ): Expression(NodeKind::PrefixExpression), Operator(Operator), Argument(Argument) {} - void setParents() override; - Token* getFirstToken() override; Token* getLastToken() override; - ~UnaryExpression(); - }; class Statement : public Node { - public: + protected: - Statement(NodeType Type): Node(Type) {} - - ~Statement(); + inline Statement(NodeKind Type): + Node(Type) {} }; class ExpressionStatement : public Statement { public: - Expression* Expression; + class Expression* Expression; ExpressionStatement(class Expression* Expression): - Statement(NodeType::ExpressionStatement), Expression(Expression) {} - - void setParents() override; + Statement(NodeKind::ExpressionStatement), Expression(Expression) {} Token* getFirstToken() override; Token* getLastToken() override; - ~ExpressionStatement(); - }; class IfStatementPart : public Node { @@ -835,7 +1078,7 @@ namespace bolt { Token* Keyword; Expression* Test; - BlockStart* BlockStart; + class BlockStart* BlockStart; std::vector Elements; inline IfStatementPart( @@ -843,19 +1086,15 @@ namespace bolt { Expression* Test, class BlockStart* BlockStart, std::vector Elements - ): Node(NodeType::IfStatementPart), + ): Node(NodeKind::IfStatementPart), Keyword(Keyword), Test(Test), BlockStart(BlockStart), Elements(Elements) {} - void setParents() override; - Token* getFirstToken() override; Token* getLastToken() override; - ~IfStatementPart(); - }; class IfStatement : public Statement { @@ -864,129 +1103,108 @@ namespace bolt { std::vector Parts; inline IfStatement(std::vector Parts): - Statement(NodeType::IfStatement), Parts(Parts) {} - - void setParents() override; + Statement(NodeKind::IfStatement), Parts(Parts) {} Token* getFirstToken() override; Token* getLastToken() override; - ~IfStatement(); - }; class ReturnStatement : public Statement { public: - ReturnKeyword* ReturnKeyword; - Expression* Expression; + class ReturnKeyword* ReturnKeyword; + class Expression* Expression; ReturnStatement( class ReturnKeyword* ReturnKeyword, class Expression* Expression - ): Statement(NodeType::ReturnStatement), + ): Statement(NodeKind::ReturnStatement), ReturnKeyword(ReturnKeyword), Expression(Expression) {} - void setParents() override; - Token* getFirstToken() override; Token* getLastToken() override; - ~ReturnStatement(); - }; class TypeAssert : public Node { public: - Colon* Colon; - TypeExpression* TypeExpression; + class Colon* Colon; + class TypeExpression* TypeExpression; TypeAssert( class Colon* Colon, class TypeExpression* TypeExpression - ): Node(NodeType::TypeAssert), + ): Node(NodeKind::TypeAssert), Colon(Colon), TypeExpression(TypeExpression) {} - void setParents() override; - Token* getFirstToken() override; Token* getLastToken() override; - ~TypeAssert(); - }; - class Param : public Node { + class Parameter : public Node { public: - Param(Pattern* Pattern, TypeAssert* TypeAssert): Node(NodeType::Param), Pattern(Pattern), TypeAssert(TypeAssert) {} + Parameter( + class Pattern* Pattern, + class TypeAssert* TypeAssert + ): Node(NodeKind::Parameter), + Pattern(Pattern), + TypeAssert(TypeAssert) {} - Pattern* Pattern; - TypeAssert* TypeAssert; - - void setParents() override; + class Pattern* Pattern; + class TypeAssert* TypeAssert; Token* getFirstToken() override; Token* getLastToken() override; - ~Param(); - }; class LetBody : public Node { public: - LetBody(NodeType Type): Node(Type) {} - - ~LetBody(); + LetBody(NodeKind Type): Node(Type) {} }; class LetBlockBody : public LetBody { public: - BlockStart* BlockStart; + class BlockStart* BlockStart; std::vector Elements; LetBlockBody( class BlockStart* BlockStart, std::vector Elements - ): LetBody(NodeType::LetBlockBody), + ): LetBody(NodeKind::LetBlockBody), BlockStart(BlockStart), Elements(Elements) {} - void setParents() override; - Token* getFirstToken() override; Token* getLastToken() override; - ~LetBlockBody(); - }; class LetExprBody : public LetBody { public: - Equals* Equals; - Expression* Expression; + class Equals* Equals; + class Expression* Expression; LetExprBody( class Equals* Equals, class Expression* Expression - ): LetBody(NodeType::LetExprBody), + ): LetBody(NodeKind::LetExprBody), Equals(Equals), Expression(Expression) {} - void setParents() override; - Token* getFirstToken() override; Token* getLastToken() override; - ~LetExprBody(); - }; class Type; @@ -1001,12 +1219,12 @@ namespace bolt { InferContext* Ctx; class Type* Ty; - PubKeyword* PubKeyword; - LetKeyword* LetKeyword; - MutKeyword* MutKeyword; - Pattern* Pattern; - std::vector Params; - TypeAssert* TypeAssert; + class PubKeyword* PubKeyword; + class LetKeyword* LetKeyword; + class MutKeyword* MutKeyword; + class Pattern* Pattern; + std::vector Params; + class TypeAssert* TypeAssert; LetBody* Body; LetDeclaration( @@ -1014,10 +1232,10 @@ namespace bolt { class LetKeyword* LetKeywod, class MutKeyword* MutKeyword, class Pattern* Pattern, - std::vector Params, + std::vector Params, class TypeAssert* TypeAssert, LetBody* Body - ): Node(NodeType::LetDeclaration), + ): Node(NodeKind::LetDeclaration), PubKeyword(PubKeyword), LetKeyword(LetKeywod), MutKeyword(MutKeyword), @@ -1033,69 +1251,122 @@ namespace bolt { return TheScope; } - void setParents() override; + Token* getFirstToken() override; + Token* getLastToken() override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::LetDeclaration; + } + + }; + + class InstanceDeclaration : public Node { + public: + + class InstanceKeyword* InstanceKeyword; + Identifier* Name; + std::vector TypeExps; + class BlockStart* BlockStart; + std::vector Elements; + + InstanceDeclaration( + class InstanceKeyword* InstanceKeyword, + Identifier* Name, + std::vector TypeExps, + class BlockStart* BlockStart, + std::vector Elements + ): Node(NodeKind::InstanceDeclaration), + InstanceKeyword(InstanceKeyword), + Name(Name), + TypeExps(TypeExps), + BlockStart(BlockStart), + Elements(Elements) {} Token* getFirstToken() override; Token* getLastToken() override; - ~LetDeclaration(); + }; + + class ClassDeclaration : public Node { + public: + + class PubKeyword* PubKeyword; + class ClassKeyword* ClassKeyword; + Identifier* Name; + std::vector TypeVars; + class BlockStart* BlockStart; + std::vector Elements; + + ClassDeclaration( + class PubKeyword* PubKeyword, + class ClassKeyword* ClassKeyword, + Identifier* Name, + std::vector TypeVars, + class BlockStart* BlockStart, + std::vector Elements + ): Node(NodeKind::ClassDeclaration), + PubKeyword(PubKeyword), + ClassKeyword(ClassKeyword), + Name(Name), + TypeVars(TypeVars), + BlockStart(BlockStart), + Elements(Elements) {} + + Token* getFirstToken() override; + Token* getLastToken() override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::ClassDeclaration; + } }; - class StructDeclField : public Node { + class StructDeclarationField : public Node { public: - StructDeclField( + StructDeclarationField( Identifier* Name, - Colon* Colon, - TypeExpression* TypeExpression - ): Node(NodeType::StructDeclField), + class Colon* Colon, + class TypeExpression* TypeExpression + ): Node(NodeKind::StructDeclarationField), Name(Name), Colon(Colon), TypeExpression(TypeExpression) {} Identifier* Name; - Colon* Colon; - TypeExpression* TypeExpression; - - void setParents() override; + class Colon* Colon; + class TypeExpression* TypeExpression; Token* getFirstToken() override; Token* getLastToken() override; - ~StructDeclField(); - }; - class StructDecl : public Node { + class StructDeclaration : public Node { public: - PubKeyword* PubKeyword; - StructKeyword* StructKeyword; + class PubKeyword* PubKeyword; + class StructKeyword* StructKeyword; Identifier* Name; - BlockStart* BlockStart; - std::vector Fields; + class BlockStart* BlockStart; + std::vector Fields; - StructDecl( + StructDeclaration( class PubKeyword* PubKeyword, class StructKeyword* StructKeyword, Identifier* Name, class BlockStart* BlockStart, - std::vector Fields - ): Node(NodeType::StructDecl), + std::vector Fields + ): Node(NodeKind::StructDeclaration), PubKeyword(PubKeyword), StructKeyword(StructKeyword), Name(Name), BlockStart(BlockStart), Fields(Fields) {} - void setParents() override; - Token* getFirstToken() override; Token* getLastToken() override; - ~StructDecl(); - }; class SourceFile : public Node { @@ -1103,20 +1374,18 @@ namespace bolt { Scope* TheScope = nullptr; public: - + TextFile& File; std::vector Elements; SourceFile(TextFile& File, std::vector Elements): - Node(NodeType::SourceFile), File(File), Elements(Elements) {} + Node(NodeKind::SourceFile), File(File), Elements(Elements) {} inline TextFile& getTextFile() { return File; } - void setParents() override; - Token* getFirstToken() override; Token* getLastToken() override; @@ -1127,10 +1396,69 @@ namespace bolt { return TheScope; } - ~SourceFile(); - }; + template<> inline NodeKind getNodeType() { return NodeKind::Equals; } + template<> inline NodeKind getNodeType() { return NodeKind::Colon; } + template<> inline NodeKind getNodeType() { return NodeKind::Dot; } + template<> inline NodeKind getNodeType() { return NodeKind::DotDot; } + template<> inline NodeKind getNodeType() { return NodeKind::Tilde; } + template<> inline NodeKind getNodeType() { return NodeKind::LParen; } + template<> inline NodeKind getNodeType() { return NodeKind::RParen; } + template<> inline NodeKind getNodeType() { return NodeKind::LBracket; } + template<> inline NodeKind getNodeType() { return NodeKind::RBracket; } + template<> inline NodeKind getNodeType() { return NodeKind::LBrace; } + template<> inline NodeKind getNodeType() { return NodeKind::RBrace; } + template<> inline NodeKind getNodeType() { return NodeKind::RArrow; } + template<> inline NodeKind getNodeType() { return NodeKind::RArrowAlt; } + template<> inline NodeKind getNodeType() { return NodeKind::LetKeyword; } + template<> inline NodeKind getNodeType() { return NodeKind::MutKeyword; } + template<> inline NodeKind getNodeType() { return NodeKind::PubKeyword; } + template<> inline NodeKind getNodeType() { return NodeKind::TypeKeyword; } + template<> inline NodeKind getNodeType() { return NodeKind::ReturnKeyword; } + template<> inline NodeKind getNodeType() { return NodeKind::ModKeyword; } + template<> inline NodeKind getNodeType() { return NodeKind::StructKeyword; } + template<> inline NodeKind getNodeType() { return NodeKind::ClassKeyword; } + template<> inline NodeKind getNodeType() { return NodeKind::InstanceKeyword; } + template<> inline NodeKind getNodeType() { return NodeKind::ElifKeyword; } + template<> inline NodeKind getNodeType() { return NodeKind::IfKeyword; } + template<> inline NodeKind getNodeType() { return NodeKind::ElseKeyword; } + template<> inline NodeKind getNodeType() { return NodeKind::Invalid; } + template<> inline NodeKind getNodeType() { return NodeKind::EndOfFile; } + template<> inline NodeKind getNodeType() { return NodeKind::BlockStart; } + template<> inline NodeKind getNodeType() { return NodeKind::BlockEnd; } + template<> inline NodeKind getNodeType() { return NodeKind::LineFoldEnd; } + template<> inline NodeKind getNodeType() { return NodeKind::CustomOperator; } + template<> inline NodeKind getNodeType() { return NodeKind::Assignment; } + template<> inline NodeKind getNodeType() { return NodeKind::Identifier; } + template<> inline NodeKind getNodeType() { return NodeKind::StringLiteral; } + template<> inline NodeKind getNodeType() { return NodeKind::IntegerLiteral; } + template<> inline NodeKind getNodeType() { return NodeKind::QualifiedName; } + template<> inline NodeKind getNodeType() { return NodeKind::QualifiedTypeExpression; } + template<> inline NodeKind getNodeType() { return NodeKind::ReferenceTypeExpression; } + template<> inline NodeKind getNodeType() { return NodeKind::ArrowTypeExpression; } + template<> inline NodeKind getNodeType() { return NodeKind::BindPattern; } + template<> inline NodeKind getNodeType() { return NodeKind::ReferenceExpression; } + template<> inline NodeKind getNodeType() { return NodeKind::NestedExpression; } + template<> inline NodeKind getNodeType() { return NodeKind::ConstantExpression; } + template<> inline NodeKind getNodeType() { return NodeKind::CallExpression; } + template<> inline NodeKind getNodeType() { return NodeKind::InfixExpression; } + template<> inline NodeKind getNodeType() { return NodeKind::PrefixExpression; } + template<> inline NodeKind getNodeType() { return NodeKind::ExpressionStatement; } + template<> inline NodeKind getNodeType() { return NodeKind::ReturnStatement; } + template<> inline NodeKind getNodeType() { return NodeKind::IfStatement; } + template<> inline NodeKind getNodeType() { return NodeKind::IfStatementPart; } + template<> inline NodeKind getNodeType() { return NodeKind::TypeAssert; } + template<> inline NodeKind getNodeType() { return NodeKind::Parameter; } + template<> inline NodeKind getNodeType() { return NodeKind::LetBlockBody; } + template<> inline NodeKind getNodeType() { return NodeKind::LetExprBody; } + template<> inline NodeKind getNodeType() { return NodeKind::LetDeclaration; } + template<> inline NodeKind getNodeType() { return NodeKind::StructDeclarationField; } + template<> inline NodeKind getNodeType() { return NodeKind::StructDeclaration; } + template<> inline NodeKind getNodeType() { return NodeKind::ClassDeclaration; } + template<> inline NodeKind getNodeType() { return NodeKind::InstanceDeclaration; } + template<> inline NodeKind getNodeType() { return NodeKind::SourceFile; } + } #endif diff --git a/include/bolt/CSTVisitor.hpp b/include/bolt/CSTVisitor.hpp new file mode 100644 index 000000000..a2b2c572b --- /dev/null +++ b/include/bolt/CSTVisitor.hpp @@ -0,0 +1,951 @@ + +#pragma once + +#include "bolt/CST.hpp" + +namespace bolt { + + template + class CSTVisitor { + public: + + void visit(Node* N) { + switch (N->getKind()) { + case NodeKind::Equals: + return static_cast(this)->visitEquals(static_cast(N)); + case NodeKind::Colon: + return static_cast(this)->visitColon(static_cast(N)); + case NodeKind::Comma: + return static_cast(this)->visitComma(static_cast(N)); + case NodeKind::Dot: + return static_cast(this)->visitDot(static_cast(N)); + case NodeKind::DotDot: + return static_cast(this)->visitDotDot(static_cast(N)); + case NodeKind::Tilde: + return static_cast(this)->visitTilde(static_cast(N)); + case NodeKind::LParen: + return static_cast(this)->visitLParen(static_cast(N)); + case NodeKind::RParen: + return static_cast(this)->visitRParen(static_cast(N)); + case NodeKind::LBracket: + return static_cast(this)->visitLBracket(static_cast(N)); + case NodeKind::RBracket: + return static_cast(this)->visitRBracket(static_cast(N)); + case NodeKind::LBrace: + return static_cast(this)->visitLBrace(static_cast(N)); + case NodeKind::RBrace: + return static_cast(this)->visitRBrace(static_cast(N)); + case NodeKind::RArrow: + return static_cast(this)->visitRArrow(static_cast(N)); + case NodeKind::RArrowAlt: + return static_cast(this)->visitRArrowAlt(static_cast(N)); + case NodeKind::LetKeyword: + return static_cast(this)->visitLetKeyword(static_cast(N)); + case NodeKind::MutKeyword: + return static_cast(this)->visitMutKeyword(static_cast(N)); + case NodeKind::PubKeyword: + return static_cast(this)->visitPubKeyword(static_cast(N)); + case NodeKind::TypeKeyword: + return static_cast(this)->visitTypeKeyword(static_cast(N)); + case NodeKind::ReturnKeyword: + return static_cast(this)->visitReturnKeyword(static_cast(N)); + case NodeKind::ModKeyword: + return static_cast(this)->visitModKeyword(static_cast(N)); + case NodeKind::StructKeyword: + return static_cast(this)->visitStructKeyword(static_cast(N)); + case NodeKind::ClassKeyword: + return static_cast(this)->visitClassKeyword(static_cast(N)); + case NodeKind::InstanceKeyword: + return static_cast(this)->visitInstanceKeyword(static_cast(N)); + case NodeKind::ElifKeyword: + return static_cast(this)->visitElifKeyword(static_cast(N)); + case NodeKind::IfKeyword: + return static_cast(this)->visitIfKeyword(static_cast(N)); + case NodeKind::ElseKeyword: + return static_cast(this)->visitElseKeyword(static_cast(N)); + case NodeKind::Invalid: + return static_cast(this)->visitInvalid(static_cast(N)); + case NodeKind::EndOfFile: + return static_cast(this)->visitEndOfFile(static_cast(N)); + case NodeKind::BlockStart: + return static_cast(this)->visitBlockStart(static_cast(N)); + case NodeKind::BlockEnd: + return static_cast(this)->visitBlockEnd(static_cast(N)); + case NodeKind::LineFoldEnd: + return static_cast(this)->visitLineFoldEnd(static_cast(N)); + case NodeKind::CustomOperator: + return static_cast(this)->visitCustomOperator(static_cast(N)); + case NodeKind::Assignment: + return static_cast(this)->visitAssignment(static_cast(N)); + case NodeKind::Identifier: + return static_cast(this)->visitIdentifier(static_cast(N)); + case NodeKind::StringLiteral: + return static_cast(this)->visitStringLiteral(static_cast(N)); + case NodeKind::IntegerLiteral: + return static_cast(this)->visitIntegerLiteral(static_cast(N)); + case NodeKind::QualifiedName: + return static_cast(this)->visitQualifiedName(static_cast(N)); + case NodeKind::TypeclassConstraintExpression: + return static_cast(this)->visitTypeclassConstraintExpression(static_cast(N)); + case NodeKind::EqualityConstraintExpression: + return static_cast(this)->visitEqualityConstraintExpression(static_cast(N)); + case NodeKind::QualifiedTypeExpression: + return static_cast(this)->visitQualifiedTypeExpression(static_cast(N)); + case NodeKind::ReferenceTypeExpression: + return static_cast(this)->visitReferenceTypeExpression(static_cast(N)); + case NodeKind::ArrowTypeExpression: + return static_cast(this)->visitArrowTypeExpression(static_cast(N)); + case NodeKind::VarTypeExpression: + return static_cast(this)->visitVarTypeExpression(static_cast(N)); + case NodeKind::BindPattern: + return static_cast(this)->visitBindPattern(static_cast(N)); + case NodeKind::ReferenceExpression: + return static_cast(this)->visitReferenceExpression(static_cast(N)); + case NodeKind::NestedExpression: + return static_cast(this)->visitNestedExpression(static_cast(N)); + case NodeKind::ConstantExpression: + return static_cast(this)->visitConstantExpression(static_cast(N)); + case NodeKind::CallExpression: + return static_cast(this)->visitCallExpression(static_cast(N)); + case NodeKind::InfixExpression: + return static_cast(this)->visitInfixExpression(static_cast(N)); + case NodeKind::PrefixExpression: + return static_cast(this)->visitPrefixExpression(static_cast(N)); + case NodeKind::ExpressionStatement: + return static_cast(this)->visitExpressionStatement(static_cast(N)); + case NodeKind::ReturnStatement: + return static_cast(this)->visitReturnStatement(static_cast(N)); + case NodeKind::IfStatement: + return static_cast(this)->visitIfStatement(static_cast(N)); + case NodeKind::IfStatementPart: + return static_cast(this)->visitIfStatementPart(static_cast(N)); + case NodeKind::TypeAssert: + return static_cast(this)->visitTypeAssert(static_cast(N)); + case NodeKind::Parameter: + return static_cast(this)->visitParameter(static_cast(N)); + case NodeKind::LetBlockBody: + return static_cast(this)->visitLetBlockBody(static_cast(N)); + case NodeKind::LetExprBody: + return static_cast(this)->visitLetExprBody(static_cast(N)); + case NodeKind::LetDeclaration: + return static_cast(this)->visitLetDeclaration(static_cast(N)); + case NodeKind::StructDeclarationField: + return static_cast(this)->visitStructDeclarationField(static_cast(N)); + case NodeKind::StructDeclaration: + return static_cast(this)->visitStructDeclaration(static_cast(N)); + case NodeKind::ClassDeclaration: + return static_cast(this)->visitClassDeclaration(static_cast(N)); + case NodeKind::InstanceDeclaration: + return static_cast(this)->visitInstanceDeclaration(static_cast(N)); + case NodeKind::SourceFile: + return static_cast(this)->visitSourceFile(static_cast(N)); + } + } + + protected: + + void visitNode(Node* N) { + visitEachChild(N); + } + + void visitToken(Token* N) { + visitNode(N); + } + + void visitEquals(Equals* N) { + visitToken(N); + } + + void visitColon(Colon* N) { + visitToken(N); + } + + void visitComma(Comma* N) { + visitToken(N); + } + + void visitDot(Dot* N) { + visitToken(N); + } + + void visitDotDot(DotDot* N) { + visitToken(N); + } + + void visitTilde(Tilde* N) { + visitToken(N); + } + + void visitLParen(LParen* N) { + visitToken(N); + } + + void visitRParen(RParen* N) { + visitToken(N); + } + + void visitLBracket(LBracket* N) { + visitToken(N); + } + + void visitRBracket(RBracket* N) { + visitToken(N); + } + + void visitLBrace(LBrace* N) { + visitToken(N); + } + + void visitRBrace(RBrace* N) { + visitToken(N); + } + + void visitRArrow(RArrow* N) { + visitToken(N); + } + + void visitRArrowAlt(RArrowAlt* N) { + visitToken(N); + } + + void visitLetKeyword(LetKeyword* N) { + visitToken(N); + } + + void visitMutKeyword(MutKeyword* N) { + visitToken(N); + } + + void visitPubKeyword(PubKeyword* N) { + visitToken(N); + } + + void visitTypeKeyword(TypeKeyword* N) { + visitToken(N); + } + + void visitReturnKeyword(ReturnKeyword* N) { + visitToken(N); + } + + void visitModKeyword(ModKeyword* N) { + visitToken(N); + } + + void visitStructKeyword(StructKeyword* N) { + visitToken(N); + } + + void visitClassKeyword(ClassKeyword* N) { + visitToken(N); + } + + void visitInstanceKeyword(InstanceKeyword* N) { + visitToken(N); + } + + void visitElifKeyword(ElifKeyword* N) { + visitToken(N); + } + + void visitIfKeyword(IfKeyword* N) { + visitToken(N); + } + + void visitElseKeyword(ElseKeyword* N) { + visitToken(N); + } + + void visitInvalid(Invalid* N) { + visitToken(N); + } + + void visitEndOfFile(EndOfFile* N) { + visitToken(N); + } + + void visitBlockStart(BlockStart* N) { + visitToken(N); + } + + void visitBlockEnd(BlockEnd* N) { + visitToken(N); + } + + void visitLineFoldEnd(LineFoldEnd* N) { + visitToken(N); + } + + void visitCustomOperator(CustomOperator* N) { + visitToken(N); + } + + void visitAssignment(Assignment* N) { + visitToken(N); + } + + void visitIdentifier(Identifier* N) { + visitToken(N); + } + + void visitStringLiteral(StringLiteral* N) { + visitToken(N); + } + + void visitIntegerLiteral(IntegerLiteral* N) { + visitToken(N); + } + + void visitQualifiedName(QualifiedName* N) { + visitNode(N); + } + + void visitConstraintExpression(ConstraintExpression* N) { + visitNode(N); + } + + void visitTypeclassConstraintExpression(TypeclassConstraintExpression* N) { + visitConstraintExpression(N); + } + + void visitEqualityConstraintExpression(EqualityConstraintExpression* N) { + visitConstraintExpression(N); + } + + void visitTypeExpression(TypeExpression* N) { + visitNode(N); + } + + void visitQualifiedTypeExpression(QualifiedTypeExpression* N) { + visitTypeExpression(N); + } + + void visitReferenceTypeExpression(ReferenceTypeExpression* N) { + visitTypeExpression(N); + } + + void visitArrowTypeExpression(ArrowTypeExpression* N) { + visitTypeExpression(N); + } + + void visitVarTypeExpression(VarTypeExpression* N) { + visitTypeExpression(N); + } + + void visitPattern(Pattern* N) { + visitNode(N); + } + + void visitBindPattern(BindPattern* N) { + visitPattern(N); + } + + void visitExpression(Expression* N) { + visitNode(N); + } + + void visitReferenceExpression(ReferenceExpression* N) { + visitExpression(N); + } + + void visitNestedExpression(NestedExpression* N) { + visitExpression(N); + } + + void visitConstantExpression(ConstantExpression* N) { + visitExpression(N); + } + + void visitCallExpression(CallExpression* N) { + visitExpression(N); + } + + void visitInfixExpression(InfixExpression* N) { + visitExpression(N); + } + + void visitPrefixExpression(PrefixExpression* N) { + visitExpression(N); + } + + void visitStatement(Statement* N) { + visitNode(N); + } + + void visitExpressionStatement(ExpressionStatement* N) { + visitStatement(N); + } + + void visitReturnStatement(ReturnStatement* N) { + visitStatement(N); + } + + void visitIfStatement(IfStatement* N) { + visitStatement(N); + } + + void visitIfStatementPart(IfStatementPart* N) { + visitNode(N); + } + + void visitTypeAssert(TypeAssert* N) { + visitNode(N); + } + + void visitParameter(Parameter* N) { + visitNode(N); + } + + void visitLetBody(LetBody* N) { + visitNode(N); + } + + void visitLetBlockBody(LetBlockBody* N) { + visitLetBody(N); + } + + void visitLetExprBody(LetExprBody* N) { + visitLetBody(N); + } + + void visitLetDeclaration(LetDeclaration* N) { + visitNode(N); + } + + void visitStructDeclarationField(StructDeclarationField* N) { + visitNode(N); + } + + void visitStructDeclaration(StructDeclaration* N) { + visitNode(N); + } + + void visitClassDeclaration(ClassDeclaration* N) { + visitNode(N); + } + + void visitInstanceDeclaration(InstanceDeclaration* N) { + visitNode(N); + } + + void visitSourceFile(SourceFile* N) { + visitNode(N); + } + + public: + + void visitEachChild(Node* N) { + switch (N->getKind()) { + case NodeKind::Equals: + visitEachChild(static_cast(N)); + break; + case NodeKind::Colon: + visitEachChild(static_cast(N)); + break; + case NodeKind::Comma: + visitEachChild(static_cast(N)); + break; + case NodeKind::Dot: + visitEachChild(static_cast(N)); + break; + case NodeKind::DotDot: + visitEachChild(static_cast(N)); + break; + case NodeKind::Tilde: + visitEachChild(static_cast(N)); + break; + case NodeKind::LParen: + visitEachChild(static_cast(N)); + break; + case NodeKind::RParen: + visitEachChild(static_cast(N)); + break; + case NodeKind::LBracket: + visitEachChild(static_cast(N)); + break; + case NodeKind::RBracket: + visitEachChild(static_cast(N)); + break; + case NodeKind::LBrace: + visitEachChild(static_cast(N)); + break; + case NodeKind::RBrace: + visitEachChild(static_cast(N)); + break; + case NodeKind::RArrow: + visitEachChild(static_cast(N)); + break; + case NodeKind::RArrowAlt: + visitEachChild(static_cast(N)); + break; + case NodeKind::LetKeyword: + visitEachChild(static_cast(N)); + break; + case NodeKind::MutKeyword: + visitEachChild(static_cast(N)); + break; + case NodeKind::PubKeyword: + visitEachChild(static_cast(N)); + break; + case NodeKind::TypeKeyword: + visitEachChild(static_cast(N)); + break; + case NodeKind::ReturnKeyword: + visitEachChild(static_cast(N)); + break; + case NodeKind::ModKeyword: + visitEachChild(static_cast(N)); + break; + case NodeKind::StructKeyword: + visitEachChild(static_cast(N)); + break; + case NodeKind::ClassKeyword: + visitEachChild(static_cast(N)); + break; + case NodeKind::InstanceKeyword: + visitEachChild(static_cast(N)); + break; + case NodeKind::ElifKeyword: + visitEachChild(static_cast(N)); + break; + case NodeKind::IfKeyword: + visitEachChild(static_cast(N)); + break; + case NodeKind::ElseKeyword: + visitEachChild(static_cast(N)); + break; + case NodeKind::Invalid: + visitEachChild(static_cast(N)); + break; + case NodeKind::EndOfFile: + visitEachChild(static_cast(N)); + break; + case NodeKind::BlockStart: + visitEachChild(static_cast(N)); + break; + case NodeKind::BlockEnd: + visitEachChild(static_cast(N)); + break; + case NodeKind::LineFoldEnd: + visitEachChild(static_cast(N)); + break; + case NodeKind::CustomOperator: + visitEachChild(static_cast(N)); + break; + case NodeKind::Assignment: + visitEachChild(static_cast(N)); + break; + case NodeKind::Identifier: + visitEachChild(static_cast(N)); + break; + case NodeKind::StringLiteral: + visitEachChild(static_cast(N)); + break; + case NodeKind::IntegerLiteral: + visitEachChild(static_cast(N)); + break; + case NodeKind::QualifiedName: + visitEachChild(static_cast(N)); + break; + case NodeKind::TypeclassConstraintExpression: + visitEachChild(static_cast(N)); + break; + case NodeKind::EqualityConstraintExpression: + visitEachChild(static_cast(N)); + break; + case NodeKind::QualifiedTypeExpression: + visitEachChild(static_cast(N)); + break; + case NodeKind::ReferenceTypeExpression: + visitEachChild(static_cast(N)); + break; + case NodeKind::ArrowTypeExpression: + visitEachChild(static_cast(N)); + break; + case NodeKind::VarTypeExpression: + visitEachChild(static_cast(N)); + break; + case NodeKind::BindPattern: + visitEachChild(static_cast(N)); + break; + case NodeKind::ReferenceExpression: + visitEachChild(static_cast(N)); + break; + case NodeKind::NestedExpression: + visitEachChild(static_cast(N)); + break; + case NodeKind::ConstantExpression: + visitEachChild(static_cast(N)); + break; + case NodeKind::CallExpression: + visitEachChild(static_cast(N)); + break; + case NodeKind::InfixExpression: + visitEachChild(static_cast(N)); + break; + case NodeKind::PrefixExpression: + visitEachChild(static_cast(N)); + break; + case NodeKind::ExpressionStatement: + visitEachChild(static_cast(N)); + break; + case NodeKind::ReturnStatement: + visitEachChild(static_cast(N)); + break; + case NodeKind::IfStatement: + visitEachChild(static_cast(N)); + break; + case NodeKind::IfStatementPart: + visitEachChild(static_cast(N)); + break; + case NodeKind::TypeAssert: + visitEachChild(static_cast(N)); + break; + case NodeKind::Parameter: + visitEachChild(static_cast(N)); + break; + case NodeKind::LetBlockBody: + visitEachChild(static_cast(N)); + break; + case NodeKind::LetExprBody: + visitEachChild(static_cast(N)); + break; + case NodeKind::LetDeclaration: + visitEachChild(static_cast(N)); + break; + case NodeKind::StructDeclaration: + visitEachChild(static_cast(N)); + break; + case NodeKind::StructDeclarationField: + visitEachChild(static_cast(N)); + break; + case NodeKind::ClassDeclaration: + visitEachChild(static_cast(N)); + break; + case NodeKind::InstanceDeclaration: + visitEachChild(static_cast(N)); + break; + case NodeKind::SourceFile: + visitEachChild(static_cast(N)); + break; + default: + ZEN_UNREACHABLE + } + } + +#define BOLT_VISIT(node) static_cast(this)->visit(node) + + void visitEachChild(Equals* N) { + } + + void visitEachChild(Colon* N) { + } + + void visitEachChild(Comma* N) { + } + + void visitEachChild(Dot* N) { + } + + void visitEachChild(DotDot* N) { + } + + void visitEachChild(Tilde* N) { + } + + void visitEachChild(LParen* N) { + } + + void visitEachChild(RParen* N) { + } + + void visitEachChild(LBracket* N) { + } + + void visitEachChild(RBracket* N) { + } + + void visitEachChild(LBrace* N) { + } + + void visitEachChild(RBrace* N) { + } + + void visitEachChild(RArrow* N) { + } + + void visitEachChild(RArrowAlt* N) { + } + + void visitEachChild(LetKeyword* N) { + } + + void visitEachChild(MutKeyword* N) { + } + + void visitEachChild(PubKeyword* N) { + } + + void visitEachChild(TypeKeyword* N) { + } + + void visitEachChild(ReturnKeyword* N) { + } + + void visitEachChild(ModKeyword* N) { + } + + void visitEachChild(StructKeyword* N) { + } + + void visitEachChild(ClassKeyword* N) { + } + + void visitEachChild(InstanceKeyword* N) { + } + + void visitEachChild(ElifKeyword* N) { + } + + void visitEachChild(IfKeyword* N) { + } + + void visitEachChild(ElseKeyword* N) { + } + + void visitEachChild(Invalid* N) { + } + + void visitEachChild(EndOfFile* N) { + } + + void visitEachChild(BlockStart* N) { + } + + void visitEachChild(BlockEnd* N) { + } + + void visitEachChild(LineFoldEnd* N) { + } + + void visitEachChild(CustomOperator* N) { + } + + void visitEachChild(Assignment* N) { + } + + void visitEachChild(Identifier* N) { + } + + void visitEachChild(StringLiteral* N) { + } + + void visitEachChild(IntegerLiteral* N) { + } + + void visitEachChild(QualifiedName* N) { + for (auto Name: N->ModulePath) { + BOLT_VISIT(Name); + } + BOLT_VISIT(N->Name); + } + + void visitEachChild(TypeclassConstraintExpression* N) { + BOLT_VISIT(N->Name); + for (auto TE: N->TEs) { + BOLT_VISIT(TE); + } + } + + void visitEachChild(EqualityConstraintExpression* N) { + BOLT_VISIT(N->Left); + BOLT_VISIT(N->Tilde); + BOLT_VISIT(N->Right); + } + + void visitEachChild(QualifiedTypeExpression* N) { + for (auto [CE, Comma]: N->Constraints) { + BOLT_VISIT(CE); + if (Comma) { + BOLT_VISIT(Comma); + } + } + BOLT_VISIT(N->RArrowAlt); + BOLT_VISIT(N->TE); + } + + void visitEachChild(ReferenceTypeExpression* N) { + BOLT_VISIT(N->Name); + } + + void visitEachChild(ArrowTypeExpression* N) { + for (auto PT: N->ParamTypes) { + BOLT_VISIT(PT); + } + BOLT_VISIT(N->ReturnType); + } + + void visitEachChild(VarTypeExpression* N) { + BOLT_VISIT(N->Name); + } + + void visitEachChild(BindPattern* N) { + BOLT_VISIT(N->Name); + } + + void visitEachChild(ReferenceExpression* N) { + BOLT_VISIT(N->Name); + } + + void visitEachChild(NestedExpression* N) { + BOLT_VISIT(N->LParen); + BOLT_VISIT(N->Inner); + BOLT_VISIT(N->RParen); + } + + void visitEachChild(ConstantExpression* N) { + BOLT_VISIT(N->Token); + } + + void visitEachChild(CallExpression* N) { + BOLT_VISIT(N->Function); + for (auto Arg: N->Args) { + BOLT_VISIT(Arg); + } + } + + void visitEachChild(InfixExpression* N) { + BOLT_VISIT(N->LHS); + BOLT_VISIT(N->Operator); + BOLT_VISIT(N->RHS); + } + + void visitEachChild(PrefixExpression* N) { + BOLT_VISIT(N->Operator); + BOLT_VISIT(N->Argument); + } + + void visitEachChild(ExpressionStatement* N) { + BOLT_VISIT(N->Expression); + } + + void visitEachChild(ReturnStatement* N) { + BOLT_VISIT(N->ReturnKeyword); + BOLT_VISIT(N->Expression); + } + + void visitEachChild(IfStatement* N) { + for (auto Part: N->Parts) { + BOLT_VISIT(Part); + } + } + + void visitEachChild(IfStatementPart* N) { + BOLT_VISIT(N->Keyword); + if (N->Test != nullptr) { + BOLT_VISIT(N->Test); + } + BOLT_VISIT(N->BlockStart); + for (auto Element: N->Elements) { + BOLT_VISIT(Element); + } + } + + void visitEachChild(TypeAssert* N) { + BOLT_VISIT(N->Colon); + BOLT_VISIT(N->TypeExpression); + } + + void visitEachChild(Parameter* N) { + BOLT_VISIT(N->Pattern); + if (N->TypeAssert != nullptr) { + BOLT_VISIT(N->TypeAssert); + } + } + + void visitEachChild(LetBlockBody* N) { + BOLT_VISIT(N->BlockStart); + for (auto Element: N->Elements) { + BOLT_VISIT(Element); + } + } + + void visitEachChild(LetExprBody* N) { + BOLT_VISIT(N->Equals); + BOLT_VISIT(N->Expression); + } + + void visitEachChild(LetDeclaration* N) { + if (N->PubKeyword) { + BOLT_VISIT(N->PubKeyword); + } + BOLT_VISIT(N->LetKeyword); + if (N->MutKeyword) { + BOLT_VISIT(N->MutKeyword); + } + BOLT_VISIT(N->Pattern); + for (auto Param: N->Params) { + BOLT_VISIT(Param); + } + if (N->TypeAssert) { + BOLT_VISIT(N->TypeAssert); + } + if (N->Body) { + BOLT_VISIT(N->Body); + } + } + + void visitEachChild(StructDeclarationField* N) { + BOLT_VISIT(N->Name); + BOLT_VISIT(N->Colon); + BOLT_VISIT(N->TypeExpression); + } + + void visitEachChild(StructDeclaration* N) { + if (N->PubKeyword) { + BOLT_VISIT(N->PubKeyword); + } + BOLT_VISIT(N->StructKeyword); + BOLT_VISIT(N->Name); + BOLT_VISIT(N->StructKeyword); + for (auto Field: N->Fields) { + BOLT_VISIT(Field); + } + } + + void visitEachChild(ClassDeclaration* N) { + if (N->PubKeyword) { + BOLT_VISIT(N->PubKeyword); + } + BOLT_VISIT(N->ClassKeyword); + BOLT_VISIT(N->Name); + for (auto Name: N->TypeVars) { + BOLT_VISIT(Name); + } + BOLT_VISIT(N->BlockStart); + for (auto Element: N->Elements) { + BOLT_VISIT(Element); + } + } + + void visitEachChild(InstanceDeclaration* N) { + BOLT_VISIT(N->InstanceKeyword); + BOLT_VISIT(N->Name); + for (auto TE: N->TypeExps) { + BOLT_VISIT(TE); + } + BOLT_VISIT(N->BlockStart); + for (auto Element: N->Elements) { + BOLT_VISIT(Element); + } + } + + void visitEachChild(SourceFile* N) { + for (auto Element: N->Elements) { + BOLT_VISIT(Element); + } + } + + }; + +} diff --git a/include/bolt/Checker.hpp b/include/bolt/Checker.hpp index 76250904f..03c54f074 100644 --- a/include/bolt/Checker.hpp +++ b/include/bolt/Checker.hpp @@ -5,6 +5,7 @@ #include "bolt/ByteString.hpp" #include "bolt/CST.hpp" +#include "bolt/Diagnostics.hpp" #include #include @@ -14,6 +15,30 @@ namespace bolt { + class LanguageConfig { + + enum ConfigFlags { + ConfigFlags_TypeVarsRequireForall = 1 << 0, + }; + + unsigned Flags; + + public: + + void setTypeVarsRequireForall(bool Enable) { + if (Enable) { + Flags |= ConfigFlags_TypeVarsRequireForall; + } else { + Flags |= ~ConfigFlags_TypeVarsRequireForall; + } + } + + bool typeVarsRequireForall() const noexcept { + return Flags & ConfigFlags_TypeVarsRequireForall; + } + + }; + class DiagnosticEngine; class Node; @@ -23,11 +48,12 @@ namespace bolt { using TVSub = std::unordered_map; using TVSet = std::unordered_set; + using TypeclassContext = std::unordered_set; + enum class TypeKind : unsigned char { Var, Con, Arrow, - Any, Tuple, }; @@ -70,15 +96,45 @@ namespace bolt { inline TCon(const size_t Id, std::vector Args, ByteString DisplayName): Type(TypeKind::Con), Id(Id), Args(Args), DisplayName(DisplayName) {} + static bool classof(const Type* Ty) { + return Ty->getKind() == TypeKind::Con; + } + + }; + + enum class VarKind { + Rigid, + Unification, }; class TVar : public Type { public: const size_t Id; + VarKind VK; - inline TVar(size_t Id): - Type(TypeKind::Var), Id(Id) {} + TypeclassContext Contexts; + + inline TVar(size_t Id, VarKind VK): + Type(TypeKind::Var), Id(Id), VK(VK) {} + + inline VarKind getVarKind() const noexcept { + return VK; + } + + static bool classof(const Type* Ty) { + return Ty->getKind() == TypeKind::Var; + } + + }; + + class TVarRigid : public TVar { + public: + + ByteString Name; + + inline TVarRigid(size_t Id, ByteString Name): + TVar(Id, VarKind::Rigid), Name(Name) {} }; @@ -95,6 +151,10 @@ namespace bolt { ParamTypes(ParamTypes), ReturnType(ReturnType) {} + static bool classof(const Type* Ty) { + return Ty->getKind() == TypeKind::Arrow; + } + }; class TTuple : public Type { @@ -105,13 +165,9 @@ namespace bolt { inline TTuple(std::vector ElementTypes): Type(TypeKind::Tuple), ElementTypes(ElementTypes) {} - }; - - class TAny : public Type { - public: - - inline TAny(): - Type(TypeKind::Any) {} + static bool classof(const Type* Ty) { + return Ty->getKind() == TypeKind::Tuple; + } }; @@ -126,26 +182,6 @@ namespace bolt { using ConstraintSet = std::vector; - class Forall { - public: - - TVSet* TVs; - ConstraintSet* Constraints; - Type* Type; - - inline Forall(class Type* Type): - TVs(new TVSet), Constraints(new ConstraintSet), Type(Type) {} - - inline Forall( - TVSet& TVs, - ConstraintSet& Constraints, - class Type* Type - ): TVs(&TVs), - Constraints(&Constraints), - Type(Type) {} - - }; - enum class SchemeKind : unsigned char { Forall, }; @@ -154,61 +190,102 @@ namespace bolt { const SchemeKind Kind; - union { - Forall F; - }; + protected: + + inline Scheme(SchemeKind Kind): + Kind(Kind) {} public: - inline Scheme(Forall F): - Kind(SchemeKind::Forall), F(F) {} - - inline Scheme(const Scheme& Other): - Kind(Other.Kind) { - switch (Kind) { - case SchemeKind::Forall: - F = Other.F; - break; - } - } - - - inline Scheme(Scheme&& Other): - Kind(std::move(Other.Kind)) { - switch (Kind) { - case SchemeKind::Forall: - F = std::move(Other.F); - break; - } - } - - template - T& as(); - - template<> - Forall& as() { - ZEN_ASSERT(Kind == SchemeKind::Forall); - return F; - } - inline SchemeKind getKind() const noexcept { return Kind; } - ~Scheme() { - switch (Kind) { - case SchemeKind::Forall: - F.~Forall(); - break; - } + virtual ~Scheme() {} + + }; + + class Forall : public Scheme { + public: + + TVSet* TVs; + ConstraintSet* Constraints; + class Type* Type; + + inline Forall(class Type* Type): + Scheme(SchemeKind::Forall), TVs(new TVSet), Constraints(new ConstraintSet), Type(Type) {} + + inline Forall( + TVSet* TVs, + ConstraintSet* Constraints, + class Type* Type + ): Scheme(SchemeKind::Forall), + TVs(TVs), + Constraints(Constraints), + Type(Type) {} + + static bool classof(const Scheme* Scm) { + return Scm->getKind() == SchemeKind::Forall; } }; - using TypeEnv = std::unordered_map; +/* class Scheme { */ + +/* const SchemeKind Kind; */ + +/* public: */ + +/* inline Scheme(Forall F): */ +/* Kind(SchemeKind::Forall), F(F) {} */ + +/* inline Scheme(const Scheme& Other): */ +/* Kind(Other.Kind) { */ +/* switch (Kind) { */ +/* case SchemeKind::Forall: */ +/* F = Other.F; */ +/* break; */ +/* } */ +/* } */ + + +/* inline Scheme(Scheme&& Other): */ +/* Kind(std::move(Other.Kind)) { */ +/* switch (Kind) { */ +/* case SchemeKind::Forall: */ +/* F = std::move(Other.F); */ +/* break; */ +/* } */ +/* } */ + +/* inline SchemeKind getKind() const noexcept { */ +/* return Kind; */ +/* } */ + +/* template */ +/* T& as(); */ + +/* template<> */ +/* Forall& as() { */ +/* ZEN_ASSERT(Kind == SchemeKind::Forall); */ +/* return F; */ +/* } */ + +/* ~Scheme() { */ +/* switch (Kind) { */ +/* case SchemeKind::Forall: */ +/* F.~Forall(); */ +/* break; */ +/* } */ +/* } */ + +/* }; */ + + using TypeEnv = std::unordered_map; enum class ConstraintKind { Equal, + Class, Many, Empty, }; @@ -249,8 +326,8 @@ namespace bolt { ConstraintSet& Elements; - inline CMany(ConstraintSet& Constraints): - Constraint(ConstraintKind::Many), Elements(Constraints) {} + inline CMany(ConstraintSet& Elements): + Constraint(ConstraintKind::Many), Elements(Elements) {} }; @@ -262,32 +339,76 @@ namespace bolt { }; - class InferContext { + class CClass : public Constraint { public: - TVSet TVs; - ConstraintSet Constraints; - TypeEnv Env; - Type* ReturnType; + ByteString Name; + std::vector Types; - InferContext* Parent; + inline CClass(ByteString Name, std::vector Types): + Constraint(ConstraintKind::Class), Name(Name), Types(Types) {} + + }; + + enum { + /** + * Indicates that the typing environment of the current context will not + * hold on to any bindings. + * + * Concretely, bindings that are assigned fall through to the parent + * context, where this process is repeated until an environment is found + * that is not pervious. + */ + InferContextFlags_PerviousEnv = 1 << 0, + }; + + using InferContextFlagsMask = unsigned; + + class InferContext { + + InferContextFlagsMask Flags = 0; + + public: + + /** + * A heap-allocated list of type variables that eventually will become part of a Forall scheme. + */ + TVSet* TVs; + + /** + * A heap-allocated list of constraints that eventually will become part of a Forall scheme. + */ + ConstraintSet* Constraints; + + TypeEnv Env; + + Type* ReturnType = nullptr; + std::vector Classes; + + inline void setIsEnvPervious(bool Enable) noexcept { + if (Enable) { + Flags |= InferContextFlags_PerviousEnv; + } else { + Flags &= ~InferContextFlags_PerviousEnv; + } + } + + inline bool isEnvPervious() const noexcept { + return Flags & InferContextFlags_PerviousEnv; + } //inline InferContext(InferContext* Parent, TVSet& TVs, ConstraintSet& Constraints, TypeEnv& Env, Type* ReturnType): // Parent(Parent), TVs(TVs), Constraints(Constraints), Env(Env), ReturnType(ReturnType) {} - inline InferContext(InferContext* Parent = nullptr): - Parent(Parent), ReturnType(nullptr) {} - }; class Checker { + const LanguageConfig& Config; DiagnosticEngine& DE; - size_t nextConTypeId = 0; - size_t nextTypeVarId = 0; - - std::unordered_map Mapping; + size_t NextConTypeId = 0; + size_t NextTypeVarId = 0; std::unordered_map CallGraph; @@ -295,44 +416,83 @@ namespace bolt { Type* IntType; Type* StringType; + TVSub Solution; + std::vector Contexts; + /** + * Holds the current inferred type class contexts in a given LetDeclaration body. + */ + // std::vector TCCs; + + InferContext& getContext(); + void addConstraint(Constraint* Constraint); + void addClass(TypeclassSignature Sig); void forwardDeclare(Node* Node); Type* inferExpression(Expression* Expression); Type* inferTypeExpression(TypeExpression* TE); - void inferBindings(Pattern* Pattern, Type* T, ConstraintSet& Constraints, TVSet& Tvs); + void inferBindings(Pattern* Pattern, Type* T, ConstraintSet* Constraints, TVSet* TVs); + void inferBindings(Pattern* Pattern, Type* T); void infer(Node* node); + Constraint* convertToConstraint(ConstraintExpression* C); + TCon* createPrimConType(); - TVar* createTypeVar(); + TVarRigid* createRigidVar(ByteString Name); + InferContext* createInferContext(); - void addBinding(ByteString Name, Scheme Scm); + void addBinding(ByteString Name, Scheme* Scm); + Scheme* lookup(ByteString Name); + + /** + * Looks up a type/variable and ensures that it is a monomorphic type. + * + * This method is mainly syntactic sugar to make it clear in the code when a + * monomorphic type is expected. + * + * Note that if the type is not monomorphic the program will abort with a + * stack trace. It wil **not** print a user-friendly error message. + * + * \returns If the type/variable could not be found `nullptr` is returned. + * Otherwise, a [Type] is returned. + */ Type* lookupMono(ByteString Name); InferContext* lookupCall(Node* Source, SymbolPath Path); + /** + * Get the return type for the current context. If none could be found, the program will abort. + */ Type* getReturnType(); - Scheme* lookup(ByteString Name); + Type* instantiate(Scheme* S, Node* Source); - Type* instantiate(Scheme& S, Node* Source); + /* void addToTypeclassContexts(Node* N, std::vector& Contexts); */ - bool unify(Type* A, Type* B, TVSub& Solution); + std::unordered_map> InstanceMap; + std::vector findInstanceContext(TCon* Ty, TypeclassId& Class, Node* Source); + void propagateClasses(TypeclassContext& Classes, Type* Ty, Node* Source); + void propagateClassTycon(TypeclassId& Class, TCon* Ty, Node* Source); + void checkTypeclassSigs(Node* N); + + bool unify(Type* A, Type* B, Node* Source); + + void solveCEqual(CEqual* C); void solve(Constraint* Constraint, TVSub& Solution); public: - Checker(DiagnosticEngine& DE); + Checker(const LanguageConfig& Config, DiagnosticEngine& DE); - TVSub check(SourceFile* SF); + void check(SourceFile* SF); inline Type* getBoolType() { return BoolType; @@ -346,7 +506,7 @@ namespace bolt { return IntType; } - Type* getType(Node* Node, const TVSub& Solution); + Type* getType(TypedNode* Node); }; diff --git a/include/bolt/Diagnostics.hpp b/include/bolt/Diagnostics.hpp index 59d357ced..e11a1c773 100644 --- a/include/bolt/Diagnostics.hpp +++ b/include/bolt/Diagnostics.hpp @@ -13,12 +13,30 @@ namespace bolt { class Type; + class TCon; + class TVar; + + using TypeclassId = ByteString; + + struct TypeclassSignature { + + using TypeclassId = ByteString; + TypeclassId Id; + std::vector Params; + + bool operator<(const TypeclassSignature& Other) const; + bool operator==(const TypeclassSignature& Other) const; + + }; enum class DiagnosticKind : unsigned char { UnexpectedToken, UnexpectedString, BindingNotFound, UnificationError, + TypeclassMissing, + InstanceNotFound, + ClassNotFound, }; class Diagnostic : std::runtime_error { @@ -31,7 +49,7 @@ namespace bolt { public: - DiagnosticKind getKind() const noexcept { + inline DiagnosticKind getKind() const noexcept { return Kind; } @@ -42,9 +60,9 @@ namespace bolt { TextFile& File; Token* Actual; - std::vector Expected; + std::vector Expected; - inline UnexpectedTokenDiagnostic(TextFile& File, Token* Actual, std::vector Expected): + inline UnexpectedTokenDiagnostic(TextFile& File, Token* Actual, std::vector Expected): Diagnostic(DiagnosticKind::UnexpectedToken), File(File), Actual(Actual), Expected(Expected) {} }; @@ -84,6 +102,39 @@ namespace bolt { }; + class TypeclassMissingDiagnostic : public Diagnostic { + public: + + TypeclassSignature Sig; + LetDeclaration* Decl; + + inline TypeclassMissingDiagnostic(TypeclassSignature Sig, LetDeclaration* Decl): + Diagnostic(DiagnosticKind::TypeclassMissing), Sig(Sig), Decl(Decl) {} + + }; + + class InstanceNotFoundDiagnostic : public Diagnostic { + public: + + ByteString TypeclassName; + TCon* Ty; + Node* Source; + + inline InstanceNotFoundDiagnostic(ByteString TypeclassName, TCon* Ty, Node* Source): + Diagnostic(DiagnosticKind::InstanceNotFound), TypeclassName(TypeclassName), Ty(Ty), Source(Source) {} + + }; + + class ClassNotFoundDiagnostic : public Diagnostic { + public: + + ByteString Name; + + inline ClassNotFoundDiagnostic(ByteString Name): + Diagnostic(DiagnosticKind::ClassNotFound), Name(Name) {} + + }; + class DiagnosticEngine { protected: diff --git a/include/bolt/Parser.hpp b/include/bolt/Parser.hpp index 50a0ecc75..2131a4e53 100644 --- a/include/bolt/Parser.hpp +++ b/include/bolt/Parser.hpp @@ -2,9 +2,10 @@ #pragma once #include -#include +#include #include "bolt/CST.hpp" +#include "bolt/Stream.hpp" namespace bolt { @@ -68,14 +69,24 @@ namespace bolt { Token* peekFirstTokenAfterModifiers(); - Token* expectToken(NodeType Ty); + Token* expectToken(NodeKind Ty); + + template + T* expectToken() { + return static_cast(expectToken(getNodeType())); + } Expression* parseInfixOperatorAfterExpression(Expression* LHS, int MinPrecedence); - TypeExpression* parsePrimitiveTypeExpression(); - Expression* parsePrimitiveExpression(); + ConstraintExpression* parseConstraintExpression(); + + TypeExpression* parsePrimitiveTypeExpression(); + TypeExpression* parseQualifiedTypeExpression(); + TypeExpression* parseArrowTypeExpression(); + VarTypeExpression* parseVarTypeExpression(); + public: Parser(TextFile& File, Stream& S); @@ -86,7 +97,7 @@ namespace bolt { Pattern* parsePattern(); - Param* parseParam(); + Parameter* parseParam(); ReferenceExpression* parseReferenceExpression(); @@ -106,6 +117,12 @@ namespace bolt { LetDeclaration* parseLetDeclaration(); + Node* parseClassElement(); + + ClassDeclaration* parseClassDeclaration(); + + InstanceDeclaration* parseInstanceDeclaration(); + Node* parseSourceElement(); SourceFile* parseSourceFile(); diff --git a/include/bolt/Scanner.hpp b/include/bolt/Scanner.hpp index 435a5c251..40fce9a4b 100644 --- a/include/bolt/Scanner.hpp +++ b/include/bolt/Scanner.hpp @@ -8,78 +8,12 @@ #include "bolt/Text.hpp" #include "bolt/String.hpp" +#include "bolt/Stream.hpp" namespace bolt { class Token; - template - class Stream { - public: - - virtual T get() = 0; - virtual T peek(std::size_t Offset = 0) = 0; - - virtual ~Stream() {} - - }; - - template - class VectorStream : public Stream { - public: - - using value_type = T; - - ContainerT& Data; - value_type Sentry; - std::size_t Offset; - - VectorStream(ContainerT& Data, value_type Sentry, std::size_t Offset = 0): - Data(Data), Sentry(Sentry), Offset(Offset) {} - - value_type get() override { - return Offset < Data.size() ? Data[Offset++] : Sentry; - } - - value_type peek(std::size_t Offset2) override { - auto I = Offset + Offset2; - return I < Data.size() ? Data[I] : Sentry; - } - - }; - - template - class BufferedStream : public Stream { - - std::deque Buffer; - - protected: - - virtual T read() = 0; - - public: - - using value_type = T; - - value_type get() override { - if (Buffer.empty()) { - return read(); - } else { - auto Keep = Buffer.front(); - Buffer.pop_front(); - return Keep; - } - } - - value_type peek(std::size_t Offset = 0) override { - while (Buffer.size() <= Offset) { - Buffer.push_back(read()); - } - return Buffer[Offset]; - } - - }; - class Scanner : public BufferedStream { TextFile& File; diff --git a/include/bolt/Text.hpp b/include/bolt/Text.hpp index a92a438bf..5a7921f95 100644 --- a/include/bolt/Text.hpp +++ b/include/bolt/Text.hpp @@ -36,8 +36,7 @@ namespace bolt { }; - class TextRange { - public: + struct TextRange { TextLoc Start; TextLoc End; }; diff --git a/scripts/CST.cc.tply b/scripts/CST.cc.tply deleted file mode 100644 index a308eb6bb..000000000 --- a/scripts/CST.cc.tply +++ /dev/null @@ -1,66 +0,0 @@ -{! -root_node_name = root_node.name - -def gen_cpp_unref(expr, ty): - if isinstance(ty, NodeType): - return f'{expr}->unref();\n' - elif isinstance(ty, ListType): - dtor = gen_cpp_unref('Element', ty.element_type) - if dtor: - out = '' - out += f'for (auto& Element: {expr})' - out += '{\n' - out += dtor - out += '}\n' - return out - elif isinstance(ty, OptionalType): - if is_type_optional_by_default(ty.element_type): - element_expr = expr - else: - element_expr = f'(*{expr})' - dtor = gen_cpp_unref(element_expr, ty.element_type) - if dtor: - out = '' - out += 'if (' - out += expr - out += ') {\n' - out += dtor - out += '}\n' - return out - elif isinstance(ty, RawType): - pass # field should be destroyed by class - else: - raise RuntimeError(f'unexpected {ty}') - -!} -#include "{{include_path}}/{{name}}.hpp" - -{% for namespace in namespaces %} -namespace {{namespace}} { {! indent() !} -{% endfor %} - -{{root_node_name}}:~{{root_node_name}}() {} - -SourceFile* {root_node.name}::getSourceFile() { - auto CurrNode = this; - for (;;) { - if (CurrNode->Type == NodeType::SourceFile) { - return static_cast(this); - } - CurrNode = CurrNode->Parent; - ZEN_ASSERT(CurrNode != nullptr); - } -} - -{% for node in nodes %} - {{node.name}}::~{{node.name}}() { - {% for name, ty in node.fields %} - {{gen_cpp_unref(name, ty)} - {% endfor %} - } -{% endfor %} - -{% for namespace in namespaces %} -} {! dedent() !} -{% endfor %} - diff --git a/scripts/CST.hpp.tply b/scripts/CST.hpp.tply deleted file mode 100644 index 38398846a..000000000 --- a/scripts/CST.hpp.tply +++ /dev/null @@ -1,118 +0,0 @@ -{! -macro_prefix = '_'.join(namespaces).upper() + '_' -variant_name = root_node_name + 'Type' -!} - -#pragma once - -{% for namespace in namespaces %} -namespace {{namespace}} { {! indent() !} -{% endfor %} - -class {{base_node.name}}; - -class {{root_node_name}} { - - unsigned RefCount = 0; - - {{root_node_name}}* Parent = nullptr; - -public: - - inline void ref() { - ++RefCount; - } - - inline void unref() { - --RefCount; - if (RefCount == 0) { - delete this; - } - } - - const {{variant_name}} Type; - - inline {{root_node_name}}({{variant_name}}Type): - Type(Type) {} - - {{base_node.name}}* get{{base_node.name}}(); - - virtual void setParents(); - - virtual ~Node(); - -}; - -{% for node in nodes %} -{! - -def gen_cpp_ctor_params(out, node): - visited = set() - queue = deque([ node ]) - is_leaf = not graph.has_children(node.name) - first = True - if not is_leaf: - out.write(f"{cpp_root_node_name}Type Type") - first = False - while queue: - node = queue.popleft() - if node.name in visited: - return - visited.add(node.name) - for member in node.members: - if first: - first = False - else: - out.write(', ') - out.write(gen_cpp_type_expr(member.type_expr.type)) - out.write(' ') - out.write(camel_case(member.name)) - for parent in node.parents: - queue.append(types[parent]) - -def gen_cpp_ctor_args(out, orig_node: NodeDecl): - first = True - is_leaf = not graph.has_children(orig_node.name) - if orig_node.parents: - for parent in orig_node.parents: - if first: - first = False - else: - out.write(', ') - node = types[parent] - refs = '' - if is_leaf: - refs += f"{cpp_root_node_name}Type::{orig_node.name}" - else: - refs += 'Type' - for member in node.members: - refs += f", {camel_case(member.name)}" - out.write(f"{prefix}{node.name}({refs})") - else: - if is_leaf: - out.write(f"{cpp_root_node_name}({cpp_root_node_name}Type::{orig_node.name})") - else: - out.write(f"{cpp_root_node_name}(Type)") - first = False - for member in orig_node.members: - if first: - first = False - else: - out.write(', ') - out.write(f"{camel_case(member.name)}({camel_case(member.name)})") - -!} -class {{node.name}} : public {{node.parent.name}} { - - {{node.name}}( - {{cpp_ctor_params}} - ): {{node.parent.name}}({{variant_name}}::{{node.name}}{{cpp_ctor_args}} {} - - ~{{node.name}}(); -}; -{% endfor %} - -{% for namespace in namespaces %} -} {! dedent() !} -{% endfor %} - diff --git a/scripts/gennodes.py b/scripts/gennodes.py deleted file mode 100755 index c363cbb67..000000000 --- a/scripts/gennodes.py +++ /dev/null @@ -1,848 +0,0 @@ -#!/usr/bin/env python3 - -from os import wait -import re -from collections import deque -from pathlib import Path -import argparse -from typing import List, Optional -from sweetener.record import Record -import templaty - -here = Path(__file__).parent.resolve() - -EOF = '\uFFFF' - -END_OF_FILE = 0 -IDENTIFIER = 1 -SEMI = 2 -EXTERNAL = 3 -NODE = 4 -LBRACE = 5 -RBRACE = 6 -LESSTHAN = 7 -GREATERTHAN = 8 -COLON = 9 -LPAREN = 10 -RPAREN = 11 -VBAR = 12 -COMMA = 13 -HASH = 14 -STRING = 15 - -RE_WHTITESPACE = re.compile(r"[\n\r\t ]") -RE_IDENT_START = re.compile(r"[a-zA-Z_]") -RE_IDENT_PART = re.compile(r"[a-zA-Z_0-9]") - -KEYWORDS = { - 'external': EXTERNAL, - 'node': NODE, - } - -def escape_char(ch): - code = ord(ch) - if code >= 32 and code < 126: - return ch - if code <= 127: - return f"\\x{code:02X}" - return f"\\u{code:04X}" - -def camel_case(ident: str) -> str: - out = ident[0].upper() - i = 1 - while i < len(ident): - ch = ident[i] - i += 1 - if ch == '_': - c1 = ident[i] - i += 1 - out += c1.upper() - else: - out += ch - return out - -class ScanError(RuntimeError): - - def __init__(self, file, position, actual): - super().__init__(f"{file.name}:{position.line}:{position.column}: unexpected character '{escape_char(actual)}'") - self.file = file - self.position = position - self.actual = actual - -TOKEN_TYPE_TO_STRING = { - LPAREN: '(', - RPAREN: ')', - LBRACE: '{', - RBRACE: '}', - LESSTHAN: '<', - GREATERTHAN: '>', - NODE: 'node', - EXTERNAL: 'external', - SEMI: ';', - COLON: ':', - COMMA: ',', - VBAR: '|', - HASH: '#', - } - -class Token: - - def __init__(self, type, position=None, value=None): - self.type = type - self.start_pos = position - self.value = value - - @property - def text(self): - if self.type in TOKEN_TYPE_TO_STRING: - return TOKEN_TYPE_TO_STRING[self.type] - if self.type == IDENTIFIER: - return self.value - if self.type == STRING: - return f'"{self.value}"' - if self.type == END_OF_FILE: - return '' - return '(unknown token)' - -class TextFile: - - def __init__(self, filename, text=None): - self.name = filename - self._cached_text = text - - @property - def text(self): - if self._cached_text is None: - with open(self.name, 'r') as f: - self._cached_text = f.read() - return self._cached_text - -class TextPos: - - def __init__(self, line=1, column=1): - self.line = line - self.column = column - - def clone(self): - return TextPos(self.line, self.column) - - def advance(self, text): - for ch in text: - if ch == '\n': - self.line += 1 - self.column = 1 - else: - self.column += 1 - -class Scanner: - - def __init__(self, text, text_offset=0, filename=None): - self._text = text - self._text_offset = text_offset - self.file = TextFile(filename, text) - self._curr_pos = TextPos() - - def _peek_char(self, offset=1): - i = self._text_offset + offset - 1 - return self._text[i] if i < len(self._text) else EOF - - def _get_char(self): - if self._text_offset == len(self._text): - return EOF - i = self._text_offset - self._text_offset += 1 - ch = self._text[i] - self._curr_pos.advance(ch) - return ch - - def _take_while(self, pred): - out = '' - while True: - ch = self._peek_char() - if not pred(ch): - break - self._get_char() - out += ch - return out - - def scan(self): - - while True: - c0 = self._peek_char() - c1 = self._peek_char(2) - if c0 == '/' and c1 == '/': - self._get_char() - self._get_char() - while True: - c3 = self._get_char() - if c3 == '\n' or c3 == EOF: - break - continue - if RE_WHTITESPACE.match(c0): - self._get_char() - continue - break - - if c0 == EOF: - return Token(END_OF_FILE, self._curr_pos.clone()) - - start_pos = self._curr_pos.clone() - self._get_char() - - if c0 == ';': return Token(SEMI, start_pos) - if c0 == '{': return Token(LBRACE, start_pos) - if c0 == '}': return Token(RBRACE, start_pos) - if c0 == '(': return Token(LPAREN, start_pos) - if c0 == ')': return Token(RPAREN, start_pos) - if c0 == '<': return Token(LESSTHAN, start_pos) - if c0 == '>': return Token(GREATERTHAN, start_pos) - if c0 == ':': return Token(COLON, start_pos) - if c0 == '|': return Token(VBAR, start_pos) - if c0 == ',': return Token(COMMA, start_pos) - if c0 == '#': return Token(HASH, start_pos) - - if c0 == '"': - text = '' - while True: - c1 = self._get_char() - if c1 == '"': - break - text += c1 - return Token(STRING, start_pos, text) - - if RE_IDENT_START.match(c0): - name = c0 + self._take_while(lambda ch: RE_IDENT_PART.match(ch)) - return Token(KEYWORDS[name], start_pos) \ - if name in KEYWORDS \ - else Token(IDENTIFIER, start_pos, name) - - raise ScanError(self.file, start_pos, c0) - -class Type(Record): - pass - -class ListType(Type): - element_type: Type - -class OptionalType(Type): - element_type: Type - -class NodeType(Type): - name: str - -class VariantType(Type): - types: List[Type] - -class RawType(Type): - text: str - -class AST(Record): - pass - -class Directive(AST): - pass - -INCLUDEMODE_LOCAL = 0 -INCLUDEMODE_SYSTEM = 1 - -class IncludeDiretive(Directive): - path: str - mode: int - - def __str__(self): - if self.mode == INCLUDEMODE_LOCAL: - return f"#include \"{self.path}\"\n" - if self.mode == INCLUDEMODE_SYSTEM: - return f"#include <{self.path}>\n" - -class TypeExpr(AST): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.type = None - -class RefTypeExpr(TypeExpr): - name: str - args: List[TypeExpr] - -class UnionTypeExpr(TypeExpr): - types: List[TypeExpr] - -class External(AST): - name: str - -class NodeDeclField(AST): - name: str - type_expr: TypeExpr - -class NodeDecl(AST): - name: str - parents: List[str] - members: List[NodeDeclField] - -def pretty_token(token): - if token.type == END_OF_FILE: - return 'end-of-file' - return f"'{token.text}'" - -def pretty_token_type(token_type): - if token_type in TOKEN_TYPE_TO_STRING: - return f"'{TOKEN_TYPE_TO_STRING[token_type]}'" - if token_type == IDENTIFIER: - return 'an identfier' - if token_type == STRING: - return 'a string literal' - if token_type == END_OF_FILE: - return 'end-of-file' - return f"(unknown token type {token_type})" - -def pretty_alternatives(elements): - try: - out = next(elements) - except StopIteration: - return 'nothing' - try: - prev_element = next(elements) - except StopIteration: - return out - while True: - try: - element = next(elements) - except StopIteration: - break - out += ', ' + prev_element - prev_element = element - return out + ' or ' + prev_element - -class ParseError(RuntimeError): - - def __init__(self, file, actual, expected): - super().__init__(f"{file.name}:{actual.start_pos.line}:{actual.start_pos.column}: got {pretty_token(actual)} but expected {pretty_alternatives(pretty_token_type(tt) for tt in expected)}") - self.actual = actual - self.expected = expected - -class Parser: - - def __init__(self, scanner): - self._scanner = scanner - self._token_buffer = deque() - - def _peek_token(self, offset=1): - while len(self._token_buffer) < offset: - self._token_buffer.append(self._scanner.scan()) - return self._token_buffer[offset-1] - - def _get_token(self): - if self._token_buffer: - return self._token_buffer.popleft() - return self._scanner.scan() - - def _expect_token(self, expected_token_type): - t0 = self._get_token() - if t0.type != expected_token_type: - raise ParseError(self._scanner.file, t0, [ expected_token_type ]) - return t0 - - def _parse_prim_type_expr(self): - t0 = self._get_token() - if t0.type == LPAREN: - result = self.parse_type_expr() - self._expect_token(RPAREN) - return result - if t0.type == IDENTIFIER: - t1 = self._peek_token() - args = [] - if t1.type == LESSTHAN: - self._get_token() - while True: - t2 = self._peek_token() - if t2.type == GREATERTHAN: - self._get_token() - break - args.append(self.parse_type_expr()) - t3 = self._get_token() - if t3.type == GREATERTHAN: - break - if t3.type != COMMA: - raise ParseError(self._scanner.file, t3, [ COMMA, GREATERTHAN ]) - return RefTypeExpr(t0.value, args) - raise ParseError(self._scanner.file, t0, [ LPAREN, IDENTIFIER ]) - - def parse_type_expr(self): - return self._parse_prim_type_expr() - - def parse_member(self): - type_expr = self.parse_type_expr() - name = self._expect_token(IDENTIFIER) - self._expect_token(SEMI) - return NodeDeclField(name.value, type_expr) - - def parse_toplevel(self): - t0 = self._get_token() - if t0.type == EXTERNAL: - name = self._expect_token(IDENTIFIER) - self._expect_token(SEMI) - return External(name.value) - if t0.type == NODE: - name = self._expect_token(IDENTIFIER).value - parents = [] - t1 = self._peek_token() - if t1.type == COLON: - self._get_token() - while True: - parent = self._expect_token(IDENTIFIER).value - parents.append(parent) - t2 = self._peek_token() - if t2.type == COMMA: - self._get_token() - continue - if t2.type == LBRACE: - break - raise ParseError(self._scanner.file, t2, [ COMMA, LBRACE ]) - self._expect_token(LBRACE) - members = [] - while True: - t2 = self._peek_token() - if t2.type == RBRACE: - self._get_token() - break - member = self.parse_member() - members.append(member) - return NodeDecl(name, parents, members) - if t0.type == HASH: - name = self._expect_token(IDENTIFIER) - if name.value == 'include': - t1 = self._get_token() - if t1.type == LESSTHAN: - assert(not self._token_buffer) - path = self._scanner._take_while(lambda ch: ch != '>') - self._scanner._get_char() - mode = INCLUDEMODE_SYSTEM - elif t1.type == STRING: - mode = INCLUDEMODE_LOCAL - path = t1.value - else: - raise ParseError(self._scanner.file, t1, [ STRING, LESSTHAN ]) - return IncludeDiretive(path, mode) - raise RuntimeError(f"invalid preprocessor directive '{name.value}'") - raise ParseError(self._scanner.file, t0, [ EXTERNAL, NODE, HASH ]) - - def parse_grammar(self): - elements = [] - while True: - t0 = self._peek_token() - if t0.type == END_OF_FILE: - break - element = self.parse_toplevel() - elements.append(element) - return elements - -class Writer: - - def __init__(self, text='', path=None): - self.path = path - self.text = text - self._at_blank_line = True - self._indentation = ' ' - self._indent_level = 0 - - def indent(self, count=1): - self._indent_level += count - - def dedent(self, count=1): - self._indent_level -= count - - def write(self, chunk): - for ch in chunk: - if ch == '}': - self.dedent() - if ch == '\n': - self._at_blank_line = True - elif self._at_blank_line and not RE_WHTITESPACE.match(ch): - self.text += self._indentation * self._indent_level - self._at_blank_line = False - self.text += ch - if ch == '{': - self.indent() - - def save(self, dest_dir): - dest_path = dest_dir / self.path - print(f'Writing file {dest_path} ...') - with open(dest_path, 'w') as f: - f.write(self.text) - -class DiGraph: - - def __init__(self): - self._out_edges = dict() - self._in_edges = dict() - - def add_edge(self, a, b): - if a not in self._out_edges: - self._out_edges[a] = set() - self._out_edges[a].add(b) - if b not in self._in_edges: - self._in_edges[b] = set() - self._in_edges[b].add(a) - - def get_children(self, node): - if node not in self._out_edges: - return - for child in self._out_edges[node]: - yield child - - def has_children(self, node): - return node in self._out_edges - - def is_child_of(self, a, b): - stack = [ b ] - visited = set() - while stack: - node = stack.pop() - if node in visited: - break - visited.add(node) - if node == a: - return True - for child in self.get_children(node): - stack.append(child) - return False - - def get_ancestors(self, node): - if node not in self._in_edges: - return - for parent in self._in_edges[node]: - yield parent - - def get_common_ancestor(self, nodes): - out = nodes[0] - parents = [] - for node in nodes[1:]: - if not self.is_child_of(node, out): - for parent in self.get_ancestors(node): - parents.append(parent) - if not parents: - return out - parents.append(out) - return self.get_common_ancestor(parents) - -def main(): - - parser = argparse.ArgumentParser() - - parser.add_argument('file', nargs=1, help='The specification file to generate C++ code for') - parser.add_argument('--namespace', default='', help='What C++ namespace to put generated code under') - parser.add_argument('--name', default='AST', help='How to name the generated tree') - parser.add_argument('-I', default='.', help='What path will be used to include generated header files') - parser.add_argument('--include-root', default='.', help='Where the headers live inside the include directroy') - parser.add_argument('--enable-serde', action='store_true', help='Also write (de)serialization logic') - parser.add_argument('--source-root', default='.', help='Where to store generated souce files') - parser.add_argument('--node-name', default='Node', help='How the root node of the hierachy should be called') - parser.add_argument('--node-prefix', default='', help='String to prepend to the names of node types') - parser.add_argument('--out-dir', default='.', help='Place the endire folder structure inside this folder') - parser.add_argument('--dry-run', action='store_true', help='Do not write generated code to the file system') - - args = parser.parse_args() - - filename = args.file[0] - prefix = args.node_prefix - cpp_root_node_name = prefix + args.node_name - include_dir = Path(args.I) - include_path = Path(args.include_root or '.') - full_include_path = include_dir / include_path - source_path = Path(args.source_root) - namespace = args.namespace.split('::') - out_dir = Path(args.out_dir) - out_name = args.name - write_serde = args.enable_serde - - with open(filename, 'r') as f: - text = f.read() - - scanner = Scanner(text, filename=filename) - parser = Parser(scanner) - elements = parser.parse_grammar() - - types = dict() - nodes = list() - leaf_nodes = list() - graph = DiGraph() - parent_to_children = dict() - - for element in elements: - if isinstance(element, External) \ - or isinstance(element, NodeDecl): - types[element.name] = element - if isinstance(element, NodeDecl): - nodes.append(element) - for parent in element.parents: - graph.add_edge(parent, element.name) - if parent not in parent_to_children: - parent_to_children[parent] = set() - children = parent_to_children[parent] - children.add(element) - - for node in nodes: - if node.name not in parent_to_children: - leaf_nodes.append(node) - - def is_null_type_expr(type_expr): - return isinstance(type_expr, RefTypeExpr) and type_expr.name == 'null' - - def is_node(name): - if name in types: - return isinstance(types[name], NodeDecl) - if name in parent_to_children: - return True - return False - - def get_all_variant_elements(type_expr): - types = list() - def loop(ty): - if isinstance(ty, RefTypeExpr) and ty.name == 'Variant': - for arg in ty.args: - loop(arg) - else: - types.append(ty) - loop(type_expr) - return types - - def infer_type(type_expr): - if isinstance(type_expr, RefTypeExpr): - if type_expr.name == 'Option': - assert(len(type_expr.args) == 1) - return OptionalType(infer_type(type_expr.args[0])) - if type_expr.name == 'List': - assert(len(type_expr.args) == 1) - return ListType(infer_type(type_expr.args[0])) - if type_expr.name == 'Variant': - types = get_all_variant_elements(type_expr) - has_null = False - if any(is_null_type_expr(ty) for ty in types): - has_null = True - types = list(ty for ty in types if not is_null_type_expr(ty)) - if all(isinstance(ty, RefTypeExpr) and is_node(ty.name) for ty in types): - node_name = graph.get_common_ancestor(list(t.name for t in types)) - return NodeType(node_name) - if len(types) == 1: - out = infer_type(types[0]) - else: - out = VariantType(infer_type(ty) for ty in types) - return OptionalType(out) if has_null else out - if is_node(type_expr.name): - assert(len(type_expr.args) == 0) - return NodeType(type_expr.name) - assert(len(type_expr.args) == 0) - return RawType(type_expr.name) - raise RuntimeError(f"unhandled type expression {type_expr}") - - for node in nodes: - for member in node.members: - member.type_expr.type = infer_type(member.type_expr) - - def is_type_optional_by_default(ty): - return isinstance(ty, NodeType) - - def gen_cpp_type_expr(ty): - if isinstance(ty, NodeType): - return prefix + ty.name + "*" - if isinstance(ty, ListType): - return f"std::vector<{gen_cpp_type_expr(ty.element_type)}>" - if isinstance(ty, NodeType): - return ty.name + '*' - if isinstance(ty, OptionalType): - cpp_expr = gen_cpp_type_expr(ty.element_type) - if is_type_optional_by_default(ty.element_type): - return cpp_expr - return f"std::optional<{cpp_expr}>" - if isinstance(ty, VariantType): - return f"std::variant<{','.join(gen_cpp_type_expr(t) for t in ty.element_types)}>" - if isinstance(ty, RawType): - return ty.text - raise RuntimeError(f"unhandled Type {ty}") - - def gen_cpp_dtor(expr, ty): - if isinstance(ty, NodeType): - return f'{expr}->unref();\n' - elif isinstance(ty, ListType): - dtor = gen_cpp_dtor('Element', ty.element_type) - if dtor: - out = '' - out += f'for (auto& Element: {expr})' - out += '{\n' - out += dtor - out += '}\n' - return out - elif isinstance(ty, OptionalType): - if is_type_optional_by_default(ty.element_type): - element_expr = expr - else: - element_expr = f'(*{expr})' - dtor = gen_cpp_dtor(element_expr, ty.element_type) - if dtor: - out = '' - out += 'if (' - out += expr - out += ') {\n' - out += dtor - out += '}\n' - return out - elif isinstance(ty, RawType): - pass # field should be destroyed by class - else: - raise RuntimeError(f'unexpected {ty}') - - def gen_cpp_ctor_params(out, node): - visited = set() - queue = deque([ node ]) - is_leaf = not graph.has_children(node.name) - first = True - if not is_leaf: - out.write(f"{cpp_root_node_name}Type Type") - first = False - while queue: - node = queue.popleft() - if node.name in visited: - return - visited.add(node.name) - for member in node.members: - if first: - first = False - else: - out.write(', ') - out.write(gen_cpp_type_expr(member.type_expr.type)) - out.write(' ') - out.write(camel_case(member.name)) - for parent in node.parents: - queue.append(types[parent]) - - def gen_cpp_ctor_args(out, orig_node: NodeDecl): - first = True - is_leaf = not graph.has_children(orig_node.name) - if orig_node.parents: - for parent in orig_node.parents: - if first: - first = False - else: - out.write(', ') - node = types[parent] - refs = '' - if is_leaf: - refs += f"{cpp_root_node_name}Type::{orig_node.name}" - else: - refs += 'Type' - for member in node.members: - refs += f", {camel_case(member.name)}" - out.write(f"{prefix}{node.name}({refs})") - else: - if is_leaf: - out.write(f"{cpp_root_node_name}({cpp_root_node_name}Type::{orig_node.name})") - else: - out.write(f"{cpp_root_node_name}(Type)") - first = False - for member in orig_node.members: - if first: - first = False - else: - out.write(', ') - out.write(f"{camel_case(member.name)}({camel_case(member.name)})") - - node_hdr = templaty.execute(here / 'CST.hpp.tply', ctx={ - 'namespaces': namespace, - 'nodes': nodes, - 'root_node_name': args.node_name - }) - - node_hdr = Writer(path=full_include_path / (out_name + '.hpp')) - node_src = Writer(path=source_path / (out_name + '.cc')) - - # Generating the header file - - if write_serde: - node_hdr.write('void encode(Encoder& encoder) const;\n\n') - node_hdr.write('virtual void encode_fields(Encoder& encoder) const = 0;\n'); - #node_hdr.write('virtual void decode_fields(Decoder& decoder) = 0;\n\n'); - - for element in elements: - if isinstance(element, NodeDecl): - node = element - is_leaf = not list(graph.get_children(node.name)) - cpp_node_name = prefix + node.name - node_hdr.write("class ") - node_hdr.write(cpp_node_name) - node_hdr.write(" : ") - if node.parents: - node_hdr.write(', '.join('public ' + prefix + parent for parent in node.parents)) - else: - node_hdr.write('public ' + cpp_root_node_name) - node_hdr.write(" {\n\n") - node_hdr.write('public:\n\n') - - node_hdr.write(cpp_node_name + '(') - gen_cpp_ctor_params(node_hdr, node) - node_hdr.write('): ') - gen_cpp_ctor_args(node_hdr, node) - node_hdr.write(' {}\n\n') - - if node.members: - for member in node.members: - node_hdr.write(gen_cpp_type_expr(member.type_expr.type)) - node_hdr.write(" "); - node_hdr.write(camel_case(member.name)) - node_hdr.write(";\n"); - node_hdr.write('\n') - if write_serde and is_leaf: - node_hdr.write('void encode_fields(Encoder& encoder) const override;\n'); - #node_hdr.write('void decode_fields(Decoder& decoder) override;\n\n'); - - # Generating the source file - - node_src.write(f"""#include "{include_path / (out_name + '.hpp')}"\n\n""") - - for name in namespace: - node_src.write(f"namespace {name} {{\n\n") - - node_src.write(f"""{cpp_root_node_name}::~{cpp_root_node_name}() {{ }}\n\n""") - - if write_serde: - node_src.write(f""" -void {cpp_root_node_name}::encode(Encoder& encoder) const {{ -encoder.start_encode_struct("{cpp_root_node_name}"); -encode_fields(encoder); -encoder.end_encode_struct(); -}} - -""") - - for node in nodes: - is_leaf = not list(graph.get_children(node.name)) - cpp_node_name = prefix + node.name - - if write_serde and is_leaf: - node_src.write(f'void {cpp_node_name}::encode_fields(Encoder& encoder) const {{\n') - for member in node.members: - node_src.write(f'encoder.encode_field("{member.name}", {member.name});\n') - node_src.write('}\n\n') - - node_src.write(f'{cpp_node_name}::~{cpp_node_name}() {{\n') - for member in node.members: - dtor = gen_cpp_dtor(camel_case(member.name), member.type_expr.type) - if dtor: - node_src.write(dtor) - node_src.write('}\n\n') - - for _ in namespace: - node_src.write("}\n\n") - - if args.dry_run: - print('# ' + str(node_hdr.path)) - print(node_hdr.text) - print('# ' + str(node_src.path)) - print(node_src.text) - else: - out_dir.mkdir(exist_ok=True, parents=True) - node_hdr.save(out_dir) - node_src.save(out_dir) - -if __name__ == '__main__': - main() diff --git a/src/CST.cc b/src/CST.cc index 70ee455c4..2ba5e066a 100644 --- a/src/CST.cc +++ b/src/CST.cc @@ -2,6 +2,7 @@ #include "zen/config.hpp" #include "bolt/CST.hpp" +#include "bolt/CSTVisitor.hpp" namespace bolt { @@ -11,23 +12,34 @@ namespace bolt { } void Scope::scan(Node* X) { - switch (X->Type) { - case NodeType::ExpressionStatement: - case NodeType::ReturnStatement: - case NodeType::IfStatement: + switch (X->getKind()) { + case NodeKind::ExpressionStatement: + case NodeKind::ReturnStatement: + case NodeKind::IfStatement: break; - case NodeType::SourceFile: + case NodeKind::SourceFile: { - auto Y = static_cast(X); - for (auto Element: Y->Elements) { + auto File = static_cast(X); + for (auto Element: File->Elements) { scan(Element); } break; } - case NodeType::LetDeclaration: + case NodeKind::ClassDeclaration: { - auto Y = static_cast(X); - addBindings(Y->Pattern, Y); + auto Decl = static_cast(X); + for (auto Element: Decl->Elements) { + scan(Element); + } + break; + } + case NodeKind::InstanceDeclaration: + // FIXME is this right? + break; + case NodeKind::LetDeclaration: + { + auto Decl = static_cast(X); + addBindings(Decl->Pattern, Decl); break; } default: @@ -36,8 +48,8 @@ namespace bolt { } void Scope::addBindings(Pattern* X, Node* ToInsert) { - switch (X->Type) { - case NodeType::BindPattern: + switch (X->getKind()) { + case NodeKind::BindPattern: { auto Y = static_cast(X); Mapping.emplace(Y->Name->Text, ToInsert); @@ -49,6 +61,7 @@ namespace bolt { } Node* Scope::lookup(SymbolPath Path) { + ZEN_ASSERT(Path.Modules.empty()); auto Curr = this; do { auto Match = Curr->Mapping.find(Path.Name); @@ -70,7 +83,7 @@ namespace bolt { SourceFile* Node::getSourceFile() { auto CurrNode = this; for (;;) { - if (CurrNode->Type == NodeType::SourceFile) { + if (CurrNode->Kind == NodeKind::SourceFile) { return static_cast(CurrNode); } CurrNode = CurrNode->Parent; @@ -95,435 +108,49 @@ namespace bolt { return EndLoc; } - void Token::setParents() { - } + void Node::setParents() { - void QualifiedName::setParents() { - for (auto Name: ModulePath) { - Name->Parent = this; - } - Name->Parent = this; - } + struct SetParentsVisitor : public CSTVisitor { - void ReferenceTypeExpression::setParents() { - Name->Parent = this; - Name->setParents(); - } - - void ArrowTypeExpression::setParents() { - for (auto ParamType: ParamTypes) { - ParamType->Parent = this; - ParamType->setParents(); - } - ReturnType->Parent = this; - ReturnType->setParents(); - } + std::vector Parents { nullptr }; - void BindPattern::setParents() { - Name->Parent = this; - } + void visit(Node* N) { + N->Parent = Parents.back(); + Parents.push_back(N); + visitEachChild(N); + Parents.pop_back(); + } - void ReferenceExpression::setParents() { - Name->Parent = this; - } + }; - void NestedExpression::setParents() { - LParen->Parent = this; - Inner->Parent = this; - Inner->setParents(); - RParen->Parent = this; - } + SetParentsVisitor V; + V.visit(this); - void ConstantExpression::setParents() { - Token->Parent = this; - } - - void CallExpression::setParents() { - Function->Parent = this; - Function->setParents(); - for (auto Arg: Args) { - Arg->Parent = this; - Arg->setParents(); - } - } - - void InfixExpression::setParents() { - LHS->Parent = this; - LHS->setParents(); - Operator->Parent = this; - RHS->Parent = this; - RHS->setParents(); - } - - void UnaryExpression::setParents() { - Operator->Parent = this; - Argument->Parent = this; - Argument->setParents(); - } - - void ExpressionStatement::setParents() { - Expression->Parent = this; - Expression->setParents(); - } - - void ReturnStatement::setParents() { - ReturnKeyword->Parent = this; - Expression->Parent = this; - Expression->setParents(); - } - - void IfStatementPart::setParents() { - Keyword->Parent = this; - if (Test) { - Test->Parent = this; - Test->setParents(); - } - BlockStart->Parent = this; - for (auto Element: Elements) { - Element->Parent = this; - Element->setParents(); - } - } - - void IfStatement::setParents() { - for (auto Part: Parts) { - Part->Parent = this; - Part->setParents(); - } - } - - void TypeAssert::setParents() { - Colon->Parent = this; - TypeExpression->Parent = this; - TypeExpression->setParents(); - } - - void LetBlockBody::setParents() { - BlockStart->Parent = this; - for (auto Element: Elements) { - Element->Parent = this; - Element->setParents(); - } - } - - void LetExprBody::setParents() { - Equals->Parent = this; - Expression->Parent = this; - Expression->setParents(); - } - - void Param::setParents() { - Pattern->Parent = this; - Pattern->setParents(); - if (TypeAssert) { - TypeAssert->Parent = this; - TypeAssert->setParents(); - } - } - - void LetDeclaration::setParents() { - if (PubKeyword) { - PubKeyword->Parent = this; - } - LetKeyword->Parent = this; - if (MutKeyword) { - MutKeyword->Parent = this; - } - Pattern->Parent = this; - Pattern->setParents(); - for (auto Param: Params) { - Param->Parent = this; - Param->setParents(); - } - if (TypeAssert) { - TypeAssert->Parent = this; - TypeAssert->setParents(); - } - if (Body) { - Body->Parent = this; - Body->setParents(); - } - } - - void StructDeclField::setParents() { - Name->Parent = this; - Colon->Parent = this; - TypeExpression->Parent = this; - TypeExpression->setParents(); - } - - void StructDecl::setParents() { - StructKeyword->Parent = this; - Name->Parent = this; - BlockStart->Parent = this; - for (auto Field: Fields) { - Field->Parent = this; - Field->setParents(); - } - } - - void SourceFile::setParents() { - for (auto Element: Elements) { - Element->Parent = this; - Element->setParents(); - } } Node::~Node() { + + struct UnrefVisitor : public CSTVisitor { + + void visit(Node* N) { + N->unref(); + visitEachChild(N); + } + + }; + + UnrefVisitor V; + V.visitEachChild(this); + } - Token::~Token() { - } - - Equals::~Equals() { - } - - Colon::~Colon() { - } - - RArrow::~RArrow() { - } - - Dot::~Dot() { - } - - DotDot::~DotDot() { - } - - LParen::~LParen() { - } - - RParen::~RParen() { - } - - LBracket::~LBracket() { - } - - RBracket::~RBracket() { - } - - LBrace::~LBrace() { - } - - RBrace::~RBrace() { - } - - LetKeyword::~LetKeyword() { - } - - MutKeyword::~MutKeyword() { - } - - PubKeyword::~PubKeyword() { - } - - TypeKeyword::~TypeKeyword() { - } - - ReturnKeyword::~ReturnKeyword() { - } - - IfKeyword::~IfKeyword() { - } - - ElifKeyword::~ElifKeyword() { - } - - ElseKeyword::~ElseKeyword() { - } - - ModKeyword::~ModKeyword() { - } - - StructKeyword::~StructKeyword() { - } - - Invalid::~Invalid() { - } - - EndOfFile::~EndOfFile() { - } - - BlockStart::~BlockStart() { - } - - BlockEnd::~BlockEnd() { - } - - LineFoldEnd::~LineFoldEnd() { - } - - CustomOperator::~CustomOperator() { - } - - Assignment::~Assignment() { - } - - Identifier::~Identifier() { - } - - StringLiteral::~StringLiteral() { - } - - IntegerLiteral::~IntegerLiteral() { - } - - QualifiedName::~QualifiedName() { - for (auto& Element: ModulePath){ - Element->unref(); - } - Name->unref(); - } - - TypeExpression::~TypeExpression() { - } - - ReferenceTypeExpression::~ReferenceTypeExpression() { - Name->unref(); - } - - ArrowTypeExpression::~ArrowTypeExpression() { - for (auto ParamType: ParamTypes) { - ParamType->unref(); - } - ReturnType->unref(); - } - - Pattern::~Pattern() { - } - - BindPattern::~BindPattern() { - Name->unref(); - } - - Expression::~Expression() { - } - - ReferenceExpression::~ReferenceExpression() { - Name->unref(); - } - - NestedExpression::~NestedExpression() { - LParen->unref(); - Inner->unref(); - RParen->unref(); - } - - ConstantExpression::~ConstantExpression() { - Token->unref(); - } - - CallExpression::~CallExpression() { - Function->unref(); - for (auto& Element: Args){ - Element->unref(); - } - } - - InfixExpression::~InfixExpression() { - LHS->unref(); - Operator->unref(); - RHS->unref(); - } - - UnaryExpression::~UnaryExpression() { - Operator->unref(); - Argument->unref(); - } - - Statement::~Statement() { - } - - ExpressionStatement::~ExpressionStatement() { - Expression->unref(); - } - - ReturnStatement::~ReturnStatement() { - ReturnKeyword->unref(); - Expression->unref(); - } - - IfStatementPart::~IfStatementPart() { - Keyword->unref(); - if (Test) { - Test->unref(); - } - BlockStart->unref(); - for (auto Element: Elements) { - Element->unref(); - } - } - - IfStatement::~IfStatement() { - for (auto Part: Parts) { - Part->unref(); - } - } - - TypeAssert::~TypeAssert() { - Colon->unref(); - TypeExpression->unref(); - } - - Param::~Param() { - Pattern->unref(); - TypeAssert->unref(); - } - - LetBody::~LetBody() { - } - - LetBlockBody::~LetBlockBody() { - BlockStart->unref(); - for (auto& Element: Elements){ - Element->unref(); - } - } - - LetExprBody::~LetExprBody() { - Equals->unref(); - Expression->unref(); - } - - LetDeclaration::~LetDeclaration() { - if (PubKeyword) { - PubKeyword->unref(); - } - LetKeyword->unref(); - if (MutKeyword) { - MutKeyword->unref(); - } - Pattern->unref(); - for (auto& Element: Params){ - Element->unref(); - } - if (TypeAssert) { - TypeAssert->unref(); - } - if (Body) { - Body->unref(); - } - } - - StructDeclField::~StructDeclField() { - Name->unref(); - Colon->unref(); - TypeExpression->unref(); - } - - StructDecl::~StructDecl() { - StructKeyword->unref(); - Name->unref(); - BlockStart->unref(); - for (auto& Element: Fields){ - Element->unref(); - } - } - - SourceFile::~SourceFile() { - for (auto& Element: Elements){ - Element->unref(); + bool Identifier::isTypeVar() const { + for (auto C: Text) { + if (!((C >= 97 && C <= 122) || C == '_')) { + return false; + } } + return true; } Token* QualifiedName::getFirstToken() { @@ -537,6 +164,36 @@ namespace bolt { return Name; } + Token* TypeclassConstraintExpression::getFirstToken() { + return Name; + } + + Token* TypeclassConstraintExpression::getLastToken() { + if (!TEs.empty()) { + return TEs.back()->getLastToken(); + } + return Name; + } + + Token* EqualityConstraintExpression::getFirstToken() { + return Left->getFirstToken(); + } + + Token* EqualityConstraintExpression::getLastToken() { + return Left->getLastToken(); + } + + Token* QualifiedTypeExpression::getFirstToken() { + if (!Constraints.empty()) { + return std::get<0>(Constraints.front())->getFirstToken(); + } + return TE->getFirstToken(); + } + + Token* QualifiedTypeExpression::getLastToken() { + return TE->getLastToken(); + } + Token* ReferenceTypeExpression::getFirstToken() { return Name->getFirstToken(); } @@ -556,6 +213,14 @@ namespace bolt { return ReturnType->getLastToken(); } + Token* VarTypeExpression::getLastToken() { + return Name; + } + + Token* VarTypeExpression::getFirstToken() { + return Name; + } + Token* BindPattern::getFirstToken() { return Name; } @@ -607,11 +272,11 @@ namespace bolt { return RHS->getLastToken(); } - Token* UnaryExpression::getFirstToken() { + Token* PrefixExpression::getFirstToken() { return Operator; } - Token* UnaryExpression::getLastToken() { + Token* PrefixExpression::getLastToken() { return Argument->getLastToken(); } @@ -663,11 +328,11 @@ namespace bolt { return TypeExpression->getLastToken(); } - Token* Param::getFirstToken() { + Token* Parameter::getFirstToken() { return Pattern->getFirstToken(); } - Token* Param::getLastToken() { + Token* Parameter::getLastToken() { if (TypeAssert) { return TypeAssert->getLastToken(); } @@ -713,28 +378,53 @@ namespace bolt { return Pattern->getLastToken(); } - Token* StructDeclField::getFirstToken() { + Token* StructDeclarationField::getFirstToken() { return Name; } - Token* StructDeclField::getLastToken() { + Token* StructDeclarationField::getLastToken() { return TypeExpression->getLastToken(); } - Token* StructDecl::getFirstToken() { + Token* StructDeclaration::getFirstToken() { if (PubKeyword) { return PubKeyword; } return StructKeyword; } - Token* StructDecl::getLastToken() { + Token* StructDeclaration::getLastToken() { if (Fields.size()) { Fields.back()->getLastToken(); } return BlockStart; } + Token* InstanceDeclaration::getFirstToken() { + return InstanceKeyword; + } + + Token* InstanceDeclaration::getLastToken() { + if (!Elements.empty()) { + return Elements.back()->getLastToken(); + } + return BlockStart; + } + + Token* ClassDeclaration::getFirstToken() { + if (PubKeyword != nullptr) { + return PubKeyword; + } + return ClassKeyword; + } + + Token* ClassDeclaration::getLastToken() { + if (!Elements.empty()) { + return Elements.back()->getLastToken(); + } + return BlockStart; + } + Token* SourceFile::getFirstToken() { if (Elements.size()) { return Elements.front()->getFirstToken(); @@ -757,10 +447,18 @@ namespace bolt { return ":"; } + std::string Comma::getText() const { + return ","; + } + std::string RArrow::getText() const { return "->"; } + std::string RArrowAlt::getText() const { + return "=>"; + } + std::string Dot::getText() const { return "."; } @@ -873,6 +571,18 @@ namespace bolt { return ".."; } + std::string Tilde::getText() const { + return "~"; + } + + std::string ClassKeyword::getText() const { + return "class"; + } + + std::string InstanceKeyword::getText() const { + return "instance"; + } + SymbolPath QualifiedName::getSymbolPath() const { std::vector ModuleNames; for (auto Ident: ModulePath) { diff --git a/src/Checker.cc b/src/Checker.cc index 76ee51d25..c0266f880 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -1,9 +1,21 @@ +// TODO Add list of CST variable names to TVar and unify them so that e.g. the typeclass checker may pick one when displaying a diagnostic + +// TODO make sure that if we have Eq Int, Eq a ~ Eq Int such that an instance binding eq has the correct type + +// TODO make unficiation work like union-find in find() + +#include +#include #include -#include "bolt/Diagnostics.hpp" -#include "zen/config.hpp" +#include "llvm/Support/Casting.h" +#include "zen/config.hpp" +#include "zen/range.hpp" + +#include "bolt/CSTVisitor.hpp" +#include "bolt/Diagnostics.hpp" #include "bolt/CST.hpp" #include "bolt/Checker.hpp" @@ -11,6 +23,21 @@ namespace bolt { std::string describe(const Type* Ty); + bool TypeclassSignature::operator<(const TypeclassSignature& Other) const { + if (Id < Other.Id) { + return true; + } + ZEN_ASSERT(Params.size() == 1); + ZEN_ASSERT(Other.Params.size() == 1); + return Params[0]->Id < Other.Params[0]->Id; + } + + bool TypeclassSignature::operator==(const TypeclassSignature& Other) const { + ZEN_ASSERT(Params.size() == 1); + ZEN_ASSERT(Other.Params.size() == 1); + return Id == Other.Id && Params[0]->Id == Other.Params[0]->Id; + } + void Type::addTypeVars(TVSet& TVs) { switch (Kind) { case TypeKind::Var: @@ -41,8 +68,6 @@ namespace bolt { } break; } - case TypeKind::Any: - break; } } @@ -80,8 +105,6 @@ namespace bolt { } return false; } - case TypeKind::Any: - return false; } } @@ -111,8 +134,6 @@ namespace bolt { } return Changed ? new TArrow(NewParamTypes, NewRetTy) : this; } - case TypeKind::Any: - return this; case TypeKind::Con: { auto Con = static_cast(this); @@ -146,6 +167,15 @@ namespace bolt { Constraint* Constraint::substitute(const TVSub &Sub) { switch (Kind) { + case ConstraintKind::Class: + { + auto Class = static_cast(this); + std::vector NewTypes; + for (auto Ty: Class->Types) { + NewTypes.push_back(Ty->substitute(Sub)); + } + return new CClass(Class->Name, NewTypes); + } case ConstraintKind::Equal: { auto Equal = static_cast(this); @@ -165,11 +195,11 @@ namespace bolt { } } - Checker::Checker(DiagnosticEngine& DE): - DE(DE) { - BoolType = new TCon(nextConTypeId++, {}, "Bool"); - IntType = new TCon(nextConTypeId++, {}, "Int"); - StringType = new TCon(nextConTypeId++, {}, "String"); + Checker::Checker(const LanguageConfig& Config, DiagnosticEngine& DE): + Config(Config), DE(DE) { + BoolType = new TCon(NextConTypeId++, {}, "Bool"); + IntType = new TCon(NextConTypeId++, {}, "Int"); + StringType = new TCon(NextConTypeId++, {}, "String"); } Scheme* Checker::lookup(ByteString Name) { @@ -177,7 +207,7 @@ namespace bolt { auto Curr = *Iter; auto Match = Curr->Env.find(Name); if (Match != Curr->Env.end()) { - return &Match->second; + return Match->second; } } return nullptr; @@ -188,13 +218,20 @@ namespace bolt { if (Scm == nullptr) { return nullptr; } - auto& F = Scm->as(); - ZEN_ASSERT(F.TVs == nullptr || F.TVs->empty()); - return F.Type; + auto F = static_cast(Scm); + ZEN_ASSERT(F->TVs == nullptr || F->TVs->empty()); + return F->Type; } - void Checker::addBinding(ByteString Name, Scheme S) { - Contexts.back()->Env.emplace(Name, S); + void Checker::addBinding(ByteString Name, Scheme* Scm) { + for (auto Iter = Contexts.rbegin(); Iter != Contexts.rend(); Iter++) { + auto& Ctx = **Iter; + if (!Ctx.isEnvPervious()) { + Ctx.Env.emplace(Name, Scm); + return; + } + } + ZEN_UNREACHABLE } Type* Checker::getReturnType() { @@ -212,19 +249,43 @@ namespace bolt { return false; } + InferContext& Checker::getContext() { + ZEN_ASSERT(!Contexts.empty()); + return *Contexts.back(); + } + void Checker::addConstraint(Constraint* C) { switch (C->getKind()) { + case ConstraintKind::Class: + { + Contexts.back()->Constraints->push_back(C); + break; + } case ConstraintKind::Equal: { auto Y = static_cast(C); - for (auto Iter = Contexts.rbegin(); Iter != Contexts.rend(); Iter++) { - auto& Ctx = **Iter; - if (hasTypeVar(Ctx.TVs, Y->Left) || hasTypeVar(Ctx.TVs, Y->Right)) { - Ctx.Constraints.push_back(C); - return; + std::size_t MaxLevel = 0; + for (std::size_t I = Contexts.size(); I-- > 0; ) { + auto Ctx = Contexts[I]; + if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) { + MaxLevel = I; + break; } } - Contexts.front()->Constraints.push_back(C); + std::size_t MinLevel = MaxLevel; + for (std::size_t I = 0; I < Contexts.size(); I++) { + auto Ctx = Contexts[I]; + if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) { + MinLevel = I; + break; + } + } + if (MaxLevel == MinLevel) { + solveCEqual(Y); + } else { + Contexts[MaxLevel]->Constraints->push_back(C); + } + // Contexts.front()->Constraints->push_back(C); //auto I = std::max(Y->Left->MaxDepth, Y->Right->MaxDepth); //ZEN_ASSERT(I < Contexts.size()); //auto Ctx = Contexts[I]; @@ -244,16 +305,20 @@ namespace bolt { } } + void Checker::addClass(TypeclassSignature Sig) { + getContext().Classes.push_back(Sig); + } + void Checker::forwardDeclare(Node* X) { - switch (X->Type) { + switch (X->getKind()) { - case NodeType::ExpressionStatement: - case NodeType::ReturnStatement: - case NodeType::IfStatement: + case NodeKind::ExpressionStatement: + case NodeKind::ReturnStatement: + case NodeKind::IfStatement: break; - case NodeType::SourceFile: + case NodeKind::SourceFile: { auto File = static_cast(X); for (auto Element: File->Elements) { @@ -262,15 +327,60 @@ namespace bolt { break; } - case NodeType::LetDeclaration: + case NodeKind::ClassDeclaration: + { + auto Class = static_cast(X); + for (auto TE: Class->TypeVars) { + auto TV = createRigidVar(TE->Name->Text); + TV->Contexts.emplace(Class->Name->Text); + TE->setType(TV); + } + for (auto Element: Class->Elements) { + forwardDeclare(Element); + } + break; + } + + case NodeKind::InstanceDeclaration: + { + auto Decl = static_cast(X); + auto Match = InstanceMap.find(Decl->Name->Text); + if (Match == InstanceMap.end()) { + InstanceMap.emplace(Decl->Name->Text, std::vector { Decl }); + } else { + Match->second.push_back(Decl); + } + auto Ctx = createInferContext(); + Contexts.push_back(Ctx); + for (auto Element: Decl->Elements) { + forwardDeclare(Element); + } + Contexts.pop_back(); + break; + } + + case NodeKind::LetDeclaration: { auto Let = static_cast(X); - auto NewCtx = new InferContext(); + auto NewCtx = createInferContext(); Let->Ctx = NewCtx; Contexts.push_back(NewCtx); + // If declaring a let-declaration inside a type class declaration, + // we need to mark that the let-declaration requires this class. + // This marking is set on the rigid type variables of the class, which + // are then added to this local type environment. + if (llvm::isa(Let->Parent)) { + auto Decl = static_cast(Let->Parent); + for (auto TE: Decl->TypeVars) { + auto TV = llvm::cast(TE->getType()); + NewCtx->Env.emplace(TE->Name->Text, new Forall(TV)); + NewCtx->TVs->emplace(TV); + } + } + Type* Ty; if (Let->TypeAssert) { Ty = inferTypeExpression(Let->TypeAssert->TypeExpression); @@ -280,10 +390,10 @@ namespace bolt { Let->Ty = Ty; if (Let->Body) { - switch (Let->Body->Type) { - case NodeType::LetExprBody: + switch (Let->Body->getKind()) { + case NodeKind::LetExprBody: break; - case NodeType::LetBlockBody: + case NodeKind::LetBlockBody: { auto Block = static_cast(Let->Body); NewCtx->ReturnType = createTypeVar(); @@ -301,7 +411,6 @@ namespace bolt { inferBindings(Let->Pattern, Ty, NewCtx->Constraints, NewCtx->TVs); - break; } @@ -312,22 +421,47 @@ namespace bolt { } - void Checker::infer(Node* X) { + void Checker::infer(Node* N) { - switch (X->Type) { + switch (N->getKind()) { - case NodeType::SourceFile: + case NodeKind::SourceFile: { - auto File = static_cast(X); + auto File = static_cast(N); for (auto Element: File->Elements) { infer(Element); } break; } - case NodeType::IfStatement: + case NodeKind::ClassDeclaration: { - auto IfStmt = static_cast(X); + auto Decl = static_cast(N); + for (auto Element: Decl->Elements) { + infer(Element); + } + break; + } + + case NodeKind::InstanceDeclaration: + { + auto Decl = static_cast(N); + + // Needed to set the associated Type on the CST node + for (auto TE: Decl->TypeExps) { + inferTypeExpression(TE); + } + + for (auto Element: Decl->Elements) { + infer(Element); + } + + break; + } + + case NodeKind::IfStatement: + { + auto IfStmt = static_cast(N); for (auto Part: IfStmt->Parts) { if (Part->Test != nullptr) { addConstraint(new CEqual { BoolType, inferExpression(Part->Test), Part->Test }); @@ -339,36 +473,34 @@ namespace bolt { break; } - case NodeType::LetDeclaration: + case NodeKind::LetDeclaration: { - auto LetDecl = static_cast(X); + auto Decl = static_cast(N); - auto NewCtx = LetDecl->Ctx; + auto NewCtx = Decl->Ctx; Contexts.push_back(NewCtx); std::vector ParamTypes; Type* RetType; - for (auto Param: LetDecl->Params) { + for (auto Param: Decl->Params) { // TODO incorporate Param->TypeAssert or make it a kind of pattern TVar* TV = createTypeVar(); - TVSet NoTVs; - ConstraintSet NoConstraints; - inferBindings(Param->Pattern, TV, NoConstraints, NoTVs); + inferBindings(Param->Pattern, TV); ParamTypes.push_back(TV); } - if (LetDecl->Body) { - switch (LetDecl->Body->Type) { - case NodeType::LetExprBody: + if (Decl->Body) { + switch (Decl->Body->getKind()) { + case NodeKind::LetExprBody: { - auto Expr = static_cast(LetDecl->Body); + auto Expr = static_cast(Decl->Body); RetType = inferExpression(Expr->Expression); break; } - case NodeType::LetBlockBody: + case NodeKind::LetBlockBody: { - auto Block = static_cast(LetDecl->Body); + auto Block = static_cast(Decl->Body); RetType = createTypeVar(); for (auto Element: Block->Elements) { infer(Element); @@ -382,29 +514,35 @@ namespace bolt { RetType = createTypeVar(); } - addConstraint(new CEqual { LetDecl->Ty, new TArrow(ParamTypes, RetType), X }); + if (ParamTypes.empty()) { + // Declaration is a plain (typed) variable + addConstraint(new CEqual { Decl->Ty, RetType, N }); + } else { + // Declaration is a function + addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), N }); + } Contexts.pop_back(); break; } - case NodeType::ReturnStatement: + case NodeKind::ReturnStatement: { - auto RetStmt = static_cast(X); + auto RetStmt = static_cast(N); Type* ReturnType; if (RetStmt->Expression) { ReturnType = inferExpression(RetStmt->Expression); } else { ReturnType = new TTuple({}); } - addConstraint(new CEqual { ReturnType, getReturnType(), X }); + addConstraint(new CEqual { ReturnType, getReturnType(), N }); break; } - case NodeType::ExpressionStatement: + case NodeKind::ExpressionStatement: { - auto ExprStmt = static_cast(X); + auto ExprStmt = static_cast(N); inferExpression(ExprStmt->Expression); break; } @@ -416,26 +554,41 @@ namespace bolt { } - TVar* Checker::createTypeVar() { - auto TV = new TVar(nextTypeVarId++); - Contexts.back()->TVs.emplace(TV); + TVarRigid* Checker::createRigidVar(ByteString Name) { + auto TV = new TVarRigid(NextTypeVarId++, Name); + Contexts.back()->TVs->emplace(TV); return TV; } - Type* Checker::instantiate(Scheme& S, Node* Source) { + TVar* Checker::createTypeVar() { + auto TV = new TVar(NextTypeVarId++, VarKind::Unification); + Contexts.back()->TVs->emplace(TV); + return TV; + } - switch (S.getKind()) { + InferContext* Checker::createInferContext() { + auto Ctx = new InferContext; + Ctx->TVs = new TVSet; + Ctx->Constraints = new ConstraintSet; + return Ctx; + } + + Type* Checker::instantiate(Scheme* Scm, Node* Source) { + + switch (Scm->getKind()) { case SchemeKind::Forall: { - auto& F = S.as(); + auto F = static_cast(Scm); TVSub Sub; - for (auto TV: *F.TVs) { - Sub[TV] = createTypeVar(); + for (auto TV: *F->TVs) { + auto Fresh = createTypeVar(); + Fresh->Contexts = TV->Contexts; + Sub[TV] = Fresh; } - for (auto Constraint: *F.Constraints) { + for (auto Constraint: *F->Constraints) { auto NewConstraint = Constraint->substitute(Sub); @@ -448,42 +601,88 @@ namespace bolt { addConstraint(NewConstraint); } - // FIXME substitute should always clone if we set MaxDepth - auto NewType = F.Type->substitute(Sub); - //NewType->MaxDepth = std::max(static_cast(Contexts.size()-1), F.Type->MaxDepth); - return NewType; + return F->Type->substitute(Sub); } } } - Type* Checker::inferTypeExpression(TypeExpression* X) { - - switch (X->Type) { - - case NodeType::ReferenceTypeExpression: + Constraint* Checker::convertToConstraint(ConstraintExpression* C) { + switch (C->getKind()) { + case NodeKind::TypeclassConstraintExpression: { - auto RefTE = static_cast(X); + auto D = static_cast(C); + std::vector Types; + for (auto TE: D->TEs) { + Types.push_back(inferTypeExpression(TE)); + } + return new CClass(D->Name->Text, Types); + } + case NodeKind::EqualityConstraintExpression: + { + auto D = static_cast(C); + return new CEqual(inferTypeExpression(D->Left), inferTypeExpression(D->Right), C); + } + default: + ZEN_UNREACHABLE + } + } + + Type* Checker::inferTypeExpression(TypeExpression* N) { + + switch (N->getKind()) { + + case NodeKind::ReferenceTypeExpression: + { + auto RefTE = static_cast(N); auto Ty = lookupMono(RefTE->Name->Name->Text); if (Ty == nullptr) { - DE.add(RefTE->Name->Name->Text, RefTE->Name->Name); - return new TAny(); + if (!RefTE->Name->Name->isTypeVar() || Config.typeVarsRequireForall()) { + DE.add(RefTE->Name->Name->Text, RefTE->Name->Name); + } + Ty = createTypeVar(); } - Mapping[X] = Ty; + N->setType(Ty); return Ty; } - case NodeType::ArrowTypeExpression: + case NodeKind::VarTypeExpression: { - auto ArrowTE = static_cast(X); + auto VarTE = static_cast(N); + auto Ty = lookupMono(VarTE->Name->Text); + if (Ty == nullptr) { + if (Config.typeVarsRequireForall()) { + DE.add(VarTE->Name->Text, VarTE->Name); + } + Ty = createRigidVar(VarTE->Name->Text); + addBinding(VarTE->Name->Text, new Forall(Ty)); + } + N->setType(Ty); + return Ty; + } + + case NodeKind::ArrowTypeExpression: + { + auto ArrowTE = static_cast(N); std::vector ParamTypes; for (auto ParamType: ArrowTE->ParamTypes) { ParamTypes.push_back(inferTypeExpression(ParamType)); } auto ReturnType = inferTypeExpression(ArrowTE->ReturnType); auto Ty = new TArrow(ParamTypes, ReturnType); - Mapping[X] = Ty; + N->setType(Ty); + return Ty; + } + + case NodeKind::QualifiedTypeExpression: + { + auto QTE = static_cast(N); + for (auto [C, Comma]: QTE->Constraints) { + addConstraint(convertToConstraint(C)); + } + auto Ty = inferTypeExpression(QTE->TE); + N->setType(Ty); return Ty; } @@ -495,28 +694,28 @@ namespace bolt { Type* Checker::inferExpression(Expression* X) { - switch (X->Type) { + switch (X->getKind()) { - case NodeType::ConstantExpression: + case NodeKind::ConstantExpression: { auto Const = static_cast(X); Type* Ty = nullptr; - switch (Const->Token->Type) { - case NodeType::IntegerLiteral: + switch (Const->Token->getKind()) { + case NodeKind::IntegerLiteral: Ty = lookupMono("Int"); break; - case NodeType::StringLiteral: + case NodeKind::StringLiteral: Ty = lookupMono("String"); break; default: ZEN_UNREACHABLE } ZEN_ASSERT(Ty != nullptr); - Mapping[X] = Ty; + X->setType(Ty); return Ty; } - case NodeType::ReferenceExpression: + case NodeKind::ReferenceExpression: { auto Ref = static_cast(X); ZEN_ASSERT(Ref->Name->ModulePath.empty()); @@ -529,14 +728,14 @@ namespace bolt { auto Scm = lookup(Ref->Name->Name->Text); if (Scm == nullptr) { DE.add(Ref->Name->Name->Text, Ref->Name); - return new TAny(); + return createTypeVar(); } - auto Ty = instantiate(*Scm, X); - Mapping[X] = Ty; + auto Ty = instantiate(Scm, X); + X->setType(Ty); return Ty; } - case NodeType::CallExpression: + case NodeKind::CallExpression: { auto Call = static_cast(X); auto OpTy = inferExpression(Call->Function); @@ -546,29 +745,29 @@ namespace bolt { ArgTypes.push_back(inferExpression(Arg)); } addConstraint(new CEqual { OpTy, new TArrow(ArgTypes, RetType), X }); - Mapping[X] = RetType; + X->setType(RetType); return RetType; } - case NodeType::InfixExpression: + case NodeKind::InfixExpression: { auto Infix = static_cast(X); auto Scm = lookup(Infix->Operator->getText()); if (Scm == nullptr) { DE.add(Infix->Operator->getText(), Infix->Operator); - return new TAny(); + return createTypeVar(); } - auto OpTy = instantiate(*Scm, Infix->Operator); + auto OpTy = instantiate(Scm, Infix->Operator); auto RetTy = createTypeVar(); std::vector ArgTys; ArgTys.push_back(inferExpression(Infix->LHS)); ArgTys.push_back(inferExpression(Infix->RHS)); addConstraint(new CEqual { new TArrow(ArgTys, RetTy), OpTy, X }); - Mapping[X] = RetTy; + X->setType(RetTy); return RetTy; } - case NodeType::NestedExpression: + case NodeKind::NestedExpression: { auto Nested = static_cast(X); return inferExpression(Nested->Inner); @@ -581,41 +780,144 @@ namespace bolt { } - void Checker::inferBindings(Pattern* Pattern, Type* Type, ConstraintSet& Constraints, TVSet& TVs) { + void Checker::inferBindings( + Pattern* Pattern, + Type* Type, + ConstraintSet* Constraints, + TVSet* TVs + ) { - switch (Pattern->Type) { + switch (Pattern->getKind()) { - case NodeType::BindPattern: - addBinding(static_cast(Pattern)->Name->Text, Forall(TVs, Constraints, Type)); + case NodeKind::BindPattern: + { + addBinding(static_cast(Pattern)->Name->Text, new Forall(TVs, Constraints, Type)); break; + } - default: + default + : ZEN_UNREACHABLE } + } - TVSub Checker::check(SourceFile *SF) { - Contexts.push_back(new InferContext {}); - ConstraintSet NoConstraints; - addBinding("String", Forall(StringType)); - addBinding("Int", Forall(IntType)); - addBinding("Bool", Forall(BoolType)); - addBinding("True", Forall(BoolType)); - addBinding("False", Forall(BoolType)); + void Checker::inferBindings(Pattern* Pattern, Type* Type) { + inferBindings(Pattern, Type, new ConstraintSet, new TVSet); + } + + void collectTypeclasses(LetDeclaration* Decl, std::vector& Out) { + if (llvm::isa(Decl->Parent)) { + auto Class = llvm::cast(Decl->Parent); + std::vector Tys; + for (auto TE: Class->TypeVars) { + Tys.push_back(llvm::cast(TE->getType())); + } + Out.push_back(TypeclassSignature { Class->Name->Text, Tys }); + } + if (Decl->TypeAssert != nullptr) { + if (llvm::isa(Decl->TypeAssert->TypeExpression)) { + auto QTE = static_cast(Decl->TypeAssert->TypeExpression); + for (auto [C, Comma]: QTE->Constraints) { + if (llvm::isa(C)) { + auto TCE = static_cast(C); + std::vector Tys; + for (auto TE: TCE->TEs) { + auto TV = TE->getType(); + ZEN_ASSERT(llvm::isa(TV)); + Tys.push_back(static_cast(TV)); + } + Out.push_back(TypeclassSignature { TCE->Name->Text, Tys }); + } + } + } + } + } + + void Checker::checkTypeclassSigs(Node* N) { + + struct LetVisitor : CSTVisitor { + + Checker& C; + + void visitLetDeclaration(LetDeclaration* Decl) { + + std::vector Expected; + collectTypeclasses(Decl, Expected); + std::sort(Expected.begin(), Expected.end()); + Expected.erase(std::unique(Expected.begin(), Expected.end()), Expected.end()); + + std::vector Actual; + for (auto Ty: *Decl->Ctx->TVs) { + auto S = Ty->substitute(C.Solution); + if (llvm::isa(S)) { + auto TV = static_cast(S); + for (auto Class: TV->Contexts) { + Actual.push_back(TypeclassSignature { Class, { TV } }); + } + } + } + std::sort(Actual.begin(), Actual.end()); + Actual.erase(std::unique(Actual.begin(), Actual.end()), Actual.end()); + + auto It1 = Actual.begin(); + auto It2 = Expected.begin(); + + for (; It1 != Actual.end() || It2 != Expected.end() ;) { + if (It1 == Actual.end()) { + // TODO Maybe issue a warning that a type class went unused + break; + } + if (It2 == Expected.end()) { + for (; It1 != Actual.end(); It1++) { + C.DE.add(*It1, Decl); + } + break; + } + if (*It1 < *It2) { + // FIXME It1->Ty needs to be unified with potential candidate It2->Ty + C.DE.add(*It1, Decl); + It1++; + continue; + } + if (*It2 < *It1) { + // DE.add(It2->Name, Decl); + It2++; + continue; + } + It1++; + It2++; + } + + } + + }; + + LetVisitor V { {}, *this }; + V.visit(N); + + } + + void Checker::check(SourceFile *SF) { + auto RootContext = createInferContext(); + Contexts.push_back(RootContext); + addBinding("String", new Forall(StringType)); + addBinding("Int", new Forall(IntType)); + addBinding("Bool", new Forall(BoolType)); + addBinding("True", new Forall(BoolType)); + addBinding("False", new Forall(BoolType)); auto A = createTypeVar(); - TVSet SingleA { A }; - addBinding("==", Forall(SingleA, NoConstraints, new TArrow({ A, A }, BoolType))); - addBinding("+", Forall(new TArrow({ IntType, IntType }, IntType))); - addBinding("-", Forall(new TArrow({ IntType, IntType }, IntType))); - addBinding("*", Forall(new TArrow({ IntType, IntType }, IntType))); - addBinding("/", Forall(new TArrow({ IntType, IntType }, IntType))); + addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, new TArrow({ A, A }, BoolType))); + addBinding("+", new Forall(new TArrow({ IntType, IntType }, IntType))); + addBinding("-", new Forall(new TArrow({ IntType, IntType }, IntType))); + addBinding("*", new Forall(new TArrow({ IntType, IntType }, IntType))); + addBinding("/", new Forall(new TArrow({ IntType, IntType }, IntType))); forwardDeclare(SF); infer(SF); - TVSub Solution; - solve(new CMany(Contexts.front()->Constraints), Solution); Contexts.pop_back(); - return Solution; + solve(new CMany(*RootContext->Constraints), Solution); + checkTypeclassSigs(SF); } void Checker::solve(Constraint* Constraint, TVSub& Solution) { @@ -631,6 +933,12 @@ namespace bolt { switch (Constraint->getKind()) { + case ConstraintKind::Class: + { + // TODO + break; + } + case ConstraintKind::Empty: break; @@ -645,11 +953,7 @@ namespace bolt { case ConstraintKind::Equal: { - auto Equal = static_cast(Constraint); - std::cerr << describe(Equal->Left) << " ~ " << describe(Equal->Right) << std::endl; - if (!unify(Equal->Left, Equal->Right, Solution)) { - DE.add(Equal->Left->substitute(Solution), Equal->Right->substitute(Solution), Equal->Source); - } + solveCEqual(static_cast(Constraint)); break; } @@ -659,69 +963,218 @@ namespace bolt { } - bool Checker::unify(Type* A, Type* B, TVSub& Solution) { - - while (A->getKind() == TypeKind::Var) { - auto Match = Solution.find(static_cast(A)); - if (Match == Solution.end()) { - break; - } - A = Match->second; - } - - while (B->getKind() == TypeKind::Var) { - auto Match = Solution.find(static_cast(B)); - if (Match == Solution.end()) { - break; - } - B = Match->second; - } - - if (A->getKind() == TypeKind::Var) { - auto TV = static_cast(A); - if (B->hasTypeVar(TV)) { - // TODO occurs check + bool assignableTo(Type* A, Type* B) { + if (llvm::isa(A) && llvm::isa(B)) { + auto Con1 = llvm::cast(A); + auto Con2 = llvm::cast(B); + if (Con1->Id != Con2-> Id) { return false; } - Solution[TV] = B; - return true; - } - - if (B->getKind() == TypeKind::Var) { - return unify(B, A, Solution); - } - - if (A->getKind() == TypeKind::Any || B->getKind() == TypeKind::Any) { - return true; - } - - if (A->getKind() == TypeKind::Arrow && B->getKind() == TypeKind::Arrow) { - auto Arr1 = static_cast(A); - auto Arr2 = static_cast(B); - if (Arr1->ParamTypes.size() != Arr2->ParamTypes.size()) { - return false; - } - auto Count = Arr1->ParamTypes.size(); - for (std::size_t I = 0; I < Count; I++) { - if (!unify(Arr1->ParamTypes[I], Arr2->ParamTypes[I], Solution)) { + ZEN_ASSERT(Con1->Args.size() == Con2->Args.size()); + for (auto [T1, T2]: zen::zip(Con1->Args, Con2->Args)) { + if (!assignableTo(T1, T2)) { return false; } } - return unify(Arr1->ReturnType, Arr2->ReturnType, Solution); + return true; + } + ZEN_UNREACHABLE + } + + std::vector Checker::findInstanceContext(TCon* Ty, TypeclassId& Class, Node* Source) { + auto Match = InstanceMap.find(Class); + std::vector S; + if (Match != InstanceMap.end()) { + for (auto Instance: Match->second) { + if (assignableTo(Ty, Instance->TypeExps[0]->getType())) { + std::vector S; + for (auto Arg: Ty->Args) { + TypeclassContext Classes; + // TODO + S.push_back(Classes); + } + return S; + } + } + } + DE.add(Class, Ty, Source); + for (auto Arg: Ty->Args) { + S.push_back({}); + } + return S; + } + + void Checker::propagateClasses(std::unordered_set& Classes, Type* Ty, Node* Source) { + if (llvm::isa(Ty)) { + auto TV = llvm::cast(Ty); + for (auto Class: Classes) { + TV->Contexts.emplace(Class); + } + } else if (llvm::isa(Ty)) { + for (auto Class: Classes) { + propagateClassTycon(Class, llvm::cast(Ty), Source); + } + } else { + ZEN_UNREACHABLE + // DE.add(Ty); + } + }; + + void Checker::propagateClassTycon(TypeclassId& Class, TCon* Ty, Node* Source) { + auto S = findInstanceContext(Ty, Class, Source); + for (auto [Classes, Arg]: zen::zip(S, Ty->Args)) { + propagateClasses(Classes, Arg, Source); + } + }; + + class ArrowCursor { + + std::stack> Path; + + public: + + ArrowCursor(TArrow* Arr) { + Path.push({ Arr, 0 }); } - if (A->getKind() == TypeKind::Arrow) { + Type* next() { + while (!Path.empty()) { + auto& [Arr, I] = Path.top(); + Type* Ty; + if (I == -1) { + Path.pop(); + continue; + } + if (I == Arr->ParamTypes.size()) { + I = -1; + Ty = Arr->ReturnType; + } else { + Ty = Arr->ParamTypes[I]; + I++; + } + if (llvm::isa(Ty)) { + Path.push({ static_cast(Ty), 0 }); + } else { + return Ty; + } + } + return nullptr; + } + + }; + + void Checker::solveCEqual(CEqual* C) { + /* std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl; */ + if (!unify(C->Left, C->Right, C->Source)) { + DE.add(C->Left->substitute(Solution), C->Right->substitute(Solution), C->Source); + } + } + + bool Checker::unify(Type* A, Type* B, Node* Source) { + + auto find = [&](auto Ty) { + while (Ty->getKind() == TypeKind::Var) { + auto Match = Solution.find(static_cast(Ty)); + if (Match == Solution.end()) { + break; + } + Ty = Match->second; + } + return Ty; + }; + + A = find(A); + B = find(B); + + if (llvm::isa(A) && llvm::isa(B)) { + auto Var1 = static_cast(A); + auto Var2 = static_cast(B); + if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Rigid) { + if (Var1->Id != Var2->Id) { + return false; + } + return true; + } + TVar* Dest; + TVar* From; + if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Unification) { + Dest = Var1; + From = Var2; + } else { + // Only cases left are Var1 = Unification, Var2 = Rigid and Var1 = Unification, Var2 = Unification + // Either way, Var1 is a good candidate for being unified away + Dest = Var2; + From = Var1; + } + Solution[From] = Dest; + propagateClasses(From->Contexts, Dest, Source); + return true; + } + + if (llvm::isa(A)) { + auto TV = static_cast(A); + if (TV->getVarKind() == VarKind::Rigid) { + return false; + } + // Occurs check + if (B->hasTypeVar(TV)) { + // NOTE Just like GHC, we just display an error message indicating that + // A cannot match B, e.g. a cannot match [a]. It looks much better + // than obsure references to an occurs check + return false; + } + Solution[TV] = B; + if (!TV->Contexts.empty()) { + propagateClasses(TV->Contexts, B, Source); + } + return true; + } + + if (llvm::isa(B)) { + return unify(B, A, Source); + } + + if (llvm::isa(A) && llvm::isa(B)) { + auto C1 = ArrowCursor(static_cast(A)); + auto C2 = ArrowCursor(static_cast(B)); + for (;;) { + auto T1 = C1.next(); + auto T2 = C2.next(); + if (T1 == nullptr && T2 == nullptr) { + break; + } + if (T1 == nullptr || T2 == nullptr) { + return false; + } + if (!unify(T1, T2, Source)) { + return false; + } + } + return true; + /* if (Arr1->ParamTypes.size() != Arr2->ParamTypes.size()) { */ + /* return false; */ + /* } */ + /* auto Count = Arr1->ParamTypes.size(); */ + /* for (std::size_t I = 0; I < Count; I++) { */ + /* if (!unify(Arr1->ParamTypes[I], Arr2->ParamTypes[I], Solution)) { */ + /* return false; */ + /* } */ + /* } */ + /* return unify(Arr1->ReturnType, Arr2->ReturnType, Solution); */ + } + + if (llvm::isa(A)) { auto Arr = static_cast(A); if (Arr->ParamTypes.empty()) { - return unify(Arr->ReturnType, B, Solution); + return unify(Arr->ReturnType, B, Source); } } - if (B->getKind() == TypeKind::Arrow) { - return unify(B, A, Solution); + if (llvm::isa(B)) { + return unify(B, A, Source); } - if (A->getKind() == TypeKind::Tuple && B->getKind() == TypeKind::Tuple) { + if (llvm::isa(A) && llvm::isa(B)) { auto Tuple1 = static_cast(A); auto Tuple2 = static_cast(B); if (Tuple1->ElementTypes.size() != Tuple2->ElementTypes.size()) { @@ -730,14 +1183,14 @@ namespace bolt { auto Count = Tuple1->ElementTypes.size(); bool Success = true; for (size_t I = 0; I < Count; I++) { - if (!unify(Tuple1->ElementTypes[I], Tuple2->ElementTypes[I], Solution)) { + if (!unify(Tuple1->ElementTypes[I], Tuple2->ElementTypes[I], Source)) { Success = false; } } return Success; } - if (A->getKind() == TypeKind::Con && B->getKind() == TypeKind::Con) { + if (llvm::isa(A) && llvm::isa(B)) { auto Con1 = static_cast(A); auto Con2 = static_cast(B); if (Con1->Id != Con2->Id) { @@ -746,7 +1199,7 @@ namespace bolt { ZEN_ASSERT(Con1->Args.size() == Con2->Args.size()); auto Count = Con1->Args.size(); for (std::size_t I = 0; I < Count; I++) { - if (!unify(Con1->Args[I], Con2->Args[I], Solution)) { + if (!unify(Con1->Args[I], Con2->Args[I], Source)) { return false; } } @@ -765,12 +1218,8 @@ namespace bolt { return Match->second; } - Type* Checker::getType(Node *Node, const TVSub &Solution) { - auto Match = Mapping.find(Node); - if (Match == Mapping.end()) { - return nullptr; - } - return Match->second->substitute(Solution); + Type* Checker::getType(TypedNode *Node) { + return Node->getType()->substitute(Solution); } } diff --git a/src/Diagnostics.cc b/src/Diagnostics.cc index 32edbca61..c625aeafa 100644 --- a/src/Diagnostics.cc +++ b/src/Diagnostics.cc @@ -44,51 +44,61 @@ namespace bolt { Diagnostic::Diagnostic(DiagnosticKind Kind): std::runtime_error("a compiler error occurred without being caught"), Kind(Kind) {} - static std::string describe(NodeType Type) { + static std::string describe(NodeKind Type) { switch (Type) { - case NodeType::Identifier: + case NodeKind::Identifier: return "an identifier"; - case NodeType::CustomOperator: + case NodeKind::CustomOperator: return "an operator"; - case NodeType::IntegerLiteral: + case NodeKind::IntegerLiteral: return "an integer literal"; - case NodeType::EndOfFile: + case NodeKind::EndOfFile: return "end-of-file"; - case NodeType::BlockStart: + case NodeKind::BlockStart: return "the start of a new indented block"; - case NodeType::BlockEnd: + case NodeKind::BlockEnd: return "the end of the current indented block"; - case NodeType::LineFoldEnd: + case NodeKind::LineFoldEnd: return "the end of the current line-fold"; - case NodeType::LParen: + case NodeKind::LParen: return "'('"; - case NodeType::RParen: + case NodeKind::RParen: return "')'"; - case NodeType::LBrace: + case NodeKind::LBrace: return "'['"; - case NodeType::RBrace: + case NodeKind::RBrace: return "']'"; - case NodeType::LBracket: + case NodeKind::LBracket: return "'{'"; - case NodeType::RBracket: + case NodeKind::RBracket: return "'}'"; - case NodeType::Colon: + case NodeKind::Colon: return "':'"; - case NodeType::Equals: + case NodeKind::Comma: + return "','"; + case NodeKind::Equals: return "'='"; - case NodeType::StringLiteral: + case NodeKind::StringLiteral: return "a string literal"; - case NodeType::Dot: + case NodeKind::Dot: return "'.'"; - case NodeType::PubKeyword: + case NodeKind::DotDot: + return "'..'"; + case NodeKind::Tilde: + return "'~'"; + case NodeKind::RArrow: + return "'->'"; + case NodeKind::RArrowAlt: + return "'=>'"; + case NodeKind::PubKeyword: return "'pub'"; - case NodeType::LetKeyword: + case NodeKind::LetKeyword: return "'let'"; - case NodeType::MutKeyword: + case NodeKind::MutKeyword: return "'mut'"; - case NodeType::ReturnKeyword: + case NodeKind::ReturnKeyword: return "'return'"; - case NodeType::TypeKeyword: + case NodeKind::TypeKeyword: return "'type'"; default: ZEN_UNREACHABLE @@ -97,10 +107,14 @@ namespace bolt { std::string describe(const Type* Ty) { switch (Ty->getKind()) { - case TypeKind::Any: - return "any"; case TypeKind::Var: - return "a" + std::to_string(static_cast(Ty)->Id); + { + auto TV = static_cast(Ty); + if (TV->getVarKind() == VarKind::Rigid) { + return static_cast(TV)->Name; + } + return "a" + std::to_string(TV->Id); + } case TypeKind::Arrow: { auto Y = static_cast(Ty); @@ -342,7 +356,7 @@ namespace bolt { writeExcerpt(E.Initiator->getSourceFile()->getTextFile(), Range, Range, Color::Red); Out << "\n"; } - break; + return; } case DiagnosticKind::UnexpectedToken: @@ -366,7 +380,7 @@ namespace bolt { default: auto Iter = E.Expected.begin(); Out << describe(*Iter++); - NodeType Prev = *Iter++; + NodeKind Prev = *Iter++; while (Iter != E.Expected.end()) { Out << ", " << describe(Prev); Prev = *Iter++; @@ -377,7 +391,7 @@ namespace bolt { Out << " but instead got '" << E.Actual->getText() << "'\n\n"; writeExcerpt(E.File, E.Actual->getRange(), E.Actual->getRange(), Color::Red); Out << "\n"; - break; + return; } case DiagnosticKind::UnexpectedString: @@ -405,7 +419,7 @@ namespace bolt { TextRange Range { E.Location, E.Location + E.Actual }; writeExcerpt(E.File, Range, Range, Color::Red); Out << "\n"; - break; + return; } case DiagnosticKind::UnificationError: @@ -423,11 +437,56 @@ namespace bolt { writeExcerpt(E.Source->getSourceFile()->getTextFile(), Range, Range, Color::Red); Out << "\n"; } - break; + return; + } + + case DiagnosticKind::TypeclassMissing: + { + auto E = static_cast(D); + setForegroundColor(Color::Red); + setBold(true); + Out << "error: "; + resetStyles(); + Out << "the type class " << ANSI_FG_YELLOW << E.Sig.Id; + for (auto TV: E.Sig.Params) { + Out << " " << describe(TV); + } + Out << ANSI_RESET << " is missing from the declaration's type signature\n\n"; + auto Range = E.Decl->getRange(); + writeExcerpt(E.Decl->getSourceFile()->getTextFile(), Range, Range, Color::Yellow); + Out << "\n\n"; + return; + } + + case DiagnosticKind::InstanceNotFound: + { + auto E = static_cast(D); + setForegroundColor(Color::Red); + setBold(true); + Out << "error: "; + resetStyles(); + Out << "a type class instance " << ANSI_FG_YELLOW << E.TypeclassName << " " << describe(E.Ty) << ANSI_RESET " was not found.\n\n"; + auto Range = E.Source->getRange(); + //std::cerr << Range.Start.Line << ":" << Range.Start.Column << "-" << Range.End.Line << ":" << Range.End.Column << "\n"; + writeExcerpt(E.Source->getSourceFile()->getTextFile(), Range, Range, Color::Red); + Out << "\n"; + return; + } + + case DiagnosticKind::ClassNotFound: + { + auto E = static_cast(D); + setForegroundColor(Color::Red); + setBold(true); + Out << "error: "; + resetStyles(); + Out << "the type class " << ANSI_FG_YELLOW << E.Name << ANSI_RESET " was not found.\n\n"; + return; } } + ZEN_UNREACHABLE } } diff --git a/src/IPRGraph.cc b/src/IPRGraph.cc index acda9d6a7..a494b2f47 100644 --- a/src/IPRGraph.cc +++ b/src/IPRGraph.cc @@ -11,9 +11,9 @@ namespace bolt { void IPRGraph::populate(Node* X, Node* Decl) { - switch (X->Type) { + switch (X->getKind()) { - case NodeType::SourceFile: + case NodeKind::SourceFile: { auto Y = static_cast(X); for (auto Element: Y->Elements) { @@ -22,7 +22,7 @@ namespace bolt { break; } - case NodeType::IfStatement: + case NodeKind::IfStatement: { auto Y = static_cast(X); for (auto Part: Y->Parts) { @@ -33,12 +33,12 @@ namespace bolt { break; } - case NodeType::LetDeclaration: + case NodeKind::LetDeclaration: { auto Y = static_cast(X); if (Y->Body) { - switch (Y->Body->Type) { - case NodeType::LetBlockBody: + switch (Y->Body->getKind()) { + case NodeKind::LetBlockBody: { auto Z = static_cast(Y->Body); for (auto Element: Z->Elements) { @@ -46,7 +46,7 @@ namespace bolt { } break; } - case NodeType::LetExprBody: + case NodeKind::LetExprBody: { auto Z = static_cast(Y->Body); populate(Z->Expression, Y); @@ -59,10 +59,10 @@ namespace bolt { break; } - case NodeType::ConstantExpression: + case NodeKind::ConstantExpression: break; - case NodeType::CallExpression: + case NodeKind::CallExpression: { auto Y = static_cast(X); populate(Y->Function, Decl); @@ -72,7 +72,7 @@ namespace bolt { break; } - case NodeType::ReferenceExpression: + case NodeKind::ReferenceExpression: { auto Y = static_cast(X); auto Def = Y->getScope()->lookup(Y->Name->getSymbolPath()); diff --git a/src/Parser.cc b/src/Parser.cc index 630407254..ac4f71023 100644 --- a/src/Parser.cc +++ b/src/Parser.cc @@ -1,10 +1,13 @@ +#include +#include + +#include "llvm/Support/Casting.h" + #include "bolt/CST.hpp" #include "bolt/Scanner.hpp" #include "bolt/Parser.hpp" #include "bolt/Diagnostics.hpp" -#include -#include namespace bolt { @@ -57,9 +60,9 @@ namespace bolt { std::size_t I = 0; for (;;) { auto T0 = Tokens.peek(I++); - switch (T0->Type) { - case NodeType::PubKeyword: - case NodeType::MutKeyword: + switch (T0->getKind()) { + case NodeKind::PubKeyword: + case NodeKind::MutKeyword: continue; default: return T0; @@ -70,71 +73,141 @@ namespace bolt { #define BOLT_EXPECT_TOKEN(name) \ { \ auto __Token = Tokens.get(); \ - if (__Token->Type != NodeType::name) { \ - throw UnexpectedTokenDiagnostic(File, __Token, std::vector { NodeType::name }); \ + if (!llvm::isa(__Token)) { \ + throw UnexpectedTokenDiagnostic(File, __Token, std::vector { NodeKind::name }); \ } \ } - Token* Parser::expectToken(NodeType Type) { + Token* Parser::expectToken(NodeKind Kind) { auto T = Tokens.get(); - if (T->Type != Type) { - throw UnexpectedTokenDiagnostic(File, T, std::vector { Type }); \ + if (T->getKind() != Kind) { + throw UnexpectedTokenDiagnostic(File, T, std::vector { Kind }); \ } return T; } Pattern* Parser::parsePattern() { auto T0 = Tokens.peek(); - switch (T0->Type) { - case NodeType::Identifier: + switch (T0->getKind()) { + case NodeKind::Identifier: Tokens.get(); return new BindPattern(static_cast(T0)); default: - throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeType::Identifier }); + throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeKind::Identifier }); } } QualifiedName* Parser::parseQualifiedName() { std::vector ModulePath; - auto Name = expectToken(NodeType::Identifier); + auto Name = expectToken(NodeKind::Identifier); for (;;) { auto T1 = Tokens.peek(); - if (T1->Type != NodeType::Dot) { + if (T1->getKind() != NodeKind::Dot) { break; } Tokens.get(); ModulePath.push_back(static_cast(Name)); Name = Tokens.get(); - if (Name->Type != NodeType::Identifier) { - throw UnexpectedTokenDiagnostic(File, Name, std::vector { NodeType::Identifier }); + if (Name->getKind() != NodeKind::Identifier) { + throw UnexpectedTokenDiagnostic(File, Name, std::vector { NodeKind::Identifier }); } } return new QualifiedName(ModulePath, static_cast(Name)); } + TypeExpression* Parser::parseTypeExpression() { + return parseQualifiedTypeExpression(); + } + + TypeExpression* Parser::parseQualifiedTypeExpression() { + bool HasConstraints = false; + auto T0 = Tokens.peek(); + if (llvm::isa(T0)) { + std::size_t I = 1; + for (;;) { + auto T0 = Tokens.peek(I++); + switch (T0->getKind()) { + case NodeKind::RArrowAlt: + HasConstraints = true; + goto after_scan; + case NodeKind::Equals: + case NodeKind::BlockStart: + case NodeKind::LineFoldEnd: + case NodeKind::EndOfFile: + goto after_scan; + default: + break; + } + } + } +after_scan: + if (!HasConstraints) { + return parseArrowTypeExpression(); + } + Tokens.get(); + LParen* LParen = static_cast(T0); + std::vector> Constraints; + RParen* RParen; + RArrowAlt* RArrowAlt; + for (;;) { + ConstraintExpression* C; + auto T0 = Tokens.peek(); + switch (T0->getKind()) { + case NodeKind::RParen: + Tokens.get(); + RParen = static_cast(T0); + RArrowAlt = expectToken(); + goto after_constraints; + default: + C = parseConstraintExpression(); + break; + } + Comma* Comma = nullptr; + auto T1 = Tokens.get(); + switch (T1->getKind()) { + case NodeKind::Comma: + Constraints.push_back(std::make_tuple(C, static_cast(T1))); + continue; + case NodeKind::RParen: + RArrowAlt = static_cast(T1); + Constraints.push_back(std::make_tuple(C, nullptr)); + RArrowAlt = expectToken(); + goto after_constraints; + default: + throw UnexpectedTokenDiagnostic(File, T1, std::vector { NodeKind::Comma, NodeKind::RArrowAlt }); + } + } +after_constraints: + auto TE = parseArrowTypeExpression(); + return new QualifiedTypeExpression(Constraints, RArrowAlt, TE); + } + TypeExpression* Parser::parsePrimitiveTypeExpression() { auto T0 = Tokens.peek(); - switch (T0->Type) { - case NodeType::Identifier: + switch (T0->getKind()) { + case NodeKind::Identifier: + if (static_cast(T0)->isTypeVar()) { + return parseVarTypeExpression(); + } return new ReferenceTypeExpression(parseQualifiedName()); default: - throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeType::Identifier }); + throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeKind::Identifier }); } } - TypeExpression* Parser::parseTypeExpression() { + TypeExpression* Parser::parseArrowTypeExpression() { auto RetType = parsePrimitiveTypeExpression(); std::vector ParamTypes; for (;;) { auto T1 = Tokens.peek(); - if (T1->Type != NodeType::RArrow) { + if (T1->getKind() != NodeKind::RArrow) { break; } Tokens.get(); ParamTypes.push_back(RetType); RetType = parsePrimitiveTypeExpression(); } - if (ParamTypes.size()) { + if (!ParamTypes.empty()) { return new ArrowTypeExpression(ParamTypes, RetType); } return RetType; @@ -142,25 +215,25 @@ namespace bolt { Expression* Parser::parsePrimitiveExpression() { auto T0 = Tokens.peek(); - switch (T0->Type) { - case NodeType::Identifier: + switch (T0->getKind()) { + case NodeKind::Identifier: { auto Name = parseQualifiedName(); return new ReferenceExpression(Name); } - case NodeType::LParen: + case NodeKind::LParen: { Tokens.get(); auto E = parseExpression(); - auto T2 = static_cast(expectToken(NodeType::RParen)); + auto T2 = static_cast(expectToken(NodeKind::RParen)); return new NestedExpression(static_cast(T0), E, T2); } - case NodeType::IntegerLiteral: - case NodeType::StringLiteral: + case NodeKind::IntegerLiteral: + case NodeKind::StringLiteral: Tokens.get(); return new ConstantExpression(T0); default: - throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeType::Identifier, NodeType::IntegerLiteral, NodeType::StringLiteral }); + throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeKind::Identifier, NodeKind::IntegerLiteral, NodeKind::StringLiteral }); } } @@ -169,7 +242,7 @@ namespace bolt { std::vector Args; for (;;) { auto T1 = Tokens.peek(); - if (T1->Type == NodeType::LineFoldEnd || T1->Type == NodeType::RParen || T1->Type == NodeType::BlockStart || ExprOperators.isInfix(T1)) { + if (T1->getKind() == NodeKind::LineFoldEnd || T1->getKind() == NodeKind::RParen || T1->getKind() == NodeKind::BlockStart || ExprOperators.isInfix(T1)) { break; } Args.push_back(parsePrimitiveExpression()); @@ -192,7 +265,7 @@ namespace bolt { } auto E = parseCallExpression(); for (auto Iter = Prefix.rbegin(); Iter != Prefix.rend(); Iter++) { - E = new UnaryExpression(*Iter, E); + E = new PrefixExpression(*Iter, E); } return E; } @@ -230,10 +303,10 @@ namespace bolt { } ReturnStatement* Parser::parseReturnStatement() { - auto T0 = static_cast(expectToken(NodeType::ReturnKeyword)); + auto T0 = static_cast(expectToken(NodeKind::ReturnKeyword)); Expression* Expression = nullptr; auto T1 = Tokens.peek(); - if (T1->Type != NodeType::LineFoldEnd) { + if (T1->getKind() != NodeKind::LineFoldEnd) { Expression = parseExpression(); } BOLT_EXPECT_TOKEN(LineFoldEnd); @@ -242,13 +315,13 @@ namespace bolt { IfStatement* Parser::parseIfStatement() { std::vector Parts; - auto T0 = expectToken(NodeType::IfKeyword); + auto T0 = expectToken(NodeKind::IfKeyword); auto Test = parseExpression(); - auto T1 = static_cast(expectToken(NodeType::BlockStart)); + auto T1 = static_cast(expectToken(NodeKind::BlockStart)); std::vector Then; for (;;) { auto T2 = Tokens.peek(); - if (T2->Type == NodeType::BlockEnd) { + if (T2->getKind() == NodeKind::BlockEnd) { Tokens.get(); break; } @@ -257,13 +330,13 @@ namespace bolt { Parts.push_back(new IfStatementPart(T0, Test, T1, Then)); BOLT_EXPECT_TOKEN(LineFoldEnd) auto T3 = Tokens.peek(); - if (T3->Type == NodeType::ElseKeyword) { + if (T3->getKind() == NodeKind::ElseKeyword) { Tokens.get(); - auto T4 = static_cast(expectToken(NodeType::BlockStart)); + auto T4 = static_cast(expectToken(NodeKind::BlockStart)); std::vector Else; for (;;) { auto T5 = Tokens.peek(); - if (T5->Type == NodeType::BlockEnd) { + if (T5->getKind() == NodeKind::BlockEnd) { Tokens.get(); break; } @@ -281,41 +354,41 @@ namespace bolt { LetKeyword* Let; MutKeyword* Mut = nullptr; auto T0 = Tokens.get(); - if (T0->Type == NodeType::PubKeyword) { + if (T0->getKind() == NodeKind::PubKeyword) { Pub = static_cast(T0); T0 = Tokens.get(); } - if (T0->Type != NodeType::LetKeyword) { - throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeType::LetKeyword }); + if (T0->getKind() != NodeKind::LetKeyword) { + throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeKind::LetKeyword }); } Let = static_cast(T0); auto T1 = Tokens.peek(); - if (T1->Type == NodeType::MutKeyword) { + if (T1->getKind() == NodeKind::MutKeyword) { Mut = static_cast(T1); Tokens.get(); } auto Patt = parsePattern(); - std::vector Params; + std::vector Params; Token* T2; for (;;) { T2 = Tokens.peek(); - switch (T2->Type) { - case NodeType::LineFoldEnd: - case NodeType::BlockStart: - case NodeType::Equals: - case NodeType::Colon: + switch (T2->getKind()) { + case NodeKind::LineFoldEnd: + case NodeKind::BlockStart: + case NodeKind::Equals: + case NodeKind::Colon: goto after_params; default: - Params.push_back(new Param(parsePattern(), nullptr)); + Params.push_back(new Parameter(parsePattern(), nullptr)); } } after_params: TypeAssert* TA = nullptr; - if (T2->Type == NodeType::Colon) { + if (T2->getKind() == NodeKind::Colon) { Tokens.get(); auto TE = parseTypeExpression(); TA = new TypeAssert(static_cast(T2), TE); @@ -323,14 +396,14 @@ after_params: } LetBody* Body; - switch (T2->Type) { - case NodeType::BlockStart: + switch (T2->getKind()) { + case NodeKind::BlockStart: { Tokens.get(); std::vector Elements; for (;;) { auto T3 = Tokens.peek(); - if (T3->Type == NodeType::BlockEnd) { + if (T3->getKind() == NodeKind::BlockEnd) { break; } Elements.push_back(parseLetBodyElement()); @@ -339,20 +412,20 @@ after_params: Body = new LetBlockBody(static_cast(T2), Elements); break; } - case NodeType::Equals: + case NodeKind::Equals: Tokens.get(); Body = new LetExprBody(static_cast(T2), parseExpression()); break; - case NodeType::LineFoldEnd: + case NodeKind::LineFoldEnd: Body = nullptr; break; default: - std::vector Expected { NodeType::BlockStart, NodeType::LineFoldEnd, NodeType::Equals }; + std::vector Expected { NodeKind::BlockStart, NodeKind::LineFoldEnd, NodeKind::Equals }; if (TA == nullptr) { // First tokens of TypeAssert - Expected.push_back(NodeType::Colon); + Expected.push_back(NodeKind::Colon); // First tokens of Pattern - Expected.push_back(NodeType::Identifier); + Expected.push_back(NodeKind::Identifier); } throw UnexpectedTokenDiagnostic(File, T2, Expected); } @@ -372,25 +445,161 @@ after_params: Node* Parser::parseLetBodyElement() { auto T0 = peekFirstTokenAfterModifiers(); - switch (T0->Type) { - case NodeType::LetKeyword: + switch (T0->getKind()) { + case NodeKind::LetKeyword: return parseLetDeclaration(); - case NodeType::ReturnKeyword: + case NodeKind::ReturnKeyword: return parseReturnStatement(); - case NodeType::IfKeyword: + case NodeKind::IfKeyword: return parseIfStatement(); default: return parseExpressionStatement(); } } + ConstraintExpression* Parser::parseConstraintExpression() { + bool HasTilde = false; + for (std::size_t I = 0; ; I++) { + auto Tok = Tokens.peek(I); + switch (Tok->getKind()) { + case NodeKind::Tilde: + HasTilde = true; + goto after_seek; + case NodeKind::RParen: + case NodeKind::Comma: + case NodeKind::RArrowAlt: + case NodeKind::EndOfFile: + goto after_seek; + default: + continue; + } + } +after_seek: + if (HasTilde) { + auto Left = parseArrowTypeExpression(); + auto Tilde = expectToken(); + auto Right = parseArrowTypeExpression(); + return new EqualityConstraintExpression { Left, Tilde, Right }; + } + auto Name = expectToken(); + std::vector TEs; + for (;;) { + auto T1 = Tokens.peek(); + switch (T1->getKind()) { + case NodeKind::RParen: + case NodeKind::RArrowAlt: + case NodeKind::Comma: + goto after_vars; + case NodeKind::Identifier: + Tokens.get(); + TEs.push_back(new VarTypeExpression { static_cast(T1) }); + break; + default: + throw UnexpectedTokenDiagnostic(File, T1, std::vector { NodeKind::RParen, NodeKind::RArrowAlt, NodeKind::Comma, NodeKind::Identifier }); + } + } +after_vars: + return new TypeclassConstraintExpression { Name, TEs }; + } + + VarTypeExpression* Parser::parseVarTypeExpression() { + auto Name = expectToken(); + // TODO reject constructor symbols (starting with a capital letter) + return new VarTypeExpression { Name }; + } + + InstanceDeclaration* Parser::parseInstanceDeclaration() { + auto InstanceKeyword = expectToken(); + auto Name = expectToken(); + std::vector TypeExps; + for (;;) { + auto T1 = Tokens.peek(); + if (T1->is()) { + break; + } + TypeExps.push_back(parseTypeExpression()); + } + auto BlockStart = expectToken(); + std::vector Elements; + for (;;) { + auto T2 = Tokens.peek(); + if (T2->is()) { + Tokens.get(); + break; + } + Elements.push_back(parseClassElement()); + } + expectToken(NodeKind::LineFoldEnd); + return new InstanceDeclaration( + InstanceKeyword, + Name, + TypeExps, + BlockStart, + Elements + ); + } + + ClassDeclaration* Parser::parseClassDeclaration() { + PubKeyword* PubKeyword = nullptr; + auto T0 = Tokens.peek(); + if (T0->getKind() == NodeKind::PubKeyword) { + Tokens.get(); + PubKeyword = static_cast(T0); + } + auto ClassKeyword = expectToken(); + auto Name = expectToken(); + std::vector TypeVars; + for (;;) { + auto T2 = Tokens.peek(); + if (T2->getKind() == NodeKind::BlockStart) { + break; + } + TypeVars.push_back(parseVarTypeExpression()); + } + auto BlockStart = expectToken(); + std::vector Elements; + for (;;) { + auto T2 = Tokens.peek(); + if (T2->is()) { + Tokens.get(); + break; + } + Elements.push_back(parseClassElement()); + } + expectToken(NodeKind::LineFoldEnd); + return new ClassDeclaration( + PubKeyword, + ClassKeyword, + Name, + TypeVars, + BlockStart, + Elements + ); + } + + Node* Parser::parseClassElement() { + auto T0 = Tokens.peek(); + switch (T0->getKind()) { + case NodeKind::LetKeyword: + return parseLetDeclaration(); + case NodeKind::TypeKeyword: + // TODO + default: + throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeKind::LetKeyword, NodeKind::TypeKeyword }); + } + } + Node* Parser::parseSourceElement() { auto T0 = peekFirstTokenAfterModifiers(); - switch (T0->Type) { - case NodeType::LetKeyword: + switch (T0->getKind()) { + case NodeKind::LetKeyword: return parseLetDeclaration(); - case NodeType::IfKeyword: + case NodeKind::IfKeyword: return parseIfStatement(); + case NodeKind::ClassKeyword: + return parseClassDeclaration(); + case NodeKind::InstanceKeyword: + return parseInstanceDeclaration(); default: return parseExpressionStatement(); } @@ -400,7 +609,7 @@ after_params: std::vector Elements; for (;;) { auto T0 = Tokens.peek(); - if (T0->Type == NodeType::EndOfFile) { + if (T0->is()) { break; } Elements.push_back(parseSourceElement()); diff --git a/src/Scanner.cc b/src/Scanner.cc index 7943e71e5..a1d5a6528 100644 --- a/src/Scanner.cc +++ b/src/Scanner.cc @@ -3,6 +3,8 @@ #include "zen/config.hpp" +#include "llvm/Support/Casting.h" + #include "bolt/Text.hpp" #include "bolt/Integer.hpp" #include "bolt/CST.hpp" @@ -57,16 +59,18 @@ namespace bolt { return Chr - 48; } - std::unordered_map Keywords = { - { "pub", NodeType::PubKeyword }, - { "let", NodeType::LetKeyword }, - { "mut", NodeType::MutKeyword }, - { "return", NodeType::ReturnKeyword }, - { "type", NodeType::TypeKeyword }, - { "mod", NodeType::ModKeyword }, - { "if", NodeType::IfKeyword }, - { "else", NodeType::ElseKeyword }, - { "elif", NodeType::ElifKeyword }, + std::unordered_map Keywords = { + { "pub", NodeKind::PubKeyword }, + { "let", NodeKind::LetKeyword }, + { "mut", NodeKind::MutKeyword }, + { "return", NodeKind::ReturnKeyword }, + { "type", NodeKind::TypeKeyword }, + { "mod", NodeKind::ModKeyword }, + { "if", NodeKind::IfKeyword }, + { "else", NodeKind::ElseKeyword }, + { "elif", NodeKind::ElifKeyword }, + { "class", NodeKind::ClassKeyword }, + { "instance", NodeKind::InstanceKeyword }, }; Scanner::Scanner(TextFile& File, Stream& Chars): @@ -202,22 +206,26 @@ digit_finish: auto Match = Keywords.find(Text); if (Match != Keywords.end()) { switch (Match->second) { - case NodeType::PubKeyword: + case NodeKind::PubKeyword: return new PubKeyword(StartLoc); - case NodeType::LetKeyword: + case NodeKind::LetKeyword: return new LetKeyword(StartLoc); - case NodeType::MutKeyword: + case NodeKind::MutKeyword: return new MutKeyword(StartLoc); - case NodeType::TypeKeyword: + case NodeKind::TypeKeyword: return new TypeKeyword(StartLoc); - case NodeType::ReturnKeyword: + case NodeKind::ReturnKeyword: return new ReturnKeyword(StartLoc); - case NodeType::IfKeyword: + case NodeKind::IfKeyword: return new IfKeyword(StartLoc); - case NodeType::ElifKeyword: + case NodeKind::ElifKeyword: return new ElifKeyword(StartLoc); - case NodeType::ElseKeyword: + case NodeKind::ElseKeyword: return new ElseKeyword(StartLoc); + case NodeKind::ClassKeyword: + return new ClassKeyword(StartLoc); + case NodeKind::InstanceKeyword: + return new InstanceKeyword(StartLoc); default: ZEN_UNREACHABLE } @@ -305,6 +313,8 @@ after_string_contents: } if (Text == "->") { return new RArrow(StartLoc); + } else if (Text == "=>") { + return new RArrowAlt(StartLoc); } else if (Text == "=") { return new Equals(StartLoc); } else if (Text.back() == '=' && Text[Text.size()-2] != '=') { @@ -316,7 +326,7 @@ after_string_contents: #define BOLT_SIMPLE_TOKEN(ch, name) case ch: return new name(StartLoc); - //BOLT_SIMPLE_TOKEN(',', Comma) + BOLT_SIMPLE_TOKEN(',', Comma) BOLT_SIMPLE_TOKEN(':', Colon) BOLT_SIMPLE_TOKEN('(', LParen) BOLT_SIMPLE_TOKEN(')', RParen) @@ -324,6 +334,7 @@ after_string_contents: BOLT_SIMPLE_TOKEN(']', RBracket) BOLT_SIMPLE_TOKEN('{', LBrace) BOLT_SIMPLE_TOKEN('}', RBrace) + BOLT_SIMPLE_TOKEN('~', Tilde) default: throw UnexpectedStringDiagnostic(File, StartLoc, String { static_cast(C0) }); @@ -342,7 +353,7 @@ after_string_contents: auto T0 = Tokens.peek(); - if (T0->Type == NodeType::EndOfFile) { + if (llvm::isa(T0)) { if (Frames.size() == 1) { return T0; } @@ -366,7 +377,7 @@ after_string_contents: Locations.pop(); return new LineFoldEnd(T0->getStartLoc()); } - if (T0->Type == NodeType::Dot) { + if (llvm::isa(T0)) { auto T1 = Tokens.peek(1); if (T1->getStartLine() > T0->getEndLine()) { Tokens.get(); diff --git a/src/TestChecker.cc b/src/TestChecker.cc index 1a26e7066..663630477 100644 --- a/src/TestChecker.cc +++ b/src/TestChecker.cc @@ -16,18 +16,18 @@ auto checkExpression(std::string Input) { Scanner S(T, Chars); Punctuator PT(S); Parser P(T, PT); + LanguageConfig Config; auto SF = P.parseSourceFile(); - Checker C(DS); - auto Solution = C.check(SF); + Checker C(Config, DS); + C.check(SF); return std::make_tuple( static_cast(SF->Elements[0])->Expression, - C, - Solution + C ); } TEST(CheckerTest, InfersIntFromIntegerLiteral) { - auto [Expression, Checker, Solution] = checkExpression("1"); - ASSERT_EQ(Checker.getType(Expression, Solution), Checker.getIntType()); + auto [Expression, Checker] = checkExpression("1"); + ASSERT_EQ(Checker.getType(Expression), Checker.getIntType()); } diff --git a/src/main.cc b/src/main.cc index 9523f83c7..917ab0009 100644 --- a/src/main.cc +++ b/src/main.cc @@ -37,6 +37,7 @@ int main(int argc, const char* argv[]) { } ConsoleDiagnostics DE; + LanguageConfig Config; auto Text = readFile(argv[1]); TextFile File { argv[1], Text }; @@ -56,7 +57,7 @@ int main(int argc, const char* argv[]) { SF->setParents(); - Checker TheChecker { DE }; + Checker TheChecker { Config, DE }; TheChecker.check(SF); return 0;