diff --git a/bootstrap/cxx/include/bolt/ByteString.hpp b/bootstrap/cxx/include/bolt/ByteString.hpp index 7d40d45c9..671c279ea 100644 --- a/bootstrap/cxx/include/bolt/ByteString.hpp +++ b/bootstrap/cxx/include/bolt/ByteString.hpp @@ -6,9 +6,9 @@ namespace bolt { - using ByteString = std::string; +using ByteString = std::string; - using ByteStringView = std::string_view; +using ByteStringView = std::string_view; } diff --git a/bootstrap/cxx/include/bolt/CST.hpp b/bootstrap/cxx/include/bolt/CST.hpp index 06feeae6e..4f9295176 100644 --- a/bootstrap/cxx/include/bolt/CST.hpp +++ b/bootstrap/cxx/include/bolt/CST.hpp @@ -14,2553 +14,2553 @@ namespace bolt { - class Type; - class InferContext; +class Type; +class InferContext; - class Token; - class SourceFile; - class Scope; - class Pattern; - class Expression; - class Statement; +class Token; +class SourceFile; +class Scope; +class Pattern; +class Expression; +class Statement; - class TextLoc { - public: +class TextLoc { +public: - size_t Line = 1; - size_t Column = 1; + size_t Line = 1; + size_t Column = 1; - inline bool isEmpty() const noexcept { - return Line == 0 && Column == 0; - } + inline bool isEmpty() const noexcept { + return Line == 0 && Column == 0; + } - inline void advance(const ByteString& Text) { - for (auto Chr: Text) { - if (Chr == '\n') { - Line++; - Column = 1; - } else { - Column++; - } + inline void advance(const ByteString& Text) { + for (auto Chr: Text) { + if (Chr == '\n') { + Line++; + Column = 1; + } else { + Column++; } } - - inline TextLoc operator+(const ByteString& Text) const { - TextLoc Out { Line, Column }; - Out.advance(Text); - return Out; - } - - static TextLoc empty() { - return TextLoc { 0, 0 }; - } - - }; - - struct TextRange { - TextLoc Start; - TextLoc End; - }; - - class TextFile { - - ByteString Path; - ByteString Text; - - std::vector LineOffsets; - - public: - - TextFile(ByteString Path, ByteString Text); - - size_t getLine(size_t Offset) const; - size_t getColumn(size_t Offset) const; - size_t getStartOffsetOfLine(size_t Line) const; - size_t getEndOffsetOfLine(size_t Line) const; - - size_t getLineCount() const; - - ByteString getPath() const; - - ByteString getText() const; - - }; - - enum class NodeKind { - VBar, - Equals, - Colon, - Comma, - Dot, - DotDot, - Tilde, - At, - LParen, - RParen, - LBracket, - RBracket, - LBrace, - RBrace, - RArrow, - RArrowAlt, - LetKeyword, - MutKeyword, - PubKeyword, - ForeignKeyword, - TypeKeyword, - ReturnKeyword, - ModKeyword, - StructKeyword, - EnumKeyword, - ClassKeyword, - InstanceKeyword, - ElifKeyword, - IfKeyword, - ElseKeyword, - MatchKeyword, - Invalid, - EndOfFile, - BlockStart, - BlockEnd, - LineFoldEnd, - CustomOperator, - Assignment, - StringLiteral, - IntegerLiteral, - Identifier, - IdentifierAlt, - WrappedOperator, - ExpressionAnnotation, - TypeAssertAnnotation, - TypeclassConstraintExpression, - EqualityConstraintExpression, - RecordTypeExpressionField, - RecordTypeExpression, - QualifiedTypeExpression, - ReferenceTypeExpression, - ArrowTypeExpression, - AppTypeExpression, - VarTypeExpression, - NestedTypeExpression, - TupleTypeExpression, - BindPattern, - LiteralPattern, - RecordPatternField, - RecordPattern, - NamedRecordPattern, - NamedTuplePattern, - TuplePattern, - NestedPattern, - ListPattern, - ReferenceExpression, - MatchCase, - MatchExpression, - MemberExpression, - TupleExpression, - NestedExpression, - LiteralExpression, - CallExpression, - InfixExpression, - PrefixExpression, - RecordExpressionField, - RecordExpression, - ExpressionStatement, - ReturnStatement, - IfStatement, - IfStatementPart, - TypeAssert, - Parameter, - LetBlockBody, - LetExprBody, - LetDeclaration, - RecordDeclarationField, - RecordDeclaration, - VariantDeclaration, - TupleVariantDeclarationMember, - RecordVariantDeclarationMember, - ClassDeclaration, - InstanceDeclaration, - SourceFile, - }; - - struct SymbolPath { - std::vector Modules; - ByteString Name; - }; - - template - NodeKind getNodeType(); - - enum NodeFlags { - NodeFlags_TypeIsSolved = 1, - }; - - using NodeFlagsMask = unsigned; - - class Node; - - template - bool _is_helper(const Node* N) noexcept; - - class Node { - - unsigned RefCount = 1; - - const NodeKind Kind; - - public: - - NodeFlagsMask Flags = 0; - Node* Parent = nullptr; - - inline void ref() { - ++RefCount; - } - - void unref(); - - void setParents(); - - virtual Token* getFirstToken() const = 0; - virtual Token* getLastToken() const = 0; - - virtual std::size_t getStartLine() const; - virtual std::size_t getStartColumn() const; - virtual std::size_t getEndLine() const; - virtual std::size_t getEndColumn() const; - - inline NodeKind getKind() const noexcept { - return Kind; - } - - template - bool is() const noexcept { - return _is_helper(this); - } - - template - T* as() { - ZEN_ASSERT(is()); - return static_cast(this); - } - - virtual TextRange getRange() const; - - inline Node(NodeKind Type): - Kind(Type) {} - - const SourceFile* getSourceFile() const; - SourceFile* getSourceFile(); - - virtual Scope* getScope(); - - virtual ~Node() {} - - }; - - template - bool _is_helper(const Node* N) noexcept { - return N->getKind() == getNodeType(); } - template<> - inline bool _is_helper(const Node* N) noexcept { - return N->getKind() == NodeKind::ReferenceExpression - || N->getKind() == NodeKind::LiteralExpression - || N->getKind() == NodeKind::PrefixExpression - || N->getKind() == NodeKind::InfixExpression - || N->getKind() == NodeKind::CallExpression - || N->getKind() == NodeKind::NestedExpression; + inline TextLoc operator+(const ByteString& Text) const { + TextLoc Out { Line, Column }; + Out.advance(Text); + return Out; } - enum class SymbolKind { - Var, - Class, - Type, - Constructor, - }; - - class Scope { - - Node* Source; - std::unordered_multimap> Mapping; - - void addSymbol(ByteString Name, Node* Decl, SymbolKind Kind); - - void scan(Node* X); - void scanChild(Node* X); - - void visitPattern(Pattern* P, Node* ToInsert); - - public: - - Scope(Node* Source); - - /** - * Performs a direct lookup in this scope for the given symbol. - * - * This method will never traverse to parent scopes and will always return a - * symbol that belongs to this scope, if any is found. - * - * \returns nullptr when no such symbol could be found in this scope. - */ - Node* lookupDirect(SymbolPath Path, SymbolKind Kind = SymbolKind::Var); - - /** - * Find the symbol with the given name, either in this scope or in any of - * the parent ones. - * - * \returns nullptr when no such symbol could be found in any of the scopes. - */ - Node* lookup(SymbolPath Path, SymbolKind Kind = SymbolKind::Var); - - Scope* getParentScope(); - - }; - - class Token : public Node { - - TextLoc StartLoc; - - public: - - Token(NodeKind Type, TextLoc StartLoc): Node(Type), StartLoc(StartLoc) {} - - virtual std::string getText() const = 0; - - inline Token* getFirstToken() const override { - ZEN_UNREACHABLE - } - - inline Token* getLastToken() const override { - ZEN_UNREACHABLE - } - - inline TextLoc getStartLoc() const { - return StartLoc; - } - - TextLoc getEndLoc() const; - - inline size_t getStartLine() const override { - return StartLoc.Line; - } - - inline size_t getStartColumn() const override { - return StartLoc.Column; - } - - inline size_t getEndLine() const override { - return getEndLoc().Line; - } - - inline size_t getEndColumn() const override { - return getEndLoc().Column; - } - - TextRange getRange() const override { - return { getStartLoc(), getEndLoc() }; - } - - }; - - /// Any node that can be used as an operator - /// - /// This includes the following nodes: - /// - VBar - /// - CustomOperator - using Operator = Token; - - /// Any node that can be used as a kind of identifier. - /// - /// This includes the following nodes: - /// - Identifier - /// - IdentifierAlt - /// - WrappedOperator - using Symbol = Node; - - inline bool isSymbol(const Node* N) { - return N->getKind() == NodeKind::Identifier - || N->getKind() == NodeKind::IdentifierAlt - || N->getKind() == NodeKind::WrappedOperator; + static TextLoc empty() { + return TextLoc { 0, 0 }; } - /// Get the text that is actually represented by a symbol, without all the - /// syntactic sugar. - ByteString getCanonicalText(const Symbol* N); +}; - class Equals : public Token { - public: +struct TextRange { + TextLoc Start; + TextLoc End; +}; - inline Equals(TextLoc StartLoc): - Token(NodeKind::Equals, StartLoc) {} +class TextFile { - std::string getText() const override; + ByteString Path; + ByteString Text; - static bool classof(const Node* N) { - return N->getKind() == NodeKind::Equals; + std::vector LineOffsets; + +public: + + TextFile(ByteString Path, ByteString Text); + + size_t getLine(size_t Offset) const; + size_t getColumn(size_t Offset) const; + size_t getStartOffsetOfLine(size_t Line) const; + size_t getEndOffsetOfLine(size_t Line) const; + + size_t getLineCount() const; + + ByteString getPath() const; + + ByteString getText() const; + +}; + +enum class NodeKind { + VBar, + Equals, + Colon, + Comma, + Dot, + DotDot, + Tilde, + At, + LParen, + RParen, + LBracket, + RBracket, + LBrace, + RBrace, + RArrow, + RArrowAlt, + LetKeyword, + MutKeyword, + PubKeyword, + ForeignKeyword, + TypeKeyword, + ReturnKeyword, + ModKeyword, + StructKeyword, + EnumKeyword, + ClassKeyword, + InstanceKeyword, + ElifKeyword, + IfKeyword, + ElseKeyword, + MatchKeyword, + Invalid, + EndOfFile, + BlockStart, + BlockEnd, + LineFoldEnd, + CustomOperator, + Assignment, + StringLiteral, + IntegerLiteral, + Identifier, + IdentifierAlt, + WrappedOperator, + ExpressionAnnotation, + TypeAssertAnnotation, + TypeclassConstraintExpression, + EqualityConstraintExpression, + RecordTypeExpressionField, + RecordTypeExpression, + QualifiedTypeExpression, + ReferenceTypeExpression, + ArrowTypeExpression, + AppTypeExpression, + VarTypeExpression, + NestedTypeExpression, + TupleTypeExpression, + BindPattern, + LiteralPattern, + RecordPatternField, + RecordPattern, + NamedRecordPattern, + NamedTuplePattern, + TuplePattern, + NestedPattern, + ListPattern, + ReferenceExpression, + MatchCase, + MatchExpression, + MemberExpression, + TupleExpression, + NestedExpression, + LiteralExpression, + CallExpression, + InfixExpression, + PrefixExpression, + RecordExpressionField, + RecordExpression, + ExpressionStatement, + ReturnStatement, + IfStatement, + IfStatementPart, + TypeAssert, + Parameter, + LetBlockBody, + LetExprBody, + LetDeclaration, + RecordDeclarationField, + RecordDeclaration, + VariantDeclaration, + TupleVariantDeclarationMember, + RecordVariantDeclarationMember, + ClassDeclaration, + InstanceDeclaration, + SourceFile, +}; + +struct SymbolPath { + std::vector Modules; + ByteString Name; +}; + +template +NodeKind getNodeType(); + +enum NodeFlags { + NodeFlags_TypeIsSolved = 1, +}; + +using NodeFlagsMask = unsigned; + +class Node; + +template +bool _is_helper(const Node* N) noexcept; + +class Node { + + unsigned RefCount = 1; + + const NodeKind Kind; + +public: + + NodeFlagsMask Flags = 0; + Node* Parent = nullptr; + + inline void ref() { + ++RefCount; + } + + void unref(); + + void setParents(); + + virtual Token* getFirstToken() const = 0; + virtual Token* getLastToken() const = 0; + + virtual std::size_t getStartLine() const; + virtual std::size_t getStartColumn() const; + virtual std::size_t getEndLine() const; + virtual std::size_t getEndColumn() const; + + inline NodeKind getKind() const noexcept { + return Kind; + } + + template + bool is() const noexcept { + return _is_helper(this); + } + + template + T* as() { + ZEN_ASSERT(is()); + return static_cast(this); + } + + virtual TextRange getRange() const; + + inline Node(NodeKind Type): + Kind(Type) {} + + const SourceFile* getSourceFile() const; + SourceFile* getSourceFile(); + + virtual Scope* getScope(); + + virtual ~Node() {} + +}; + +template +bool _is_helper(const Node* N) noexcept { + return N->getKind() == getNodeType(); +} + +template<> +inline bool _is_helper(const Node* N) noexcept { + return N->getKind() == NodeKind::ReferenceExpression + || N->getKind() == NodeKind::LiteralExpression + || N->getKind() == NodeKind::PrefixExpression + || N->getKind() == NodeKind::InfixExpression + || N->getKind() == NodeKind::CallExpression + || N->getKind() == NodeKind::NestedExpression; +} + +enum class SymbolKind { + Var, + Class, + Type, + Constructor, +}; + +class Scope { + + Node* Source; + std::unordered_multimap> Mapping; + + void addSymbol(ByteString Name, Node* Decl, SymbolKind Kind); + + void scan(Node* X); + void scanChild(Node* X); + + void visitPattern(Pattern* P, Node* ToInsert); + +public: + + Scope(Node* Source); + + /** + * Performs a direct lookup in this scope for the given symbol. + * + * This method will never traverse to parent scopes and will always return a + * symbol that belongs to this scope, if any is found. + * + * \returns nullptr when no such symbol could be found in this scope. + */ + Node* lookupDirect(SymbolPath Path, SymbolKind Kind = SymbolKind::Var); + + /** + * Find the symbol with the given name, either in this scope or in any of + * the parent ones. + * + * \returns nullptr when no such symbol could be found in any of the scopes. + */ + Node* lookup(SymbolPath Path, SymbolKind Kind = SymbolKind::Var); + + Scope* getParentScope(); + +}; + +class Token : public Node { + + TextLoc StartLoc; + +public: + + Token(NodeKind Type, TextLoc StartLoc): Node(Type), StartLoc(StartLoc) {} + + virtual std::string getText() const = 0; + + inline Token* getFirstToken() const override { + ZEN_UNREACHABLE + } + + inline Token* getLastToken() const override { + ZEN_UNREACHABLE + } + + inline TextLoc getStartLoc() const { + return StartLoc; + } + + TextLoc getEndLoc() const; + + inline size_t getStartLine() const override { + return StartLoc.Line; + } + + inline size_t getStartColumn() const override { + return StartLoc.Column; + } + + inline size_t getEndLine() const override { + return getEndLoc().Line; + } + + inline size_t getEndColumn() const override { + return getEndLoc().Column; + } + + TextRange getRange() const override { + return { getStartLoc(), getEndLoc() }; + } + +}; + +/// Any node that can be used as an operator +/// +/// This includes the following nodes: +/// - VBar +/// - CustomOperator +using Operator = Token; + +/// Any node that can be used as a kind of identifier. +/// +/// This includes the following nodes: +/// - Identifier +/// - IdentifierAlt +/// - WrappedOperator +using Symbol = Node; + +inline bool isSymbol(const Node* N) { + return N->getKind() == NodeKind::Identifier + || N->getKind() == NodeKind::IdentifierAlt + || N->getKind() == NodeKind::WrappedOperator; +} + +/// Get the text that is actually represented by a symbol, without all the +/// syntactic sugar. +ByteString getCanonicalText(const Symbol* N); + +class Equals : public Token { +public: + + inline Equals(TextLoc StartLoc): + Token(NodeKind::Equals, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::Equals; + } + +}; + +class VBar : public Token { +public: + + inline VBar(TextLoc StartLoc): + Token(NodeKind::VBar, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::VBar; + } + +}; + +class Colon : public Token { +public: + + inline Colon(TextLoc StartLoc): + Token(NodeKind::Colon, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::Colon; + } + +}; + +class Comma : public Token { +public: + + inline Comma(TextLoc StartLoc): + Token(NodeKind::Comma, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::Comma; + } + +}; + +class Dot : public Token { +public: + + inline Dot(TextLoc StartLoc): + Token(NodeKind::Dot, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::Dot; + } + +}; + +class DotDot : public Token { +public: + + inline DotDot(TextLoc StartLoc): + Token(NodeKind::DotDot, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::DotDot; + } + +}; + +class Tilde : public Token { +public: + + inline Tilde(TextLoc StartLoc): + Token(NodeKind::Tilde, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::Tilde; + } + +}; + +class At : public Token { +public: + + inline At(TextLoc StartLoc): + Token(NodeKind::At, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::At; + } + +}; + +class LParen : public Token { +public: + + inline LParen(TextLoc StartLoc): + Token(NodeKind::LParen, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::LParen; + } + +}; + +class RParen : public Token { +public: + + inline RParen(TextLoc StartLoc): + Token(NodeKind::RParen, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::RParen; + } + +}; + +class LBracket : public Token { +public: + + inline LBracket(TextLoc StartLoc): + Token(NodeKind::LBracket, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::LBracket; + } + +}; + +class RBracket : public Token { +public: + + inline RBracket(TextLoc StartLoc): + Token(NodeKind::RBracket, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::RBracket; + } + +}; + +class LBrace : public Token { +public: + + inline LBrace(TextLoc StartLoc): + Token(NodeKind::LBrace, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::LBrace; + } + +}; + +class RBrace : public Token { +public: + + inline RBrace(TextLoc StartLoc): + Token(NodeKind::RBrace, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::RBrace; + } + +}; + +class RArrow : public Token { +public: + + inline RArrow(TextLoc StartLoc): + Token(NodeKind::RArrow, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::RArrow; + } + +}; + +class RArrowAlt : public Token { +public: + + inline RArrowAlt(TextLoc StartLoc): + Token(NodeKind::RArrowAlt, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::RArrowAlt; + } + +}; + +class LetKeyword : public Token { +public: + + inline LetKeyword(TextLoc StartLoc): + Token(NodeKind::LetKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::LetKeyword; + } + +}; + +class MutKeyword : public Token { +public: + + inline MutKeyword(TextLoc StartLoc): + Token(NodeKind::MutKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::MutKeyword; + } + +}; + +class PubKeyword : public Token { +public: + + inline PubKeyword(TextLoc StartLoc): + Token(NodeKind::PubKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::PubKeyword; + } + +}; + +class ForeignKeyword : public Token { +public: + + inline ForeignKeyword(TextLoc StartLoc): + Token(NodeKind::ForeignKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::ForeignKeyword; + } + +}; + +class TypeKeyword : public Token { +public: + + inline TypeKeyword(TextLoc StartLoc): + Token(NodeKind::TypeKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::TypeKeyword; + } + +}; + +class ReturnKeyword : public Token { +public: + + inline ReturnKeyword(TextLoc StartLoc): + Token(NodeKind::ReturnKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::ReturnKeyword; + } + +}; + +class ModKeyword : public Token { +public: + + inline ModKeyword(TextLoc StartLoc): + Token(NodeKind::ModKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::ModKeyword; + } + +}; + +class StructKeyword : public Token { +public: + + inline StructKeyword(TextLoc StartLoc): + Token(NodeKind::StructKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::StructKeyword; + } + +}; + +class EnumKeyword : public Token { +public: + + inline EnumKeyword(TextLoc StartLoc): + Token(NodeKind::EnumKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::EnumKeyword; + } + +}; + +class ClassKeyword : public Token { +public: + + inline ClassKeyword(TextLoc StartLoc): + Token(NodeKind::ClassKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::ClassKeyword; + } + +}; + +class InstanceKeyword : public Token { +public: + + inline InstanceKeyword(TextLoc StartLoc): + Token(NodeKind::InstanceKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::InstanceKeyword; + } + +}; + +class ElifKeyword : public Token { +public: + + inline ElifKeyword(TextLoc StartLoc): + Token(NodeKind::ElifKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::ElifKeyword; + } + +}; + +class IfKeyword : public Token { +public: + + inline IfKeyword(TextLoc StartLoc): + Token(NodeKind::IfKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::IfKeyword; + } + +}; + +class ElseKeyword : public Token { +public: + + inline ElseKeyword(TextLoc StartLoc): + Token(NodeKind::ElseKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::ElseKeyword; + } + +}; + +class MatchKeyword : public Token { +public: + + inline MatchKeyword(TextLoc StartLoc): + Token(NodeKind::MatchKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::MatchKeyword; + } + +}; + +class Invalid : public Token { +public: + + inline Invalid(TextLoc StartLoc): + Token(NodeKind::Invalid, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::Invalid; + } + +}; + +class EndOfFile : public Token { +public: + + inline EndOfFile(TextLoc StartLoc): + Token(NodeKind::EndOfFile, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::EndOfFile; + } + +}; + +class BlockStart : public Token { +public: + + inline BlockStart(TextLoc StartLoc): + Token(NodeKind::BlockStart, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::BlockStart; + } + +}; + +class BlockEnd : public Token { +public: + + inline BlockEnd(TextLoc StartLoc): + Token(NodeKind::BlockEnd, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::BlockEnd; + } + +}; + +class LineFoldEnd : public Token { +public: + + inline LineFoldEnd(TextLoc StartLoc): + Token(NodeKind::LineFoldEnd, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::LineFoldEnd; + } + +}; + +class CustomOperator : public Token { +public: + + ByteString Text; + + CustomOperator(ByteString Text, TextLoc StartLoc): + Token(NodeKind::CustomOperator, StartLoc), Text(Text) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::CustomOperator; + } + +}; + +class Assignment : public Token { +public: + + ByteString Text; + + Assignment(ByteString Text, TextLoc StartLoc): + Token(NodeKind::Assignment, StartLoc), Text(Text) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::Assignment; + } + +}; + +class Identifier : public Token { +public: + + ByteString Text; + + Identifier(ByteString Text, TextLoc StartLoc = TextLoc::empty()): + Token(NodeKind::Identifier, StartLoc), Text(Text) {} + + std::string getText() const override; + + bool isTypeVar() const; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::Identifier; + } + +}; + +class IdentifierAlt : public Token { +public: + + ByteString Text; + + IdentifierAlt(ByteString Text, TextLoc StartLoc): + Token(NodeKind::IdentifierAlt, StartLoc), Text(Text) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::IdentifierAlt; + } + +}; + +using LiteralValue = std::variant; + +class Literal : public Token { +public: + + inline Literal(NodeKind Kind, TextLoc StartLoc): + Token(Kind, StartLoc) {} + + virtual LiteralValue getValue() = 0; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::StringLiteral + || N->getKind() == NodeKind::IntegerLiteral; + } + +}; + +class StringLiteral : public Literal { +public: + + ByteString Text; + + StringLiteral(ByteString Text, TextLoc StartLoc): + Literal(NodeKind::StringLiteral, StartLoc), Text(Text) {} + + std::string getText() const override; + + LiteralValue getValue() override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::StringLiteral; + } + +}; + +class IntegerLiteral : public Literal { +public: + + Integer V; + + IntegerLiteral(Integer Value, TextLoc StartLoc): + Literal(NodeKind::IntegerLiteral, StartLoc), V(Value) {} + + std::string getText() const override; + + inline Integer getInteger() const noexcept { + return V; + } + + inline int asInt() const { + ZEN_ASSERT(V >= std::numeric_limits::min() && V <= std::numeric_limits::max()); + return V; + } + + LiteralValue getValue() override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::IntegerLiteral; + } + +}; + +class Annotation : public Node { +public: + + inline Annotation(NodeKind Kind): + Node(Kind) {} + +}; + +class AnnotationContainer { +public: + std::vector Annotations; +}; + +class ExpressionAnnotation : public Annotation { +public: + + class At* At; + class Expression* Expression; + + inline ExpressionAnnotation( + class At* At, + class Expression* Expression + ): Annotation(NodeKind::ExpressionAnnotation), + At(At), + Expression(Expression) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + inline class Expression* getExpression() const noexcept { + return Expression; + } + +}; + +class TypeExpression; + +class TypeAssertAnnotation : public Annotation { +public: + + class At* At; + class Colon* Colon; + TypeExpression* TE; + + inline TypeAssertAnnotation( + class At* At, + class Colon* Colon, + TypeExpression* TE + ): Annotation(NodeKind::TypeAssertAnnotation), + At(At), + Colon(Colon), + TE(TE) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + inline TypeExpression* getTypeExpression() const noexcept { + return TE; + } + +}; + +class TypedNode : public Node { +protected: + + Type* Ty; + + inline TypedNode(NodeKind Kind): + Node(Kind) {} + +public: + + inline void setType(Type* Ty2) { + Ty = Ty2; + } + + inline Type* getType() const noexcept { + ZEN_ASSERT(Ty != nullptr); + return Ty; + } + +}; + +class TypeExpression : public TypedNode, AnnotationContainer { +protected: + + inline TypeExpression(NodeKind Kind, std::vector Annotations = {}): + TypedNode(Kind), AnnotationContainer(Annotations) {} + +}; + +class ConstraintExpression : public Node { +public: + + inline ConstraintExpression(NodeKind Kind): + Node(Kind) {} + +}; + +class RecordTypeExpressionField : public Node { +public: + + Identifier* Name; + class Colon* Colon; + TypeExpression* TE; + + inline RecordTypeExpressionField( + Identifier* Name, + class Colon* Colon, + TypeExpression* TE + ): Node(NodeKind::RecordTypeExpressionField), + Name(Name), + Colon(Colon), + TE(TE) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class RecordTypeExpression : public TypeExpression { +public: + + class LBrace* LBrace; + std::vector> Fields; + class VBar* VBar; + TypeExpression* Rest; + class RBrace* RBrace; + + inline RecordTypeExpression( + class LBrace* LBrace, + std::vector> Fields, + class VBar* VBar, + TypeExpression* Rest, + class RBrace* RBrace + ): TypeExpression(NodeKind::RecordTypeExpression), + LBrace(LBrace), + Fields(Fields), + VBar(VBar), + Rest(Rest), + RBrace(RBrace) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class VarTypeExpression; + +class TypeclassConstraintExpression : public ConstraintExpression { +public: + + IdentifierAlt* Name; + std::vector TEs; + + TypeclassConstraintExpression( + IdentifierAlt* Name, + std::vector TEs + ): ConstraintExpression(NodeKind::TypeclassConstraintExpression), + Name(Name), + TEs(TEs) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::TypeclassConstraintExpression; + } + +}; + +class EqualityConstraintExpression : public ConstraintExpression { +public: + + TypeExpression* Left; + class Tilde* Tilde; + TypeExpression* Right; + + inline EqualityConstraintExpression( + TypeExpression* Left, + class Tilde* Tilde, + TypeExpression* Right + ): ConstraintExpression(NodeKind::EqualityConstraintExpression), + Left(Left), + Tilde(Tilde), + Right(Right) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::EqualityConstraintExpression; + } + +}; + +class QualifiedTypeExpression : public TypeExpression { +public: + + std::vector> Constraints; + class RArrowAlt* RArrowAlt; + TypeExpression* TE; + + QualifiedTypeExpression( + std::vector> Constraints, + class RArrowAlt* RArrowAlt, + TypeExpression* TE + ): TypeExpression(NodeKind::QualifiedTypeExpression), + Constraints(Constraints), + RArrowAlt(RArrowAlt), + TE(TE) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::QualifiedTypeExpression; + } + +}; + +class ReferenceTypeExpression : public TypeExpression { +public: + + std::vector> ModulePath; + IdentifierAlt* Name; + + ReferenceTypeExpression( + std::vector> ModulePath, + IdentifierAlt* Name + ): TypeExpression(NodeKind::ReferenceTypeExpression), + ModulePath(ModulePath), + Name(Name) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + SymbolPath getSymbolPath() const; + +}; + +class ArrowTypeExpression : public TypeExpression { +public: + + std::vector ParamTypes; + TypeExpression* ReturnType; + + inline ArrowTypeExpression( + std::vector ParamTypes, + TypeExpression* ReturnType + ): TypeExpression(NodeKind::ArrowTypeExpression), + ParamTypes(ParamTypes), + ReturnType(ReturnType) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class AppTypeExpression : public TypeExpression { +public: + + TypeExpression* Op; + std::vector Args; + + inline AppTypeExpression( + TypeExpression* Op, + std::vector Args + ): TypeExpression(NodeKind::AppTypeExpression), + Op(Op), + Args(Args) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class VarTypeExpression : public TypeExpression { +public: + + Identifier* Name; + + inline VarTypeExpression(Identifier* Name): + TypeExpression(NodeKind::VarTypeExpression), Name(Name) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class NestedTypeExpression : public TypeExpression { +public: + + class LParen* LParen; + TypeExpression* TE; + class RParen* RParen; + + inline NestedTypeExpression( + class LParen* LParen, + TypeExpression* TE, + class RParen* RParen + ): TypeExpression(NodeKind::NestedTypeExpression), + LParen(LParen), + TE(TE), + RParen(RParen) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class TupleTypeExpression : public TypeExpression { +public: + + class LParen* LParen; + std::vector> Elements; + class RParen* RParen; + + inline TupleTypeExpression( + class LParen* LParen, + std::vector> Elements, + class RParen* RParen + ): TypeExpression(NodeKind::TupleTypeExpression), + LParen(LParen), + Elements(Elements), + RParen(RParen) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class WrappedOperator : public Symbol { +public: + + class LParen* LParen; + Token* Op; + class RParen* RParen; + + WrappedOperator( + class LParen* LParen, + Token* Operator, + class RParen* RParen + ): Symbol(NodeKind::WrappedOperator), + LParen(LParen), + Op(Operator), + RParen(RParen) {} + + inline Token* getOperator() const { + return Op; + } + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + + +class Pattern : public Node { +protected: + + inline Pattern(NodeKind Type): + Node(Type) {} + +}; + +class BindPattern : public Pattern { +public: + + Symbol* Name; + + BindPattern( + Symbol* Name + ): Pattern(NodeKind::BindPattern), + Name(Name) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::BindPattern; + } + +}; + +class LiteralPattern : public Pattern { +public: + + class Literal* Literal; + + LiteralPattern(class Literal* Literal): + Pattern(NodeKind::LiteralPattern), + Literal(Literal) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::LiteralPattern; + } + +}; + +class RecordPatternField : public Node { +public: + + class DotDot* DotDot; + Identifier* Name; + class Equals* Equals; + class Pattern* Pattern; + + inline RecordPatternField( + class DotDot* DotDot, + Identifier* Name, + class Equals* Equals, + class Pattern* Pattern + ): Node(NodeKind::RecordPatternField), + DotDot(DotDot), + Name(Name), + Equals(Equals), + Pattern(Pattern) {} + + inline RecordPatternField( + Identifier* Name, + class Equals* Equals, + class Pattern* Pattern + ): RecordPatternField(nullptr, Name, Equals, Pattern) {} + + inline RecordPatternField( + class DotDot* DotDot + ): RecordPatternField(DotDot, nullptr, nullptr, nullptr) {} + + inline RecordPatternField( + class DotDot* DotDot, + class Pattern* Pattern + ): RecordPatternField(DotDot, nullptr, nullptr, Pattern) {} + + inline RecordPatternField( + Identifier* Name + ): RecordPatternField(nullptr, Name, nullptr, nullptr) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class RecordPattern : public Pattern { +public: + + class LBrace* LBrace; + std::vector> Fields; + class RBrace* RBrace; + + inline RecordPattern( + class LBrace* LBrace, + std::vector> Fields, + class RBrace* RBrace + ): Pattern(NodeKind::RecordPattern), + LBrace(LBrace), + Fields(Fields), + RBrace(RBrace) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class NamedRecordPattern : public Pattern { +public: + + std::vector> ModulePath; + IdentifierAlt* Name; + class LBrace* LBrace; + std::vector> Fields; + class RBrace* RBrace; + + inline NamedRecordPattern( + std::vector> ModulePath, + IdentifierAlt* Name, + class LBrace* LBrace, + std::vector> Fields, + class RBrace* RBrace + ): Pattern(NodeKind::NamedRecordPattern), + ModulePath(ModulePath), + Name(Name), + LBrace(LBrace), + Fields(Fields), + RBrace(RBrace) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class NamedTuplePattern : public Pattern { +public: + + IdentifierAlt* Name; + std::vector Patterns; + + inline NamedTuplePattern( + IdentifierAlt* Name, + std::vector Patterns + ): Pattern(NodeKind::NamedTuplePattern), + Name(Name), + Patterns(Patterns) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class TuplePattern : public Pattern { +public: + + class LParen* LParen; + std::vector> Elements; + class RParen* RParen; + + inline TuplePattern( + class LParen* LParen, + std::vector> Elements, + class RParen* RParen + ): Pattern(NodeKind::TuplePattern), + LParen(LParen), + Elements(Elements), + RParen(RParen) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class NestedPattern : public Pattern { +public: + + class LParen* LParen; + Pattern* P; + class RParen* RParen; + + inline NestedPattern( + class LParen* LParen, + Pattern* P, + class RParen* RParen + ): Pattern(NodeKind::NestedPattern), + LParen(LParen), + P(P), + RParen(RParen) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class ListPattern : public Pattern { +public: + + class LBracket* LBracket; + std::vector> Elements; + class RBracket* RBracket; + + inline ListPattern( + class LBracket* LBracket, + std::vector> Elements, + class RBracket* RBracket + ): Pattern(NodeKind::ListPattern), + LBracket(LBracket), + Elements(Elements), + RBracket(RBracket) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class Expression : public TypedNode, public AnnotationContainer { +protected: + + inline Expression(NodeKind Kind, std::vector Annotations = {}): + TypedNode(Kind), AnnotationContainer(Annotations) {} + +}; + +class ReferenceExpression : public Expression { +public: + + std::vector> ModulePath; + Symbol* Name; + + inline ReferenceExpression( + std::vector> ModulePath, + Symbol* Name + ): Expression(NodeKind::ReferenceExpression), + ModulePath(ModulePath), + Name(Name) {} + + inline ReferenceExpression( + std::vector Annotations, + std::vector> ModulePath, + Symbol* Name + ): Expression(NodeKind::ReferenceExpression, Annotations), + ModulePath(ModulePath), + Name(Name) {} + + inline ByteString getNameAsString() const noexcept { + return getCanonicalText(Name); + } + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + SymbolPath getSymbolPath() const; + +}; + +class MatchCase : public Node { + + Scope* TheScope = nullptr; + +public: + + InferContext* Ctx; + + class Pattern* Pattern; + class RArrowAlt* RArrowAlt; + class Expression* Expression; + + inline MatchCase( + class Pattern* Pattern, + class RArrowAlt* RArrowAlt, + class Expression* Expression + ): Node(NodeKind::MatchCase), + Pattern(Pattern), + RArrowAlt(RArrowAlt), + Expression(Expression) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + inline Scope* getScope() override { + if (TheScope == nullptr) { + TheScope = new Scope(this); } + return TheScope; + } + + +}; + +class MatchExpression : public Expression { +public: + + class MatchKeyword* MatchKeyword; + Expression* Value; + class BlockStart* BlockStart; + std::vector Cases; + + inline MatchExpression( + class MatchKeyword* MatchKeyword, + Expression* Value, + class BlockStart* BlockStart, + std::vector Cases + ): Expression(NodeKind::MatchExpression), + MatchKeyword(MatchKeyword), + Value(Value), + BlockStart(BlockStart), + Cases(Cases) {} + + inline MatchExpression( + std::vector Annotations, + class MatchKeyword* MatchKeyword, + Expression* Value, + class BlockStart* BlockStart, + std::vector Cases + ): Expression(NodeKind::MatchExpression, Annotations), + MatchKeyword(MatchKeyword), + Value(Value), + BlockStart(BlockStart), + Cases(Cases) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class MemberExpression : public Expression { +public: + + Expression* E; + class Dot* Dot; + Token* Name; + + inline MemberExpression( + Expression* E, + class Dot* Dot, + Token* Name + ): Expression(NodeKind::MemberExpression), + E(E), + Dot(Dot), + Name(Name) {} + + inline MemberExpression( + std::vector Annotations, + class Expression* E, + class Dot* Dot, + Token* Name + ): Expression(NodeKind::MemberExpression, Annotations), + E(E), + Dot(Dot), + Name(Name) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + inline Expression* getExpression() const { + return E; + } + +}; + +class TupleExpression : public Expression { +public: + + class LParen* LParen; + std::vector> Elements; + class RParen* RParen; + + inline TupleExpression( + class LParen* LParen, + std::vector> Elements, + class RParen* RParen + ): Expression(NodeKind::TupleExpression), + LParen(LParen), + Elements(Elements), + RParen(RParen) {} + + inline TupleExpression( + std::vector Annotations, + class LParen* LParen, + std::vector> Elements, + class RParen* RParen + ): Expression(NodeKind::TupleExpression, Annotations), + LParen(LParen), + Elements(Elements), + RParen(RParen) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class NestedExpression : public Expression { +public: + + class LParen* LParen; + Expression* Inner; + class RParen* RParen; + + inline NestedExpression( + class LParen* LParen, + Expression* Inner, + class RParen* RParen + ): Expression(NodeKind::NestedExpression), + LParen(LParen), + Inner(Inner), + RParen(RParen) {} + + inline NestedExpression( + std::vector Annotations, + class LParen* LParen, + Expression* Inner, + class RParen* RParen + ): Expression(NodeKind::NestedExpression, Annotations), + LParen(LParen), + Inner(Inner), + RParen(RParen) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class LiteralExpression : public Expression { +public: + + Literal* Token; + + LiteralExpression( + Literal* Token + ): Expression(NodeKind::LiteralExpression), + Token(Token) {} + + LiteralExpression( + std::vector Annotations, + Literal* Token + ): Expression(NodeKind::LiteralExpression, Annotations), + Token(Token) {} + + inline ByteString getAsText() { + ZEN_ASSERT(Token->getKind() == NodeKind::StringLiteral); + return static_cast(Token)->Text; + } + + inline int getAsInt() { + ZEN_ASSERT(Token->getKind() == NodeKind::IntegerLiteral); + return static_cast(Token)->asInt(); + } + + class Token* getFirstToken() const override; + class Token* getLastToken() const override; + +}; + +class CallExpression : public Expression { +public: + + Expression* Function; + std::vector Args; + + inline CallExpression( + Expression* Function, + std::vector Args + ): Expression(NodeKind::CallExpression), + Function(Function), + Args(Args) {} + + inline CallExpression( + std::vector Annotations, + Expression* Function, + std::vector Args + ): Expression(NodeKind::CallExpression, Annotations), + Function(Function), + Args(Args) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class InfixExpression : public Expression { +public: + + Expression* Left; + Token* Operator; + Expression* Right; + + inline InfixExpression( + Expression* Left, + Token* Operator, + Expression* Right + ): Expression(NodeKind::InfixExpression), + Left(Left), + Operator(Operator), + Right(Right) {} + + inline InfixExpression( + std::vector Annotations, + Expression* Left, + Token* Operator, + Expression* Right + ): Expression(NodeKind::InfixExpression, Annotations), + Left(Left), + Operator(Operator), + Right(Right) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class PrefixExpression : public Expression { +public: + + Token* Operator; + Expression* Argument; + + PrefixExpression( + Token* Operator, + Expression* Argument + ): Expression(NodeKind::PrefixExpression), + Operator(Operator), + Argument(Argument) {} + + PrefixExpression( + std::vector Annotations, + Token* Operator, + Expression* Argument + ): Expression(NodeKind::PrefixExpression, Annotations), + Operator(Operator), + Argument(Argument) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class RecordExpressionField : public Node { +public: + + Identifier* Name; + class Equals* Equals; + Expression* E; + + inline RecordExpressionField( + Identifier* Name, + class Equals* Equals, + Expression* E + ): Node(NodeKind::RecordExpressionField), + Name(Name), + Equals(Equals), + E(E) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + inline Expression* getExpression() const { + return E; + } + +}; + +class RecordExpression : public Expression { +public: + + class LBrace* LBrace; + std::vector> Fields; + class RBrace* RBrace; + + inline RecordExpression( + class LBrace* LBrace, + std::vector> Fields, + class RBrace* RBrace + ): Expression(NodeKind::RecordExpression), + LBrace(LBrace), + Fields(Fields), + RBrace(RBrace) {} + + inline RecordExpression( + std::vector Annotations, + class LBrace* LBrace, + std::vector> Fields, + class RBrace* RBrace + ): Expression(NodeKind::RecordExpression, Annotations), + LBrace(LBrace), + Fields(Fields), + RBrace(RBrace) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class Statement : public Node, public AnnotationContainer { +protected: + + inline Statement(NodeKind Type, std::vector Annotations = {}): + Node(Type), AnnotationContainer(Annotations) {} + +}; + +class ExpressionStatement : public Statement { +public: + + class Expression* Expression; + + ExpressionStatement(class Expression* Expression): + Statement(NodeKind::ExpressionStatement), Expression(Expression) {} + + ExpressionStatement( + std::vector Annotations, + class Expression* Expression + ): Statement(NodeKind::ExpressionStatement, Annotations), + Expression(Expression) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class IfStatementPart : public Node, public AnnotationContainer { +public: + + Token* Keyword; + Expression* Test; + class BlockStart* BlockStart; + std::vector Elements; + + inline IfStatementPart( + Token* Keyword, + Expression* Test, + class BlockStart* BlockStart, + std::vector Elements + ): Node(NodeKind::IfStatementPart), + Keyword(Keyword), + Test(Test), + BlockStart(BlockStart), + Elements(Elements) {} - }; + inline IfStatementPart( + std::vector Annotations, + Token* Keyword, + Expression* Test, + class BlockStart* BlockStart, + std::vector Elements + ): Node(NodeKind::IfStatementPart), + AnnotationContainer(Annotations), + Keyword(Keyword), + Test(Test), + BlockStart(BlockStart), + Elements(Elements) {} - class VBar : public Token { - public: + Token* getFirstToken() const override; + Token* getLastToken() const override; - inline VBar(TextLoc StartLoc): - Token(NodeKind::VBar, StartLoc) {} +}; - std::string getText() const override; +class IfStatement : public Statement { +public: - static bool classof(const Node* N) { - return N->getKind() == NodeKind::VBar; - } + std::vector Parts; - }; + inline IfStatement(std::vector Parts): + Statement(NodeKind::IfStatement), Parts(Parts) {} - class Colon : public Token { - public: + Token* getFirstToken() const override; + Token* getLastToken() const override; - inline Colon(TextLoc StartLoc): - Token(NodeKind::Colon, StartLoc) {} +}; - std::string getText() const override; +class ReturnStatement : public Statement { +public: - static bool classof(const Node* N) { - return N->getKind() == NodeKind::Colon; - } + class ReturnKeyword* ReturnKeyword; + class Expression* Expression; - }; + ReturnStatement( + class ReturnKeyword* ReturnKeyword, + class Expression* Expression + ): Statement(NodeKind::ReturnStatement), + ReturnKeyword(ReturnKeyword), + Expression(Expression) {} - class Comma : public Token { - public: + ReturnStatement( + std::vector Annotations, + class ReturnKeyword* ReturnKeyword, + class Expression* Expression + ): Statement(NodeKind::ReturnStatement, Annotations), + ReturnKeyword(ReturnKeyword), + Expression(Expression) {} - inline Comma(TextLoc StartLoc): - Token(NodeKind::Comma, StartLoc) {} + Token* getFirstToken() const override; + Token* getLastToken() const override; - std::string getText() const override; +}; - static bool classof(const Node* N) { - return N->getKind() == NodeKind::Comma; - } +class TypeAssert : public Node { +public: - }; + class Colon* Colon; + class TypeExpression* TypeExpression; - class Dot : public Token { - public: - - inline Dot(TextLoc StartLoc): - Token(NodeKind::Dot, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::Dot; - } - - }; - - class DotDot : public Token { - public: - - inline DotDot(TextLoc StartLoc): - Token(NodeKind::DotDot, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::DotDot; - } - - }; - - class Tilde : public Token { - public: - - inline Tilde(TextLoc StartLoc): - Token(NodeKind::Tilde, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::Tilde; - } - - }; - - class At : public Token { - public: - - inline At(TextLoc StartLoc): - Token(NodeKind::At, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::At; - } - - }; - - class LParen : public Token { - public: - - inline LParen(TextLoc StartLoc): - Token(NodeKind::LParen, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::LParen; - } - - }; - - class RParen : public Token { - public: - - inline RParen(TextLoc StartLoc): - Token(NodeKind::RParen, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::RParen; - } - - }; - - class LBracket : public Token { - public: - - inline LBracket(TextLoc StartLoc): - Token(NodeKind::LBracket, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::LBracket; - } - - }; - - class RBracket : public Token { - public: - - inline RBracket(TextLoc StartLoc): - Token(NodeKind::RBracket, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::RBracket; - } - - }; - - class LBrace : public Token { - public: - - inline LBrace(TextLoc StartLoc): - Token(NodeKind::LBrace, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::LBrace; - } - - }; - - class RBrace : public Token { - public: - - inline RBrace(TextLoc StartLoc): - Token(NodeKind::RBrace, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::RBrace; - } - - }; - - class RArrow : public Token { - public: - - inline RArrow(TextLoc StartLoc): - Token(NodeKind::RArrow, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::RArrow; - } - - }; - - class RArrowAlt : public Token { - public: - - inline RArrowAlt(TextLoc StartLoc): - Token(NodeKind::RArrowAlt, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::RArrowAlt; - } - - }; - - class LetKeyword : public Token { - public: - - inline LetKeyword(TextLoc StartLoc): - Token(NodeKind::LetKeyword, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::LetKeyword; - } - - }; - - class MutKeyword : public Token { - public: - - inline MutKeyword(TextLoc StartLoc): - Token(NodeKind::MutKeyword, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::MutKeyword; - } - - }; - - class PubKeyword : public Token { - public: - - inline PubKeyword(TextLoc StartLoc): - Token(NodeKind::PubKeyword, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::PubKeyword; - } - - }; - - class ForeignKeyword : public Token { - public: - - inline ForeignKeyword(TextLoc StartLoc): - Token(NodeKind::ForeignKeyword, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::ForeignKeyword; - } - - }; - - class TypeKeyword : public Token { - public: - - inline TypeKeyword(TextLoc StartLoc): - Token(NodeKind::TypeKeyword, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::TypeKeyword; - } - - }; - - class ReturnKeyword : public Token { - public: - - inline ReturnKeyword(TextLoc StartLoc): - Token(NodeKind::ReturnKeyword, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::ReturnKeyword; - } - - }; - - class ModKeyword : public Token { - public: - - inline ModKeyword(TextLoc StartLoc): - Token(NodeKind::ModKeyword, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::ModKeyword; - } - - }; - - class StructKeyword : public Token { - public: - - inline StructKeyword(TextLoc StartLoc): - Token(NodeKind::StructKeyword, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::StructKeyword; - } - - }; - - class EnumKeyword : public Token { - public: - - inline EnumKeyword(TextLoc StartLoc): - Token(NodeKind::EnumKeyword, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::EnumKeyword; - } - - }; - - class ClassKeyword : public Token { - public: - - inline ClassKeyword(TextLoc StartLoc): - Token(NodeKind::ClassKeyword, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::ClassKeyword; - } - - }; - - class InstanceKeyword : public Token { - public: - - inline InstanceKeyword(TextLoc StartLoc): - Token(NodeKind::InstanceKeyword, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::InstanceKeyword; - } - - }; - - class ElifKeyword : public Token { - public: - - inline ElifKeyword(TextLoc StartLoc): - Token(NodeKind::ElifKeyword, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::ElifKeyword; - } - - }; - - class IfKeyword : public Token { - public: - - inline IfKeyword(TextLoc StartLoc): - Token(NodeKind::IfKeyword, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::IfKeyword; - } - - }; - - class ElseKeyword : public Token { - public: - - inline ElseKeyword(TextLoc StartLoc): - Token(NodeKind::ElseKeyword, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::ElseKeyword; - } - - }; - - class MatchKeyword : public Token { - public: - - inline MatchKeyword(TextLoc StartLoc): - Token(NodeKind::MatchKeyword, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::MatchKeyword; - } - - }; - - class Invalid : public Token { - public: - - inline Invalid(TextLoc StartLoc): - Token(NodeKind::Invalid, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::Invalid; - } - - }; - - class EndOfFile : public Token { - public: - - inline EndOfFile(TextLoc StartLoc): - Token(NodeKind::EndOfFile, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::EndOfFile; - } - - }; - - class BlockStart : public Token { - public: - - inline BlockStart(TextLoc StartLoc): - Token(NodeKind::BlockStart, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::BlockStart; - } - - }; - - class BlockEnd : public Token { - public: - - inline BlockEnd(TextLoc StartLoc): - Token(NodeKind::BlockEnd, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::BlockEnd; - } - - }; - - class LineFoldEnd : public Token { - public: - - inline LineFoldEnd(TextLoc StartLoc): - Token(NodeKind::LineFoldEnd, StartLoc) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::LineFoldEnd; - } - - }; - - class CustomOperator : public Token { - public: - - ByteString Text; - - CustomOperator(ByteString Text, TextLoc StartLoc): - Token(NodeKind::CustomOperator, StartLoc), Text(Text) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::CustomOperator; - } - - }; - - class Assignment : public Token { - public: - - ByteString Text; - - Assignment(ByteString Text, TextLoc StartLoc): - Token(NodeKind::Assignment, StartLoc), Text(Text) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::Assignment; - } - - }; - - class Identifier : public Token { - public: - - ByteString Text; - - Identifier(ByteString Text, TextLoc StartLoc = TextLoc::empty()): - Token(NodeKind::Identifier, StartLoc), Text(Text) {} - - std::string getText() const override; - - bool isTypeVar() const; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::Identifier; - } - - }; - - class IdentifierAlt : public Token { - public: - - ByteString Text; - - IdentifierAlt(ByteString Text, TextLoc StartLoc): - Token(NodeKind::IdentifierAlt, StartLoc), Text(Text) {} - - std::string getText() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::IdentifierAlt; - } - - }; - - using LiteralValue = std::variant; - - class Literal : public Token { - public: - - inline Literal(NodeKind Kind, TextLoc StartLoc): - Token(Kind, StartLoc) {} - - virtual LiteralValue getValue() = 0; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::StringLiteral - || N->getKind() == NodeKind::IntegerLiteral; - } - - }; - - class StringLiteral : public Literal { - public: - - ByteString Text; - - StringLiteral(ByteString Text, TextLoc StartLoc): - Literal(NodeKind::StringLiteral, StartLoc), Text(Text) {} - - std::string getText() const override; - - LiteralValue getValue() override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::StringLiteral; - } - - }; - - class IntegerLiteral : public Literal { - public: - - Integer V; - - IntegerLiteral(Integer Value, TextLoc StartLoc): - Literal(NodeKind::IntegerLiteral, StartLoc), V(Value) {} - - std::string getText() const override; - - inline Integer getInteger() const noexcept { - return V; - } - - inline int asInt() const { - ZEN_ASSERT(V >= std::numeric_limits::min() && V <= std::numeric_limits::max()); - return V; - } - - LiteralValue getValue() override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::IntegerLiteral; - } - - }; - - class Annotation : public Node { - public: - - inline Annotation(NodeKind Kind): - Node(Kind) {} - - }; - - class AnnotationContainer { - public: - std::vector Annotations; - }; - - class ExpressionAnnotation : public Annotation { - public: - - class At* At; - class Expression* Expression; - - inline ExpressionAnnotation( - class At* At, - class Expression* Expression - ): Annotation(NodeKind::ExpressionAnnotation), - At(At), - Expression(Expression) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - inline class Expression* getExpression() const noexcept { - return Expression; - } - - }; - - class TypeExpression; - - class TypeAssertAnnotation : public Annotation { - public: - - class At* At; - class Colon* Colon; - TypeExpression* TE; - - inline TypeAssertAnnotation( - class At* At, - class Colon* Colon, - TypeExpression* TE - ): Annotation(NodeKind::TypeAssertAnnotation), - At(At), - Colon(Colon), - TE(TE) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - inline TypeExpression* getTypeExpression() const noexcept { - return TE; - } - - }; - - class TypedNode : public Node { - protected: - - Type* Ty; - - inline TypedNode(NodeKind Kind): - Node(Kind) {} - - public: - - inline void setType(Type* Ty2) { - Ty = Ty2; - } - - inline Type* getType() const noexcept { - ZEN_ASSERT(Ty != nullptr); - return Ty; - } - - }; - - class TypeExpression : public TypedNode, AnnotationContainer { - protected: - - inline TypeExpression(NodeKind Kind, std::vector Annotations = {}): - TypedNode(Kind), AnnotationContainer(Annotations) {} - - }; - - class ConstraintExpression : public Node { - public: - - inline ConstraintExpression(NodeKind Kind): - Node(Kind) {} - - }; - - class RecordTypeExpressionField : public Node { - public: - - Identifier* Name; - class Colon* Colon; - TypeExpression* TE; - - inline RecordTypeExpressionField( - Identifier* Name, - class Colon* Colon, - TypeExpression* TE - ): Node(NodeKind::RecordTypeExpressionField), - Name(Name), - Colon(Colon), - TE(TE) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class RecordTypeExpression : public TypeExpression { - public: - - class LBrace* LBrace; - std::vector> Fields; - class VBar* VBar; - TypeExpression* Rest; - class RBrace* RBrace; - - inline RecordTypeExpression( - class LBrace* LBrace, - std::vector> Fields, - class VBar* VBar, - TypeExpression* Rest, - class RBrace* RBrace - ): TypeExpression(NodeKind::RecordTypeExpression), - LBrace(LBrace), - Fields(Fields), - VBar(VBar), - Rest(Rest), - RBrace(RBrace) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class VarTypeExpression; - - class TypeclassConstraintExpression : public ConstraintExpression { - public: - - IdentifierAlt* Name; - std::vector TEs; - - TypeclassConstraintExpression( - IdentifierAlt* Name, - std::vector TEs - ): ConstraintExpression(NodeKind::TypeclassConstraintExpression), - Name(Name), - TEs(TEs) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::TypeclassConstraintExpression; - } - - }; - - class EqualityConstraintExpression : public ConstraintExpression { - public: - - TypeExpression* Left; - class Tilde* Tilde; - TypeExpression* Right; - - inline EqualityConstraintExpression( - TypeExpression* Left, - class Tilde* Tilde, - TypeExpression* Right - ): ConstraintExpression(NodeKind::EqualityConstraintExpression), - Left(Left), - Tilde(Tilde), - Right(Right) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::EqualityConstraintExpression; - } - - }; - - class QualifiedTypeExpression : public TypeExpression { - public: - - std::vector> Constraints; - class RArrowAlt* RArrowAlt; - TypeExpression* TE; - - QualifiedTypeExpression( - std::vector> Constraints, - class RArrowAlt* RArrowAlt, - TypeExpression* TE - ): TypeExpression(NodeKind::QualifiedTypeExpression), - Constraints(Constraints), - RArrowAlt(RArrowAlt), - TE(TE) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::QualifiedTypeExpression; - } - - }; - - class ReferenceTypeExpression : public TypeExpression { - public: - - std::vector> ModulePath; - IdentifierAlt* Name; - - ReferenceTypeExpression( - std::vector> ModulePath, - IdentifierAlt* Name - ): TypeExpression(NodeKind::ReferenceTypeExpression), - ModulePath(ModulePath), - Name(Name) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - SymbolPath getSymbolPath() const; - - }; - - class ArrowTypeExpression : public TypeExpression { - public: - - std::vector ParamTypes; - TypeExpression* ReturnType; - - inline ArrowTypeExpression( - std::vector ParamTypes, - TypeExpression* ReturnType - ): TypeExpression(NodeKind::ArrowTypeExpression), - ParamTypes(ParamTypes), - ReturnType(ReturnType) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class AppTypeExpression : public TypeExpression { - public: - - TypeExpression* Op; - std::vector Args; - - inline AppTypeExpression( - TypeExpression* Op, - std::vector Args - ): TypeExpression(NodeKind::AppTypeExpression), - Op(Op), - Args(Args) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class VarTypeExpression : public TypeExpression { - public: - - Identifier* Name; - - inline VarTypeExpression(Identifier* Name): - TypeExpression(NodeKind::VarTypeExpression), Name(Name) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class NestedTypeExpression : public TypeExpression { - public: - - class LParen* LParen; - TypeExpression* TE; - class RParen* RParen; - - inline NestedTypeExpression( - class LParen* LParen, - TypeExpression* TE, - class RParen* RParen - ): TypeExpression(NodeKind::NestedTypeExpression), - LParen(LParen), - TE(TE), - RParen(RParen) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class TupleTypeExpression : public TypeExpression { - public: - - class LParen* LParen; - std::vector> Elements; - class RParen* RParen; - - inline TupleTypeExpression( - class LParen* LParen, - std::vector> Elements, - class RParen* RParen - ): TypeExpression(NodeKind::TupleTypeExpression), - LParen(LParen), - Elements(Elements), - RParen(RParen) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class WrappedOperator : public Symbol { - public: - - class LParen* LParen; - Token* Op; - class RParen* RParen; - - WrappedOperator( - class LParen* LParen, - Token* Operator, - class RParen* RParen - ): Symbol(NodeKind::WrappedOperator), - LParen(LParen), - Op(Operator), - RParen(RParen) {} - - inline Token* getOperator() const { - return Op; - } - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - - class Pattern : public Node { - protected: - - inline Pattern(NodeKind Type): - Node(Type) {} - - }; - - class BindPattern : public Pattern { - public: - - Symbol* Name; - - BindPattern( - Symbol* Name - ): Pattern(NodeKind::BindPattern), - Name(Name) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::BindPattern; - } - - }; - - class LiteralPattern : public Pattern { - public: - - class Literal* Literal; - - LiteralPattern(class Literal* Literal): - Pattern(NodeKind::LiteralPattern), - Literal(Literal) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::LiteralPattern; - } - - }; - - class RecordPatternField : public Node { - public: - - class DotDot* DotDot; - Identifier* Name; - class Equals* Equals; - class Pattern* Pattern; - - inline RecordPatternField( - class DotDot* DotDot, - Identifier* Name, - class Equals* Equals, - class Pattern* Pattern - ): Node(NodeKind::RecordPatternField), - DotDot(DotDot), - Name(Name), - Equals(Equals), - Pattern(Pattern) {} - - inline RecordPatternField( - Identifier* Name, - class Equals* Equals, - class Pattern* Pattern - ): RecordPatternField(nullptr, Name, Equals, Pattern) {} - - inline RecordPatternField( - class DotDot* DotDot - ): RecordPatternField(DotDot, nullptr, nullptr, nullptr) {} - - inline RecordPatternField( - class DotDot* DotDot, - class Pattern* Pattern - ): RecordPatternField(DotDot, nullptr, nullptr, Pattern) {} - - inline RecordPatternField( - Identifier* Name - ): RecordPatternField(nullptr, Name, nullptr, nullptr) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class RecordPattern : public Pattern { - public: - - class LBrace* LBrace; - std::vector> Fields; - class RBrace* RBrace; - - inline RecordPattern( - class LBrace* LBrace, - std::vector> Fields, - class RBrace* RBrace - ): Pattern(NodeKind::RecordPattern), - LBrace(LBrace), - Fields(Fields), - RBrace(RBrace) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class NamedRecordPattern : public Pattern { - public: - - std::vector> ModulePath; - IdentifierAlt* Name; - class LBrace* LBrace; - std::vector> Fields; - class RBrace* RBrace; - - inline NamedRecordPattern( - std::vector> ModulePath, - IdentifierAlt* Name, - class LBrace* LBrace, - std::vector> Fields, - class RBrace* RBrace - ): Pattern(NodeKind::NamedRecordPattern), - ModulePath(ModulePath), - Name(Name), - LBrace(LBrace), - Fields(Fields), - RBrace(RBrace) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class NamedTuplePattern : public Pattern { - public: - - IdentifierAlt* Name; - std::vector Patterns; - - inline NamedTuplePattern( - IdentifierAlt* Name, - std::vector Patterns - ): Pattern(NodeKind::NamedTuplePattern), - Name(Name), - Patterns(Patterns) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class TuplePattern : public Pattern { - public: - - class LParen* LParen; - std::vector> Elements; - class RParen* RParen; - - inline TuplePattern( - class LParen* LParen, - std::vector> Elements, - class RParen* RParen - ): Pattern(NodeKind::TuplePattern), - LParen(LParen), - Elements(Elements), - RParen(RParen) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class NestedPattern : public Pattern { - public: - - class LParen* LParen; - Pattern* P; - class RParen* RParen; - - inline NestedPattern( - class LParen* LParen, - Pattern* P, - class RParen* RParen - ): Pattern(NodeKind::NestedPattern), - LParen(LParen), - P(P), - RParen(RParen) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class ListPattern : public Pattern { - public: - - class LBracket* LBracket; - std::vector> Elements; - class RBracket* RBracket; - - inline ListPattern( - class LBracket* LBracket, - std::vector> Elements, - class RBracket* RBracket - ): Pattern(NodeKind::ListPattern), - LBracket(LBracket), - Elements(Elements), - RBracket(RBracket) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class Expression : public TypedNode, public AnnotationContainer { - protected: - - inline Expression(NodeKind Kind, std::vector Annotations = {}): - TypedNode(Kind), AnnotationContainer(Annotations) {} - - }; - - class ReferenceExpression : public Expression { - public: - - std::vector> ModulePath; - Symbol* Name; - - inline ReferenceExpression( - std::vector> ModulePath, - Symbol* Name - ): Expression(NodeKind::ReferenceExpression), - ModulePath(ModulePath), - Name(Name) {} - - inline ReferenceExpression( - std::vector Annotations, - std::vector> ModulePath, - Symbol* Name - ): Expression(NodeKind::ReferenceExpression, Annotations), - ModulePath(ModulePath), - Name(Name) {} - - inline ByteString getNameAsString() const noexcept { - return getCanonicalText(Name); - } - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - SymbolPath getSymbolPath() const; - - }; - - class MatchCase : public Node { - - Scope* TheScope = nullptr; - - public: - - InferContext* Ctx; - - class Pattern* Pattern; - class RArrowAlt* RArrowAlt; - class Expression* Expression; - - inline MatchCase( - class Pattern* Pattern, - class RArrowAlt* RArrowAlt, - class Expression* Expression - ): Node(NodeKind::MatchCase), - Pattern(Pattern), - RArrowAlt(RArrowAlt), - Expression(Expression) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - inline Scope* getScope() override { + TypeAssert( + class Colon* Colon, + class TypeExpression* TypeExpression + ): Node(NodeKind::TypeAssert), + Colon(Colon), + TypeExpression(TypeExpression) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class Parameter : public Node { +public: + + Parameter( + class Pattern* Pattern, + class TypeAssert* TypeAssert + ): Node(NodeKind::Parameter), + Pattern(Pattern), + TypeAssert(TypeAssert) {} + + class Pattern* Pattern; + class TypeAssert* TypeAssert; + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class LetBody : public Node { +public: + + LetBody(NodeKind Type): Node(Type) {} + +}; + +class LetBlockBody : public LetBody { +public: + + class BlockStart* BlockStart; + std::vector Elements; + + LetBlockBody( + class BlockStart* BlockStart, + std::vector Elements + ): LetBody(NodeKind::LetBlockBody), + BlockStart(BlockStart), + Elements(Elements) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class LetExprBody : public LetBody { +public: + + class Equals* Equals; + class Expression* Expression; + + LetExprBody( + class Equals* Equals, + class Expression* Expression + ): LetBody(NodeKind::LetExprBody), + Equals(Equals), + Expression(Expression) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class LetDeclaration : public TypedNode, public AnnotationContainer { + + Scope* TheScope = nullptr; + +public: + + bool IsCycleActive = false; + bool Visited = false; + InferContext* Ctx; + + class PubKeyword* PubKeyword; + class ForeignKeyword* ForeignKeyword; + class LetKeyword* LetKeyword; + class MutKeyword* MutKeyword; + class Pattern* Pattern; + std::vector Params; + class TypeAssert* TypeAssert; + LetBody* Body; + + LetDeclaration( + class PubKeyword* PubKeyword, + class ForeignKeyword* ForeignKeyword, + class LetKeyword* LetKeyword, + class MutKeyword* MutKeyword, + class Pattern* Pattern, + std::vector Params, + class TypeAssert* TypeAssert, + LetBody* Body + ): TypedNode(NodeKind::LetDeclaration), + PubKeyword(PubKeyword), + ForeignKeyword(ForeignKeyword), + LetKeyword(LetKeyword), + MutKeyword(MutKeyword), + Pattern(Pattern), + Params(Params), + TypeAssert(TypeAssert), + Body(Body) {} + + LetDeclaration( + std::vector Annotations, + class PubKeyword* PubKeyword, + class ForeignKeyword* ForeignKeyword, + class LetKeyword* LetKeyword, + class MutKeyword* MutKeyword, + class Pattern* Pattern, + std::vector Params, + class TypeAssert* TypeAssert, + LetBody* Body + ): TypedNode(NodeKind::LetDeclaration), + AnnotationContainer(Annotations), + PubKeyword(PubKeyword), + ForeignKeyword(ForeignKeyword), + LetKeyword(LetKeyword), + MutKeyword(MutKeyword), + Pattern(Pattern), + Params(Params), + TypeAssert(TypeAssert), + Body(Body) {} + + inline Scope* getScope() override { + if (isFunction()) { if (TheScope == nullptr) { TheScope = new Scope(this); } return TheScope; } + return Parent->getScope(); + } + bool isInstance() const noexcept { + return Parent->getKind() == NodeKind::InstanceDeclaration; + } - }; + bool isClass() const noexcept { + return Parent->getKind() == NodeKind::ClassDeclaration; + } - class MatchExpression : public Expression { - public: + bool isSignature() const noexcept { + return ForeignKeyword; + } - class MatchKeyword* MatchKeyword; - Expression* Value; - class BlockStart* BlockStart; - std::vector Cases; + bool isVariable() const noexcept { + // Variables in classes and instances are never possible, so we reflect this by excluding them here. + return !isSignature() && !isClass() && !isInstance() && Params.empty() && (Pattern->getKind() != NodeKind::BindPattern || !Body); + } - inline MatchExpression( - class MatchKeyword* MatchKeyword, - Expression* Value, - class BlockStart* BlockStart, - std::vector Cases - ): Expression(NodeKind::MatchExpression), - MatchKeyword(MatchKeyword), - Value(Value), - BlockStart(BlockStart), - Cases(Cases) {} + bool isFunction() const noexcept { + return !isSignature() && !isVariable(); + } - inline MatchExpression( - std::vector Annotations, - class MatchKeyword* MatchKeyword, - Expression* Value, - class BlockStart* BlockStart, - std::vector Cases - ): Expression(NodeKind::MatchExpression, Annotations), - MatchKeyword(MatchKeyword), - Value(Value), - BlockStart(BlockStart), - Cases(Cases) {} + Symbol* getName() const noexcept { + ZEN_ASSERT(Pattern->getKind() == NodeKind::BindPattern); + return static_cast(Pattern)->Name; + } - Token* getFirstToken() const override; - Token* getLastToken() const override; + ByteString getNameAsString() const noexcept { + return getCanonicalText(getName()); + } - }; + Token* getFirstToken() const override; + Token* getLastToken() const override; - class MemberExpression : public Expression { - public: + static bool classof(const Node* N) { + return N->getKind() == NodeKind::LetDeclaration; + } - Expression* E; - class Dot* Dot; - Token* Name; +}; - inline MemberExpression( - Expression* E, - class Dot* Dot, - Token* Name - ): Expression(NodeKind::MemberExpression), - E(E), - Dot(Dot), - Name(Name) {} +class InstanceDeclaration : public Node { +public: - inline MemberExpression( - std::vector Annotations, - class Expression* E, - class Dot* Dot, - Token* Name - ): Expression(NodeKind::MemberExpression, Annotations), - E(E), - Dot(Dot), - Name(Name) {} + class InstanceKeyword* InstanceKeyword; + IdentifierAlt* Name; + std::vector TypeExps; + class BlockStart* BlockStart; + std::vector Elements; - Token* getFirstToken() const override; - Token* getLastToken() const override; + InstanceDeclaration( + class InstanceKeyword* InstanceKeyword, + IdentifierAlt* Name, + std::vector TypeExps, + class BlockStart* BlockStart, + std::vector Elements + ): Node(NodeKind::InstanceDeclaration), + InstanceKeyword(InstanceKeyword), + Name(Name), + TypeExps(TypeExps), + BlockStart(BlockStart), + Elements(Elements) {} - inline Expression* getExpression() const { - return E; + Token* getFirstToken() const override; + Token* getLastToken() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::InstanceDeclaration; + } + +}; + +class ClassDeclaration : public Node { +public: + + class PubKeyword* PubKeyword; + class ClassKeyword* ClassKeyword; + IdentifierAlt* Name; + std::vector TypeVars; + class BlockStart* BlockStart; + std::vector Elements; + + ClassDeclaration( + class PubKeyword* PubKeyword, + class ClassKeyword* ClassKeyword, + IdentifierAlt* Name, + std::vector TypeVars, + class BlockStart* BlockStart, + std::vector Elements + ): Node(NodeKind::ClassDeclaration), + PubKeyword(PubKeyword), + ClassKeyword(ClassKeyword), + Name(Name), + TypeVars(TypeVars), + BlockStart(BlockStart), + Elements(Elements) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::ClassDeclaration; + } + +}; + +class RecordDeclarationField : public Node { +public: + + RecordDeclarationField( + Identifier* Name, + class Colon* Colon, + class TypeExpression* TypeExpression + ): Node(NodeKind::RecordDeclarationField), + Name(Name), + Colon(Colon), + TypeExpression(TypeExpression) {} + + Identifier* Name; + class Colon* Colon; + class TypeExpression* TypeExpression; + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class RecordDeclaration : public Node { +public: + + InferContext* Ctx; + + class PubKeyword* PubKeyword; + class StructKeyword* StructKeyword; + IdentifierAlt* Name; + std::vector Vars; + class BlockStart* BlockStart; + std::vector Fields; + + RecordDeclaration( + class PubKeyword* PubKeyword, + class StructKeyword* StructKeyword, + IdentifierAlt* Name, + std::vector Vars, + class BlockStart* BlockStart, + std::vector Fields + ): Node(NodeKind::RecordDeclaration), + PubKeyword(PubKeyword), + StructKeyword(StructKeyword), + Name(Name), + Vars(Vars), + BlockStart(BlockStart), + Fields(Fields) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class VariantDeclarationMember : public Node { +public: + + inline VariantDeclarationMember(NodeKind Kind): + Node(Kind) {} + +}; + +class TupleVariantDeclarationMember : public VariantDeclarationMember { +public: + + IdentifierAlt* Name; + std::vector Elements; + + inline TupleVariantDeclarationMember( + IdentifierAlt* Name, + std::vector Elements + ): VariantDeclarationMember(NodeKind::TupleVariantDeclarationMember), + Name(Name), + Elements(Elements) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class RecordVariantDeclarationMember : public VariantDeclarationMember { +public: + + IdentifierAlt* Name; + class BlockStart* BlockStart; + std::vector Fields; + + inline RecordVariantDeclarationMember( + IdentifierAlt* Name, + class BlockStart* BlockStart, + std::vector Fields + ): VariantDeclarationMember(NodeKind::RecordVariantDeclarationMember), + Name(Name), + BlockStart(BlockStart), + Fields(Fields) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class VariantDeclaration : public Node { +public: + + InferContext* Ctx; + + class PubKeyword* PubKeyword; + class EnumKeyword* EnumKeyword; + class IdentifierAlt* Name; + std::vector TVs; + class BlockStart* BlockStart; + std::vector Members; + + inline VariantDeclaration( + class PubKeyword* PubKeyword, + class EnumKeyword* EnumKeyword, + class IdentifierAlt* Name, + std::vector TVs, + class BlockStart* BlockStart, + std::vector Members + ): Node(NodeKind::VariantDeclaration), + PubKeyword(PubKeyword), + EnumKeyword(EnumKeyword), + Name(Name), + TVs(TVs), + BlockStart(BlockStart), + Members(Members) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + +}; + +class SourceFile : public Node { + + Scope* TheScope = nullptr; + +public: + + TextFile File; + InferContext* Ctx; + + std::vector Elements; + + SourceFile(TextFile& File, std::vector Elements): + Node(NodeKind::SourceFile), File(File), Elements(Elements) {} + + inline TextFile& getTextFile() { + return File; + } + + inline const TextFile& getTextFile() const { + return File; + } + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + inline Scope* getScope() override { + if (TheScope == nullptr) { + TheScope = new Scope(this); } - - }; - - class TupleExpression : public Expression { - public: - - class LParen* LParen; - std::vector> Elements; - class RParen* RParen; - - inline TupleExpression( - class LParen* LParen, - std::vector> Elements, - class RParen* RParen - ): Expression(NodeKind::TupleExpression), - LParen(LParen), - Elements(Elements), - RParen(RParen) {} - - inline TupleExpression( - std::vector Annotations, - class LParen* LParen, - std::vector> Elements, - class RParen* RParen - ): Expression(NodeKind::TupleExpression, Annotations), - LParen(LParen), - Elements(Elements), - RParen(RParen) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class NestedExpression : public Expression { - public: - - class LParen* LParen; - Expression* Inner; - class RParen* RParen; - - inline NestedExpression( - class LParen* LParen, - Expression* Inner, - class RParen* RParen - ): Expression(NodeKind::NestedExpression), - LParen(LParen), - Inner(Inner), - RParen(RParen) {} - - inline NestedExpression( - std::vector Annotations, - class LParen* LParen, - Expression* Inner, - class RParen* RParen - ): Expression(NodeKind::NestedExpression, Annotations), - LParen(LParen), - Inner(Inner), - RParen(RParen) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class LiteralExpression : public Expression { - public: - - Literal* Token; - - LiteralExpression( - Literal* Token - ): Expression(NodeKind::LiteralExpression), - Token(Token) {} - - LiteralExpression( - std::vector Annotations, - Literal* Token - ): Expression(NodeKind::LiteralExpression, Annotations), - Token(Token) {} - - inline ByteString getAsText() { - ZEN_ASSERT(Token->getKind() == NodeKind::StringLiteral); - return static_cast(Token)->Text; - } - - inline int getAsInt() { - ZEN_ASSERT(Token->getKind() == NodeKind::IntegerLiteral); - return static_cast(Token)->asInt(); - } - - class Token* getFirstToken() const override; - class Token* getLastToken() const override; - - }; - - class CallExpression : public Expression { - public: - - Expression* Function; - std::vector Args; - - inline CallExpression( - Expression* Function, - std::vector Args - ): Expression(NodeKind::CallExpression), - Function(Function), - Args(Args) {} - - inline CallExpression( - std::vector Annotations, - Expression* Function, - std::vector Args - ): Expression(NodeKind::CallExpression, Annotations), - Function(Function), - Args(Args) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class InfixExpression : public Expression { - public: - - Expression* Left; - Token* Operator; - Expression* Right; - - inline InfixExpression( - Expression* Left, - Token* Operator, - Expression* Right - ): Expression(NodeKind::InfixExpression), - Left(Left), - Operator(Operator), - Right(Right) {} - - inline InfixExpression( - std::vector Annotations, - Expression* Left, - Token* Operator, - Expression* Right - ): Expression(NodeKind::InfixExpression, Annotations), - Left(Left), - Operator(Operator), - Right(Right) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class PrefixExpression : public Expression { - public: - - Token* Operator; - Expression* Argument; - - PrefixExpression( - Token* Operator, - Expression* Argument - ): Expression(NodeKind::PrefixExpression), - Operator(Operator), - Argument(Argument) {} - - PrefixExpression( - std::vector Annotations, - Token* Operator, - Expression* Argument - ): Expression(NodeKind::PrefixExpression, Annotations), - Operator(Operator), - Argument(Argument) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class RecordExpressionField : public Node { - public: - - Identifier* Name; - class Equals* Equals; - Expression* E; - - inline RecordExpressionField( - Identifier* Name, - class Equals* Equals, - Expression* E - ): Node(NodeKind::RecordExpressionField), - Name(Name), - Equals(Equals), - E(E) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - inline Expression* getExpression() const { - return E; - } - - }; - - class RecordExpression : public Expression { - public: - - class LBrace* LBrace; - std::vector> Fields; - class RBrace* RBrace; - - inline RecordExpression( - class LBrace* LBrace, - std::vector> Fields, - class RBrace* RBrace - ): Expression(NodeKind::RecordExpression), - LBrace(LBrace), - Fields(Fields), - RBrace(RBrace) {} - - inline RecordExpression( - std::vector Annotations, - class LBrace* LBrace, - std::vector> Fields, - class RBrace* RBrace - ): Expression(NodeKind::RecordExpression, Annotations), - LBrace(LBrace), - Fields(Fields), - RBrace(RBrace) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class Statement : public Node, public AnnotationContainer { - protected: - - inline Statement(NodeKind Type, std::vector Annotations = {}): - Node(Type), AnnotationContainer(Annotations) {} - - }; - - class ExpressionStatement : public Statement { - public: - - class Expression* Expression; - - ExpressionStatement(class Expression* Expression): - Statement(NodeKind::ExpressionStatement), Expression(Expression) {} - - ExpressionStatement( - std::vector Annotations, - class Expression* Expression - ): Statement(NodeKind::ExpressionStatement, Annotations), - Expression(Expression) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class IfStatementPart : public Node, public AnnotationContainer { - public: - - Token* Keyword; - Expression* Test; - class BlockStart* BlockStart; - std::vector Elements; - - inline IfStatementPart( - Token* Keyword, - Expression* Test, - class BlockStart* BlockStart, - std::vector Elements - ): Node(NodeKind::IfStatementPart), - Keyword(Keyword), - Test(Test), - BlockStart(BlockStart), - Elements(Elements) {} - - inline IfStatementPart( - std::vector Annotations, - Token* Keyword, - Expression* Test, - class BlockStart* BlockStart, - std::vector Elements - ): Node(NodeKind::IfStatementPart), - AnnotationContainer(Annotations), - Keyword(Keyword), - Test(Test), - BlockStart(BlockStart), - Elements(Elements) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class IfStatement : public Statement { - public: - - std::vector Parts; - - inline IfStatement(std::vector Parts): - Statement(NodeKind::IfStatement), Parts(Parts) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class ReturnStatement : public Statement { - public: - - class ReturnKeyword* ReturnKeyword; - class Expression* Expression; - - ReturnStatement( - class ReturnKeyword* ReturnKeyword, - class Expression* Expression - ): Statement(NodeKind::ReturnStatement), - ReturnKeyword(ReturnKeyword), - Expression(Expression) {} - - ReturnStatement( - std::vector Annotations, - class ReturnKeyword* ReturnKeyword, - class Expression* Expression - ): Statement(NodeKind::ReturnStatement, Annotations), - ReturnKeyword(ReturnKeyword), - Expression(Expression) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class TypeAssert : public Node { - public: - - class Colon* Colon; - class TypeExpression* TypeExpression; - - TypeAssert( - class Colon* Colon, - class TypeExpression* TypeExpression - ): Node(NodeKind::TypeAssert), - Colon(Colon), - TypeExpression(TypeExpression) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class Parameter : public Node { - public: - - Parameter( - class Pattern* Pattern, - class TypeAssert* TypeAssert - ): Node(NodeKind::Parameter), - Pattern(Pattern), - TypeAssert(TypeAssert) {} - - class Pattern* Pattern; - class TypeAssert* TypeAssert; - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class LetBody : public Node { - public: - - LetBody(NodeKind Type): Node(Type) {} - - }; - - class LetBlockBody : public LetBody { - public: - - class BlockStart* BlockStart; - std::vector Elements; - - LetBlockBody( - class BlockStart* BlockStart, - std::vector Elements - ): LetBody(NodeKind::LetBlockBody), - BlockStart(BlockStart), - Elements(Elements) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class LetExprBody : public LetBody { - public: - - class Equals* Equals; - class Expression* Expression; - - LetExprBody( - class Equals* Equals, - class Expression* Expression - ): LetBody(NodeKind::LetExprBody), - Equals(Equals), - Expression(Expression) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class LetDeclaration : public TypedNode, public AnnotationContainer { - - Scope* TheScope = nullptr; - - public: - - bool IsCycleActive = false; - bool Visited = false; - InferContext* Ctx; - - class PubKeyword* PubKeyword; - class ForeignKeyword* ForeignKeyword; - class LetKeyword* LetKeyword; - class MutKeyword* MutKeyword; - class Pattern* Pattern; - std::vector Params; - class TypeAssert* TypeAssert; - LetBody* Body; - - LetDeclaration( - class PubKeyword* PubKeyword, - class ForeignKeyword* ForeignKeyword, - class LetKeyword* LetKeyword, - class MutKeyword* MutKeyword, - class Pattern* Pattern, - std::vector Params, - class TypeAssert* TypeAssert, - LetBody* Body - ): TypedNode(NodeKind::LetDeclaration), - PubKeyword(PubKeyword), - ForeignKeyword(ForeignKeyword), - LetKeyword(LetKeyword), - MutKeyword(MutKeyword), - Pattern(Pattern), - Params(Params), - TypeAssert(TypeAssert), - Body(Body) {} - - LetDeclaration( - std::vector Annotations, - class PubKeyword* PubKeyword, - class ForeignKeyword* ForeignKeyword, - class LetKeyword* LetKeyword, - class MutKeyword* MutKeyword, - class Pattern* Pattern, - std::vector Params, - class TypeAssert* TypeAssert, - LetBody* Body - ): TypedNode(NodeKind::LetDeclaration), - AnnotationContainer(Annotations), - PubKeyword(PubKeyword), - ForeignKeyword(ForeignKeyword), - LetKeyword(LetKeyword), - MutKeyword(MutKeyword), - Pattern(Pattern), - Params(Params), - TypeAssert(TypeAssert), - Body(Body) {} - - inline Scope* getScope() override { - if (isFunction()) { - if (TheScope == nullptr) { - TheScope = new Scope(this); - } - return TheScope; - } - return Parent->getScope(); - } - - bool isInstance() const noexcept { - return Parent->getKind() == NodeKind::InstanceDeclaration; - } - - bool isClass() const noexcept { - return Parent->getKind() == NodeKind::ClassDeclaration; - } - - bool isSignature() const noexcept { - return ForeignKeyword; - } - - bool isVariable() const noexcept { - // Variables in classes and instances are never possible, so we reflect this by excluding them here. - return !isSignature() && !isClass() && !isInstance() && Params.empty() && (Pattern->getKind() != NodeKind::BindPattern || !Body); - } - - bool isFunction() const noexcept { - return !isSignature() && !isVariable(); - } - - Symbol* getName() const noexcept { - ZEN_ASSERT(Pattern->getKind() == NodeKind::BindPattern); - return static_cast(Pattern)->Name; - } - - ByteString getNameAsString() const noexcept { - return getCanonicalText(getName()); - } - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::LetDeclaration; - } - - }; - - class InstanceDeclaration : public Node { - public: - - class InstanceKeyword* InstanceKeyword; - IdentifierAlt* Name; - std::vector TypeExps; - class BlockStart* BlockStart; - std::vector Elements; - - InstanceDeclaration( - class InstanceKeyword* InstanceKeyword, - IdentifierAlt* Name, - std::vector TypeExps, - class BlockStart* BlockStart, - std::vector Elements - ): Node(NodeKind::InstanceDeclaration), - InstanceKeyword(InstanceKeyword), - Name(Name), - TypeExps(TypeExps), - BlockStart(BlockStart), - Elements(Elements) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::InstanceDeclaration; - } - - }; - - class ClassDeclaration : public Node { - public: - - class PubKeyword* PubKeyword; - class ClassKeyword* ClassKeyword; - IdentifierAlt* Name; - std::vector TypeVars; - class BlockStart* BlockStart; - std::vector Elements; - - ClassDeclaration( - class PubKeyword* PubKeyword, - class ClassKeyword* ClassKeyword, - IdentifierAlt* Name, - std::vector TypeVars, - class BlockStart* BlockStart, - std::vector Elements - ): Node(NodeKind::ClassDeclaration), - PubKeyword(PubKeyword), - ClassKeyword(ClassKeyword), - Name(Name), - TypeVars(TypeVars), - BlockStart(BlockStart), - Elements(Elements) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::ClassDeclaration; - } - - }; - - class RecordDeclarationField : public Node { - public: - - RecordDeclarationField( - Identifier* Name, - class Colon* Colon, - class TypeExpression* TypeExpression - ): Node(NodeKind::RecordDeclarationField), - Name(Name), - Colon(Colon), - TypeExpression(TypeExpression) {} - - Identifier* Name; - class Colon* Colon; - class TypeExpression* TypeExpression; - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class RecordDeclaration : public Node { - public: - - InferContext* Ctx; - - class PubKeyword* PubKeyword; - class StructKeyword* StructKeyword; - IdentifierAlt* Name; - std::vector Vars; - class BlockStart* BlockStart; - std::vector Fields; - - RecordDeclaration( - class PubKeyword* PubKeyword, - class StructKeyword* StructKeyword, - IdentifierAlt* Name, - std::vector Vars, - class BlockStart* BlockStart, - std::vector Fields - ): Node(NodeKind::RecordDeclaration), - PubKeyword(PubKeyword), - StructKeyword(StructKeyword), - Name(Name), - Vars(Vars), - BlockStart(BlockStart), - Fields(Fields) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class VariantDeclarationMember : public Node { - public: - - inline VariantDeclarationMember(NodeKind Kind): - Node(Kind) {} - - }; - - class TupleVariantDeclarationMember : public VariantDeclarationMember { - public: - - IdentifierAlt* Name; - std::vector Elements; - - inline TupleVariantDeclarationMember( - IdentifierAlt* Name, - std::vector Elements - ): VariantDeclarationMember(NodeKind::TupleVariantDeclarationMember), - Name(Name), - Elements(Elements) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class RecordVariantDeclarationMember : public VariantDeclarationMember { - public: - - IdentifierAlt* Name; - class BlockStart* BlockStart; - std::vector Fields; - - inline RecordVariantDeclarationMember( - IdentifierAlt* Name, - class BlockStart* BlockStart, - std::vector Fields - ): VariantDeclarationMember(NodeKind::RecordVariantDeclarationMember), - Name(Name), - BlockStart(BlockStart), - Fields(Fields) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class VariantDeclaration : public Node { - public: - - InferContext* Ctx; - - class PubKeyword* PubKeyword; - class EnumKeyword* EnumKeyword; - class IdentifierAlt* Name; - std::vector TVs; - class BlockStart* BlockStart; - std::vector Members; - - inline VariantDeclaration( - class PubKeyword* PubKeyword, - class EnumKeyword* EnumKeyword, - class IdentifierAlt* Name, - std::vector TVs, - class BlockStart* BlockStart, - std::vector Members - ): Node(NodeKind::VariantDeclaration), - PubKeyword(PubKeyword), - EnumKeyword(EnumKeyword), - Name(Name), - TVs(TVs), - BlockStart(BlockStart), - Members(Members) {} - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - }; - - class SourceFile : public Node { - - Scope* TheScope = nullptr; - - public: - - TextFile File; - InferContext* Ctx; - - std::vector Elements; - - SourceFile(TextFile& File, std::vector Elements): - Node(NodeKind::SourceFile), File(File), Elements(Elements) {} - - inline TextFile& getTextFile() { - return File; - } - - inline const TextFile& getTextFile() const { - return File; - } - - Token* getFirstToken() const override; - Token* getLastToken() const override; - - inline Scope* getScope() override { - if (TheScope == nullptr) { - TheScope = new Scope(this); - } - return TheScope; - } - - static bool classof(const Node* N) { - return N->getKind() == NodeKind::SourceFile; - } - - }; - - template<> inline NodeKind getNodeType() { return NodeKind::Equals; } - template<> inline NodeKind getNodeType() { return NodeKind::Colon; } - template<> inline NodeKind getNodeType() { return NodeKind::Dot; } - template<> inline NodeKind getNodeType() { return NodeKind::DotDot; } - template<> inline NodeKind getNodeType() { return NodeKind::Tilde; } - template<> inline NodeKind getNodeType() { return NodeKind::LParen; } - template<> inline NodeKind getNodeType() { return NodeKind::RParen; } - template<> inline NodeKind getNodeType() { return NodeKind::LBracket; } - template<> inline NodeKind getNodeType() { return NodeKind::RBracket; } - template<> inline NodeKind getNodeType() { return NodeKind::LBrace; } - template<> inline NodeKind getNodeType() { return NodeKind::RBrace; } - template<> inline NodeKind getNodeType() { return NodeKind::RArrow; } - template<> inline NodeKind getNodeType() { return NodeKind::RArrowAlt; } - template<> inline NodeKind getNodeType() { return NodeKind::LetKeyword; } - template<> inline NodeKind getNodeType() { return NodeKind::ForeignKeyword; } - template<> inline NodeKind getNodeType() { return NodeKind::MutKeyword; } - template<> inline NodeKind getNodeType() { return NodeKind::PubKeyword; } - template<> inline NodeKind getNodeType() { return NodeKind::TypeKeyword; } - template<> inline NodeKind getNodeType() { return NodeKind::ReturnKeyword; } - template<> inline NodeKind getNodeType() { return NodeKind::ModKeyword; } - template<> inline NodeKind getNodeType() { return NodeKind::StructKeyword; } - template<> inline NodeKind getNodeType() { return NodeKind::EnumKeyword; } - template<> inline NodeKind getNodeType() { return NodeKind::ClassKeyword; } - template<> inline NodeKind getNodeType() { return NodeKind::InstanceKeyword; } - template<> inline NodeKind getNodeType() { return NodeKind::ElifKeyword; } - template<> inline NodeKind getNodeType() { return NodeKind::IfKeyword; } - template<> inline NodeKind getNodeType() { return NodeKind::MatchKeyword; } - template<> inline NodeKind getNodeType() { return NodeKind::ElseKeyword; } - template<> inline NodeKind getNodeType() { return NodeKind::Invalid; } - template<> inline NodeKind getNodeType() { return NodeKind::EndOfFile; } - template<> inline NodeKind getNodeType() { return NodeKind::BlockStart; } - template<> inline NodeKind getNodeType() { return NodeKind::BlockEnd; } - template<> inline NodeKind getNodeType() { return NodeKind::LineFoldEnd; } - template<> inline NodeKind getNodeType() { return NodeKind::CustomOperator; } - template<> inline NodeKind getNodeType() { return NodeKind::Assignment; } - template<> inline NodeKind getNodeType() { return NodeKind::Identifier; } - template<> inline NodeKind getNodeType() { return NodeKind::IdentifierAlt; } - template<> inline NodeKind getNodeType() { return NodeKind::StringLiteral; } - template<> inline NodeKind getNodeType() { return NodeKind::IntegerLiteral; } - template<> inline NodeKind getNodeType() { return NodeKind::QualifiedTypeExpression; } - template<> inline NodeKind getNodeType() { return NodeKind::ReferenceTypeExpression; } - template<> inline NodeKind getNodeType() { return NodeKind::ArrowTypeExpression; } - template<> inline NodeKind getNodeType() { return NodeKind::BindPattern; } - template<> inline NodeKind getNodeType() { return NodeKind::ReferenceExpression; } - template<> inline NodeKind getNodeType() { return NodeKind::NestedExpression; } - template<> inline NodeKind getNodeType() { return NodeKind::LiteralExpression; } - template<> inline NodeKind getNodeType() { return NodeKind::CallExpression; } - template<> inline NodeKind getNodeType() { return NodeKind::InfixExpression; } - template<> inline NodeKind getNodeType() { return NodeKind::PrefixExpression; } - template<> inline NodeKind getNodeType() { return NodeKind::ExpressionStatement; } - template<> inline NodeKind getNodeType() { return NodeKind::ReturnStatement; } - template<> inline NodeKind getNodeType() { return NodeKind::IfStatement; } - template<> inline NodeKind getNodeType() { return NodeKind::IfStatementPart; } - template<> inline NodeKind getNodeType() { return NodeKind::TypeAssert; } - template<> inline NodeKind getNodeType() { return NodeKind::Parameter; } - template<> inline NodeKind getNodeType() { return NodeKind::LetBlockBody; } - template<> inline NodeKind getNodeType() { return NodeKind::LetExprBody; } - template<> inline NodeKind getNodeType() { return NodeKind::LetDeclaration; } - template<> inline NodeKind getNodeType() { return NodeKind::RecordDeclarationField; } - template<> inline NodeKind getNodeType() { return NodeKind::RecordDeclaration; } - template<> inline NodeKind getNodeType() { return NodeKind::ClassDeclaration; } - template<> inline NodeKind getNodeType() { return NodeKind::InstanceDeclaration; } - template<> inline NodeKind getNodeType() { return NodeKind::SourceFile; } + return TheScope; + } + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::SourceFile; + } + +}; + +template<> inline NodeKind getNodeType() { return NodeKind::Equals; } +template<> inline NodeKind getNodeType() { return NodeKind::Colon; } +template<> inline NodeKind getNodeType() { return NodeKind::Dot; } +template<> inline NodeKind getNodeType() { return NodeKind::DotDot; } +template<> inline NodeKind getNodeType() { return NodeKind::Tilde; } +template<> inline NodeKind getNodeType() { return NodeKind::LParen; } +template<> inline NodeKind getNodeType() { return NodeKind::RParen; } +template<> inline NodeKind getNodeType() { return NodeKind::LBracket; } +template<> inline NodeKind getNodeType() { return NodeKind::RBracket; } +template<> inline NodeKind getNodeType() { return NodeKind::LBrace; } +template<> inline NodeKind getNodeType() { return NodeKind::RBrace; } +template<> inline NodeKind getNodeType() { return NodeKind::RArrow; } +template<> inline NodeKind getNodeType() { return NodeKind::RArrowAlt; } +template<> inline NodeKind getNodeType() { return NodeKind::LetKeyword; } +template<> inline NodeKind getNodeType() { return NodeKind::ForeignKeyword; } +template<> inline NodeKind getNodeType() { return NodeKind::MutKeyword; } +template<> inline NodeKind getNodeType() { return NodeKind::PubKeyword; } +template<> inline NodeKind getNodeType() { return NodeKind::TypeKeyword; } +template<> inline NodeKind getNodeType() { return NodeKind::ReturnKeyword; } +template<> inline NodeKind getNodeType() { return NodeKind::ModKeyword; } +template<> inline NodeKind getNodeType() { return NodeKind::StructKeyword; } +template<> inline NodeKind getNodeType() { return NodeKind::EnumKeyword; } +template<> inline NodeKind getNodeType() { return NodeKind::ClassKeyword; } +template<> inline NodeKind getNodeType() { return NodeKind::InstanceKeyword; } +template<> inline NodeKind getNodeType() { return NodeKind::ElifKeyword; } +template<> inline NodeKind getNodeType() { return NodeKind::IfKeyword; } +template<> inline NodeKind getNodeType() { return NodeKind::MatchKeyword; } +template<> inline NodeKind getNodeType() { return NodeKind::ElseKeyword; } +template<> inline NodeKind getNodeType() { return NodeKind::Invalid; } +template<> inline NodeKind getNodeType() { return NodeKind::EndOfFile; } +template<> inline NodeKind getNodeType() { return NodeKind::BlockStart; } +template<> inline NodeKind getNodeType() { return NodeKind::BlockEnd; } +template<> inline NodeKind getNodeType() { return NodeKind::LineFoldEnd; } +template<> inline NodeKind getNodeType() { return NodeKind::CustomOperator; } +template<> inline NodeKind getNodeType() { return NodeKind::Assignment; } +template<> inline NodeKind getNodeType() { return NodeKind::Identifier; } +template<> inline NodeKind getNodeType() { return NodeKind::IdentifierAlt; } +template<> inline NodeKind getNodeType() { return NodeKind::StringLiteral; } +template<> inline NodeKind getNodeType() { return NodeKind::IntegerLiteral; } +template<> inline NodeKind getNodeType() { return NodeKind::QualifiedTypeExpression; } +template<> inline NodeKind getNodeType() { return NodeKind::ReferenceTypeExpression; } +template<> inline NodeKind getNodeType() { return NodeKind::ArrowTypeExpression; } +template<> inline NodeKind getNodeType() { return NodeKind::BindPattern; } +template<> inline NodeKind getNodeType() { return NodeKind::ReferenceExpression; } +template<> inline NodeKind getNodeType() { return NodeKind::NestedExpression; } +template<> inline NodeKind getNodeType() { return NodeKind::LiteralExpression; } +template<> inline NodeKind getNodeType() { return NodeKind::CallExpression; } +template<> inline NodeKind getNodeType() { return NodeKind::InfixExpression; } +template<> inline NodeKind getNodeType() { return NodeKind::PrefixExpression; } +template<> inline NodeKind getNodeType() { return NodeKind::ExpressionStatement; } +template<> inline NodeKind getNodeType() { return NodeKind::ReturnStatement; } +template<> inline NodeKind getNodeType() { return NodeKind::IfStatement; } +template<> inline NodeKind getNodeType() { return NodeKind::IfStatementPart; } +template<> inline NodeKind getNodeType() { return NodeKind::TypeAssert; } +template<> inline NodeKind getNodeType() { return NodeKind::Parameter; } +template<> inline NodeKind getNodeType() { return NodeKind::LetBlockBody; } +template<> inline NodeKind getNodeType() { return NodeKind::LetExprBody; } +template<> inline NodeKind getNodeType() { return NodeKind::LetDeclaration; } +template<> inline NodeKind getNodeType() { return NodeKind::RecordDeclarationField; } +template<> inline NodeKind getNodeType() { return NodeKind::RecordDeclaration; } +template<> inline NodeKind getNodeType() { return NodeKind::ClassDeclaration; } +template<> inline NodeKind getNodeType() { return NodeKind::InstanceDeclaration; } +template<> inline NodeKind getNodeType() { return NodeKind::SourceFile; } } diff --git a/bootstrap/cxx/include/bolt/CSTVisitor.hpp b/bootstrap/cxx/include/bolt/CSTVisitor.hpp index 133fc83fd..90228669d 100644 --- a/bootstrap/cxx/include/bolt/CSTVisitor.hpp +++ b/bootstrap/cxx/include/bolt/CSTVisitor.hpp @@ -8,1242 +8,1242 @@ namespace bolt { - template - class CSTVisitor { - public: +template +class CSTVisitor { +public: - void visit(Node* N) { + void visit(Node* N) { #define BOLT_GEN_CASE(name) \ - case NodeKind::name: \ - return static_cast(this)->visit ## name(static_cast(N)); - - switch (N->getKind()) { - BOLT_GEN_CASE(VBar) - BOLT_GEN_CASE(Equals) - BOLT_GEN_CASE(Colon) - BOLT_GEN_CASE(Comma) - BOLT_GEN_CASE(Dot) - BOLT_GEN_CASE(DotDot) - BOLT_GEN_CASE(Tilde) - BOLT_GEN_CASE(At) - BOLT_GEN_CASE(LParen) - BOLT_GEN_CASE(RParen) - BOLT_GEN_CASE(LBracket) - BOLT_GEN_CASE(RBracket) - BOLT_GEN_CASE(LBrace) - BOLT_GEN_CASE(RBrace) - BOLT_GEN_CASE(RArrow) - BOLT_GEN_CASE(RArrowAlt) - BOLT_GEN_CASE(LetKeyword) - BOLT_GEN_CASE(ForeignKeyword) - BOLT_GEN_CASE(MutKeyword) - BOLT_GEN_CASE(PubKeyword) - BOLT_GEN_CASE(TypeKeyword) - BOLT_GEN_CASE(ReturnKeyword) - BOLT_GEN_CASE(ModKeyword) - BOLT_GEN_CASE(StructKeyword) - BOLT_GEN_CASE(EnumKeyword) - BOLT_GEN_CASE(ClassKeyword) - BOLT_GEN_CASE(InstanceKeyword) - BOLT_GEN_CASE(ElifKeyword) - BOLT_GEN_CASE(IfKeyword) - BOLT_GEN_CASE(ElseKeyword) - BOLT_GEN_CASE(MatchKeyword) - BOLT_GEN_CASE(Invalid) - BOLT_GEN_CASE(EndOfFile) - BOLT_GEN_CASE(BlockStart) - BOLT_GEN_CASE(BlockEnd) - BOLT_GEN_CASE(LineFoldEnd) - BOLT_GEN_CASE(CustomOperator) - BOLT_GEN_CASE(Assignment) - BOLT_GEN_CASE(Identifier) - BOLT_GEN_CASE(IdentifierAlt) - BOLT_GEN_CASE(WrappedOperator) - BOLT_GEN_CASE(StringLiteral) - BOLT_GEN_CASE(IntegerLiteral) - BOLT_GEN_CASE(ExpressionAnnotation) - BOLT_GEN_CASE(TypeAssertAnnotation) - BOLT_GEN_CASE(TypeclassConstraintExpression) - BOLT_GEN_CASE(EqualityConstraintExpression) - BOLT_GEN_CASE(RecordTypeExpressionField) - BOLT_GEN_CASE(RecordTypeExpression) - BOLT_GEN_CASE(QualifiedTypeExpression) - BOLT_GEN_CASE(ReferenceTypeExpression) - BOLT_GEN_CASE(ArrowTypeExpression) - BOLT_GEN_CASE(AppTypeExpression) - BOLT_GEN_CASE(VarTypeExpression) - BOLT_GEN_CASE(NestedTypeExpression) - BOLT_GEN_CASE(TupleTypeExpression) - BOLT_GEN_CASE(BindPattern) - BOLT_GEN_CASE(LiteralPattern) - BOLT_GEN_CASE(RecordPatternField) - BOLT_GEN_CASE(RecordPattern) - BOLT_GEN_CASE(NamedRecordPattern) - BOLT_GEN_CASE(NamedTuplePattern) - BOLT_GEN_CASE(TuplePattern) - BOLT_GEN_CASE(NestedPattern) - BOLT_GEN_CASE(ListPattern) - BOLT_GEN_CASE(ReferenceExpression) - BOLT_GEN_CASE(MatchCase) - BOLT_GEN_CASE(MatchExpression) - BOLT_GEN_CASE(MemberExpression) - BOLT_GEN_CASE(TupleExpression) - BOLT_GEN_CASE(NestedExpression) - BOLT_GEN_CASE(LiteralExpression) - BOLT_GEN_CASE(CallExpression) - BOLT_GEN_CASE(InfixExpression) - BOLT_GEN_CASE(PrefixExpression) - BOLT_GEN_CASE(RecordExpressionField) - BOLT_GEN_CASE(RecordExpression) - BOLT_GEN_CASE(ExpressionStatement) - BOLT_GEN_CASE(ReturnStatement) - BOLT_GEN_CASE(IfStatement) - BOLT_GEN_CASE(IfStatementPart) - BOLT_GEN_CASE(TypeAssert) - BOLT_GEN_CASE(Parameter) - BOLT_GEN_CASE(LetBlockBody) - BOLT_GEN_CASE(LetExprBody) - BOLT_GEN_CASE(LetDeclaration) - BOLT_GEN_CASE(RecordDeclaration) - BOLT_GEN_CASE(RecordDeclarationField) - BOLT_GEN_CASE(VariantDeclaration) - BOLT_GEN_CASE(TupleVariantDeclarationMember) - BOLT_GEN_CASE(RecordVariantDeclarationMember) - BOLT_GEN_CASE(ClassDeclaration) - BOLT_GEN_CASE(InstanceDeclaration) - BOLT_GEN_CASE(SourceFile) - } - } - - protected: - - void visitNode(Node* N) { - visitEachChild(N); - } - - void visitToken(Token* N) { - static_cast(this)->visitNode(N); - } - - void visitVBar(VBar* N) { - static_cast(this)->visitToken(N); - } - - void visitEquals(Equals* N) { - static_cast(this)->visitToken(N); - } - - void visitColon(Colon* N) { - static_cast(this)->visitToken(N); - } - - void visitComma(Comma* N) { - static_cast(this)->visitToken(N); - } - - void visitDot(Dot* N) { - static_cast(this)->visitToken(N); - } - - void visitDotDot(DotDot* N) { - static_cast(this)->visitToken(N); - } - - void visitTilde(Tilde* N) { - static_cast(this)->visitToken(N); - } - - void visitAt(At* N) { - static_cast(this)->visitToken(N); - } - - void visitLParen(LParen* N) { - static_cast(this)->visitToken(N); - } - - void visitRParen(RParen* N) { - static_cast(this)->visitToken(N); - } - - void visitLBracket(LBracket* N) { - static_cast(this)->visitToken(N); - } - - void visitRBracket(RBracket* N) { - static_cast(this)->visitToken(N); - } - - void visitLBrace(LBrace* N) { - static_cast(this)->visitToken(N); - } + case NodeKind::name: \ + return static_cast(this)->visit ## name(static_cast(N)); + + switch (N->getKind()) { + BOLT_GEN_CASE(VBar) + BOLT_GEN_CASE(Equals) + BOLT_GEN_CASE(Colon) + BOLT_GEN_CASE(Comma) + BOLT_GEN_CASE(Dot) + BOLT_GEN_CASE(DotDot) + BOLT_GEN_CASE(Tilde) + BOLT_GEN_CASE(At) + BOLT_GEN_CASE(LParen) + BOLT_GEN_CASE(RParen) + BOLT_GEN_CASE(LBracket) + BOLT_GEN_CASE(RBracket) + BOLT_GEN_CASE(LBrace) + BOLT_GEN_CASE(RBrace) + BOLT_GEN_CASE(RArrow) + BOLT_GEN_CASE(RArrowAlt) + BOLT_GEN_CASE(LetKeyword) + BOLT_GEN_CASE(ForeignKeyword) + BOLT_GEN_CASE(MutKeyword) + BOLT_GEN_CASE(PubKeyword) + BOLT_GEN_CASE(TypeKeyword) + BOLT_GEN_CASE(ReturnKeyword) + BOLT_GEN_CASE(ModKeyword) + BOLT_GEN_CASE(StructKeyword) + BOLT_GEN_CASE(EnumKeyword) + BOLT_GEN_CASE(ClassKeyword) + BOLT_GEN_CASE(InstanceKeyword) + BOLT_GEN_CASE(ElifKeyword) + BOLT_GEN_CASE(IfKeyword) + BOLT_GEN_CASE(ElseKeyword) + BOLT_GEN_CASE(MatchKeyword) + BOLT_GEN_CASE(Invalid) + BOLT_GEN_CASE(EndOfFile) + BOLT_GEN_CASE(BlockStart) + BOLT_GEN_CASE(BlockEnd) + BOLT_GEN_CASE(LineFoldEnd) + BOLT_GEN_CASE(CustomOperator) + BOLT_GEN_CASE(Assignment) + BOLT_GEN_CASE(Identifier) + BOLT_GEN_CASE(IdentifierAlt) + BOLT_GEN_CASE(WrappedOperator) + BOLT_GEN_CASE(StringLiteral) + BOLT_GEN_CASE(IntegerLiteral) + BOLT_GEN_CASE(ExpressionAnnotation) + BOLT_GEN_CASE(TypeAssertAnnotation) + BOLT_GEN_CASE(TypeclassConstraintExpression) + BOLT_GEN_CASE(EqualityConstraintExpression) + BOLT_GEN_CASE(RecordTypeExpressionField) + BOLT_GEN_CASE(RecordTypeExpression) + BOLT_GEN_CASE(QualifiedTypeExpression) + BOLT_GEN_CASE(ReferenceTypeExpression) + BOLT_GEN_CASE(ArrowTypeExpression) + BOLT_GEN_CASE(AppTypeExpression) + BOLT_GEN_CASE(VarTypeExpression) + BOLT_GEN_CASE(NestedTypeExpression) + BOLT_GEN_CASE(TupleTypeExpression) + BOLT_GEN_CASE(BindPattern) + BOLT_GEN_CASE(LiteralPattern) + BOLT_GEN_CASE(RecordPatternField) + BOLT_GEN_CASE(RecordPattern) + BOLT_GEN_CASE(NamedRecordPattern) + BOLT_GEN_CASE(NamedTuplePattern) + BOLT_GEN_CASE(TuplePattern) + BOLT_GEN_CASE(NestedPattern) + BOLT_GEN_CASE(ListPattern) + BOLT_GEN_CASE(ReferenceExpression) + BOLT_GEN_CASE(MatchCase) + BOLT_GEN_CASE(MatchExpression) + BOLT_GEN_CASE(MemberExpression) + BOLT_GEN_CASE(TupleExpression) + BOLT_GEN_CASE(NestedExpression) + BOLT_GEN_CASE(LiteralExpression) + BOLT_GEN_CASE(CallExpression) + BOLT_GEN_CASE(InfixExpression) + BOLT_GEN_CASE(PrefixExpression) + BOLT_GEN_CASE(RecordExpressionField) + BOLT_GEN_CASE(RecordExpression) + BOLT_GEN_CASE(ExpressionStatement) + BOLT_GEN_CASE(ReturnStatement) + BOLT_GEN_CASE(IfStatement) + BOLT_GEN_CASE(IfStatementPart) + BOLT_GEN_CASE(TypeAssert) + BOLT_GEN_CASE(Parameter) + BOLT_GEN_CASE(LetBlockBody) + BOLT_GEN_CASE(LetExprBody) + BOLT_GEN_CASE(LetDeclaration) + BOLT_GEN_CASE(RecordDeclaration) + BOLT_GEN_CASE(RecordDeclarationField) + BOLT_GEN_CASE(VariantDeclaration) + BOLT_GEN_CASE(TupleVariantDeclarationMember) + BOLT_GEN_CASE(RecordVariantDeclarationMember) + BOLT_GEN_CASE(ClassDeclaration) + BOLT_GEN_CASE(InstanceDeclaration) + BOLT_GEN_CASE(SourceFile) + } + } + +protected: + + void visitNode(Node* N) { + visitEachChild(N); + } + + void visitToken(Token* N) { + static_cast(this)->visitNode(N); + } + + void visitVBar(VBar* N) { + static_cast(this)->visitToken(N); + } + + void visitEquals(Equals* N) { + static_cast(this)->visitToken(N); + } + + void visitColon(Colon* N) { + static_cast(this)->visitToken(N); + } + + void visitComma(Comma* N) { + static_cast(this)->visitToken(N); + } + + void visitDot(Dot* N) { + static_cast(this)->visitToken(N); + } + + void visitDotDot(DotDot* N) { + static_cast(this)->visitToken(N); + } + + void visitTilde(Tilde* N) { + static_cast(this)->visitToken(N); + } + + void visitAt(At* N) { + static_cast(this)->visitToken(N); + } + + void visitLParen(LParen* N) { + static_cast(this)->visitToken(N); + } + + void visitRParen(RParen* N) { + static_cast(this)->visitToken(N); + } + + void visitLBracket(LBracket* N) { + static_cast(this)->visitToken(N); + } + + void visitRBracket(RBracket* N) { + static_cast(this)->visitToken(N); + } + + void visitLBrace(LBrace* N) { + static_cast(this)->visitToken(N); + } + + void visitRBrace(RBrace* N) { + static_cast(this)->visitToken(N); + } + + void visitRArrow(RArrow* N) { + static_cast(this)->visitToken(N); + } + + void visitRArrowAlt(RArrowAlt* N) { + static_cast(this)->visitToken(N); + } + + void visitLetKeyword(LetKeyword* N) { + static_cast(this)->visitToken(N); + } + + void visitForeignKeyword(ForeignKeyword* N) { + static_cast(this)->visitToken(N); + } + + void visitMutKeyword(MutKeyword* N) { + static_cast(this)->visitToken(N); + } + + void visitPubKeyword(PubKeyword* N) { + static_cast(this)->visitToken(N); + } - void visitRBrace(RBrace* N) { - static_cast(this)->visitToken(N); - } - - void visitRArrow(RArrow* N) { - static_cast(this)->visitToken(N); - } - - void visitRArrowAlt(RArrowAlt* N) { - static_cast(this)->visitToken(N); - } - - void visitLetKeyword(LetKeyword* N) { - static_cast(this)->visitToken(N); - } - - void visitForeignKeyword(ForeignKeyword* N) { - static_cast(this)->visitToken(N); - } - - void visitMutKeyword(MutKeyword* N) { - static_cast(this)->visitToken(N); - } - - void visitPubKeyword(PubKeyword* N) { - static_cast(this)->visitToken(N); - } - - void visitTypeKeyword(TypeKeyword* N) { - static_cast(this)->visitToken(N); - } + void visitTypeKeyword(TypeKeyword* N) { + static_cast(this)->visitToken(N); + } - void visitReturnKeyword(ReturnKeyword* N) { - static_cast(this)->visitToken(N); - } + void visitReturnKeyword(ReturnKeyword* N) { + static_cast(this)->visitToken(N); + } - void visitModKeyword(ModKeyword* N) { - static_cast(this)->visitToken(N); - } + void visitModKeyword(ModKeyword* N) { + static_cast(this)->visitToken(N); + } - void visitStructKeyword(StructKeyword* N) { - static_cast(this)->visitToken(N); - } + void visitStructKeyword(StructKeyword* N) { + static_cast(this)->visitToken(N); + } - void visitEnumKeyword(EnumKeyword* N) { - static_cast(this)->visitToken(N); - } + void visitEnumKeyword(EnumKeyword* N) { + static_cast(this)->visitToken(N); + } - void visitClassKeyword(ClassKeyword* N) { - static_cast(this)->visitToken(N); - } + void visitClassKeyword(ClassKeyword* N) { + static_cast(this)->visitToken(N); + } - void visitInstanceKeyword(InstanceKeyword* N) { - static_cast(this)->visitToken(N); - } + void visitInstanceKeyword(InstanceKeyword* N) { + static_cast(this)->visitToken(N); + } - void visitElifKeyword(ElifKeyword* N) { - static_cast(this)->visitToken(N); - } + void visitElifKeyword(ElifKeyword* N) { + static_cast(this)->visitToken(N); + } - void visitIfKeyword(IfKeyword* N) { - static_cast(this)->visitToken(N); - } + void visitIfKeyword(IfKeyword* N) { + static_cast(this)->visitToken(N); + } - void visitElseKeyword(ElseKeyword* N) { - static_cast(this)->visitToken(N); - } + void visitElseKeyword(ElseKeyword* N) { + static_cast(this)->visitToken(N); + } - void visitMatchKeyword(MatchKeyword* N) { - static_cast(this)->visitToken(N); - } + void visitMatchKeyword(MatchKeyword* N) { + static_cast(this)->visitToken(N); + } - void visitInvalid(Invalid* N) { - static_cast(this)->visitToken(N); - } + void visitInvalid(Invalid* N) { + static_cast(this)->visitToken(N); + } - void visitEndOfFile(EndOfFile* N) { - static_cast(this)->visitToken(N); - } + void visitEndOfFile(EndOfFile* N) { + static_cast(this)->visitToken(N); + } - void visitBlockStart(BlockStart* N) { - static_cast(this)->visitToken(N); - } + void visitBlockStart(BlockStart* N) { + static_cast(this)->visitToken(N); + } - void visitBlockEnd(BlockEnd* N) { - static_cast(this)->visitToken(N); - } + void visitBlockEnd(BlockEnd* N) { + static_cast(this)->visitToken(N); + } - void visitLineFoldEnd(LineFoldEnd* N) { - static_cast(this)->visitToken(N); - } + void visitLineFoldEnd(LineFoldEnd* N) { + static_cast(this)->visitToken(N); + } - void visitCustomOperator(CustomOperator* N) { - static_cast(this)->visitToken(N); - } + void visitCustomOperator(CustomOperator* N) { + static_cast(this)->visitToken(N); + } - void visitAssignment(Assignment* N) { - static_cast(this)->visitToken(N); - } + void visitAssignment(Assignment* N) { + static_cast(this)->visitToken(N); + } - void visitIdentifier(Identifier* N) { - static_cast(this)->visitToken(N); - } + void visitIdentifier(Identifier* N) { + static_cast(this)->visitToken(N); + } - void visitIdentifierAlt(IdentifierAlt* N) { - static_cast(this)->visitToken(N); - } + void visitIdentifierAlt(IdentifierAlt* N) { + static_cast(this)->visitToken(N); + } - void visitStringLiteral(StringLiteral* N) { - static_cast(this)->visitToken(N); - } + void visitStringLiteral(StringLiteral* N) { + static_cast(this)->visitToken(N); + } - void visitIntegerLiteral(IntegerLiteral* N) { - static_cast(this)->visitToken(N); - } + void visitIntegerLiteral(IntegerLiteral* N) { + static_cast(this)->visitToken(N); + } - void visitAnnotation(Annotation* N) { - static_cast(this)->visitNode(N); - } + void visitAnnotation(Annotation* N) { + static_cast(this)->visitNode(N); + } - void visitTypeAssertAnnotation(TypeAssertAnnotation* N) { - static_cast(this)->visitAnnotation(N); - } + void visitTypeAssertAnnotation(TypeAssertAnnotation* N) { + static_cast(this)->visitAnnotation(N); + } - void visitExpressionAnnotation(ExpressionAnnotation* N) { - static_cast(this)->visitAnnotation(N); - } + void visitExpressionAnnotation(ExpressionAnnotation* N) { + static_cast(this)->visitAnnotation(N); + } - void visitConstraintExpression(ConstraintExpression* N) { - static_cast(this)->visitNode(N); - } + void visitConstraintExpression(ConstraintExpression* N) { + static_cast(this)->visitNode(N); + } - void visitTypeclassConstraintExpression(TypeclassConstraintExpression* N) { - static_cast(this)->visitConstraintExpression(N); - } + void visitTypeclassConstraintExpression(TypeclassConstraintExpression* N) { + static_cast(this)->visitConstraintExpression(N); + } - void visitEqualityConstraintExpression(EqualityConstraintExpression* N) { - static_cast(this)->visitConstraintExpression(N); - } + void visitEqualityConstraintExpression(EqualityConstraintExpression* N) { + static_cast(this)->visitConstraintExpression(N); + } - void visitTypeExpression(TypeExpression* N) { - static_cast(this)->visitNode(N); - } + void visitTypeExpression(TypeExpression* N) { + static_cast(this)->visitNode(N); + } - void visitRecordTypeExpressionField(RecordTypeExpressionField * N) { - static_cast(this)->visitNode(N); - } + void visitRecordTypeExpressionField(RecordTypeExpressionField * N) { + static_cast(this)->visitNode(N); + } - void visitRecordTypeExpression(RecordTypeExpression* N) { - static_cast(this)->visitTypeExpression(N); - } + void visitRecordTypeExpression(RecordTypeExpression* N) { + static_cast(this)->visitTypeExpression(N); + } - void visitQualifiedTypeExpression(QualifiedTypeExpression* N) { - static_cast(this)->visitTypeExpression(N); - } + void visitQualifiedTypeExpression(QualifiedTypeExpression* N) { + static_cast(this)->visitTypeExpression(N); + } - void visitReferenceTypeExpression(ReferenceTypeExpression* N) { - static_cast(this)->visitTypeExpression(N); - } + void visitReferenceTypeExpression(ReferenceTypeExpression* N) { + static_cast(this)->visitTypeExpression(N); + } - void visitArrowTypeExpression(ArrowTypeExpression* N) { - static_cast(this)->visitTypeExpression(N); - } + void visitArrowTypeExpression(ArrowTypeExpression* N) { + static_cast(this)->visitTypeExpression(N); + } - void visitAppTypeExpression(AppTypeExpression* N) { - static_cast(this)->visitTypeExpression(N); - } + void visitAppTypeExpression(AppTypeExpression* N) { + static_cast(this)->visitTypeExpression(N); + } - void visitVarTypeExpression(VarTypeExpression* N) { - static_cast(this)->visitTypeExpression(N); - } + void visitVarTypeExpression(VarTypeExpression* N) { + static_cast(this)->visitTypeExpression(N); + } - void visitNestedTypeExpression(NestedTypeExpression* N) { - static_cast(this)->visitTypeExpression(N); - } + void visitNestedTypeExpression(NestedTypeExpression* N) { + static_cast(this)->visitTypeExpression(N); + } - void visitTupleTypeExpression(TupleTypeExpression* N) { - static_cast(this)->visitTypeExpression(N); - } + void visitTupleTypeExpression(TupleTypeExpression* N) { + static_cast(this)->visitTypeExpression(N); + } - void visitWrappedOperator(WrappedOperator* N) { - static_cast(this)->visitNode(N); - } + void visitWrappedOperator(WrappedOperator* N) { + static_cast(this)->visitNode(N); + } - void visitPattern(Pattern* N) { - static_cast(this)->visitNode(N); - } + void visitPattern(Pattern* N) { + static_cast(this)->visitNode(N); + } - void visitBindPattern(BindPattern* N) { - static_cast(this)->visitPattern(N); - } + void visitBindPattern(BindPattern* N) { + static_cast(this)->visitPattern(N); + } - void visitLiteralPattern(LiteralPattern* N) { - static_cast(this)->visitPattern(N); - } + void visitLiteralPattern(LiteralPattern* N) { + static_cast(this)->visitPattern(N); + } - void visitRecordPatternField(RecordPatternField* N) { - static_cast(this)->visitNode(N); - } + void visitRecordPatternField(RecordPatternField* N) { + static_cast(this)->visitNode(N); + } - void visitRecordPattern(RecordPattern* N) { - static_cast(this)->visitPattern(N); - } + void visitRecordPattern(RecordPattern* N) { + static_cast(this)->visitPattern(N); + } - void visitNamedRecordPattern(NamedRecordPattern* N) { - static_cast(this)->visitPattern(N); - } + void visitNamedRecordPattern(NamedRecordPattern* N) { + static_cast(this)->visitPattern(N); + } - void visitNamedTuplePattern(NamedTuplePattern* N) { - static_cast(this)->visitPattern(N); - } + void visitNamedTuplePattern(NamedTuplePattern* N) { + static_cast(this)->visitPattern(N); + } - void visitTuplePattern(TuplePattern* N) { - static_cast(this)->visitPattern(N); - } + void visitTuplePattern(TuplePattern* N) { + static_cast(this)->visitPattern(N); + } - void visitNestedPattern(NestedPattern* N) { - static_cast(this)->visitPattern(N); - } + void visitNestedPattern(NestedPattern* N) { + static_cast(this)->visitPattern(N); + } - void visitListPattern(ListPattern* N) { - static_cast(this)->visitPattern(N); - } + void visitListPattern(ListPattern* N) { + static_cast(this)->visitPattern(N); + } - void visitExpression(Expression* N) { - static_cast(this)->visitNode(N); - } + void visitExpression(Expression* N) { + static_cast(this)->visitNode(N); + } - void visitReferenceExpression(ReferenceExpression* N) { - static_cast(this)->visitExpression(N); - } + void visitReferenceExpression(ReferenceExpression* N) { + static_cast(this)->visitExpression(N); + } - void visitMatchCase(MatchCase* N) { - static_cast(this)->visitNode(N); - } + void visitMatchCase(MatchCase* N) { + static_cast(this)->visitNode(N); + } - void visitMatchExpression(MatchExpression* N) { - static_cast(this)->visitExpression(N); - } + void visitMatchExpression(MatchExpression* N) { + static_cast(this)->visitExpression(N); + } - void visitMemberExpression(MemberExpression* N) { - static_cast(this)->visitExpression(N); - } + void visitMemberExpression(MemberExpression* N) { + static_cast(this)->visitExpression(N); + } - void visitTupleExpression(TupleExpression* N) { - static_cast(this)->visitExpression(N); - } + void visitTupleExpression(TupleExpression* N) { + static_cast(this)->visitExpression(N); + } - void visitNestedExpression(NestedExpression* N) { - static_cast(this)->visitExpression(N); - } + void visitNestedExpression(NestedExpression* N) { + static_cast(this)->visitExpression(N); + } - void visitLiteralExpression(LiteralExpression* N) { - static_cast(this)->visitExpression(N); - } + void visitLiteralExpression(LiteralExpression* N) { + static_cast(this)->visitExpression(N); + } - void visitCallExpression(CallExpression* N) { - static_cast(this)->visitExpression(N); - } + void visitCallExpression(CallExpression* N) { + static_cast(this)->visitExpression(N); + } - void visitInfixExpression(InfixExpression* N) { - static_cast(this)->visitExpression(N); - } + void visitInfixExpression(InfixExpression* N) { + static_cast(this)->visitExpression(N); + } - void visitPrefixExpression(PrefixExpression* N) { - static_cast(this)->visitExpression(N); - } + void visitPrefixExpression(PrefixExpression* N) { + static_cast(this)->visitExpression(N); + } - void visitRecordExpressionField(RecordExpressionField* N) { - static_cast(this)->visitNode(N); - } + void visitRecordExpressionField(RecordExpressionField* N) { + static_cast(this)->visitNode(N); + } - void visitRecordExpression(RecordExpression* N) { - static_cast(this)->visitExpression(N); - } + void visitRecordExpression(RecordExpression* N) { + static_cast(this)->visitExpression(N); + } - void visitStatement(Statement* N) { - static_cast(this)->visitNode(N); - } + void visitStatement(Statement* N) { + static_cast(this)->visitNode(N); + } - void visitExpressionStatement(ExpressionStatement* N) { - static_cast(this)->visitStatement(N); - } + void visitExpressionStatement(ExpressionStatement* N) { + static_cast(this)->visitStatement(N); + } - void visitReturnStatement(ReturnStatement* N) { - static_cast(this)->visitStatement(N); - } + void visitReturnStatement(ReturnStatement* N) { + static_cast(this)->visitStatement(N); + } - void visitIfStatement(IfStatement* N) { - static_cast(this)->visitStatement(N); - } + void visitIfStatement(IfStatement* N) { + static_cast(this)->visitStatement(N); + } - void visitIfStatementPart(IfStatementPart* N) { - static_cast(this)->visitNode(N); - } + void visitIfStatementPart(IfStatementPart* N) { + static_cast(this)->visitNode(N); + } - void visitTypeAssert(TypeAssert* N) { - static_cast(this)->visitNode(N); - } + void visitTypeAssert(TypeAssert* N) { + static_cast(this)->visitNode(N); + } - void visitParameter(Parameter* N) { - static_cast(this)->visitNode(N); - } + void visitParameter(Parameter* N) { + static_cast(this)->visitNode(N); + } - void visitLetBody(LetBody* N) { - static_cast(this)->visitNode(N); - } + void visitLetBody(LetBody* N) { + static_cast(this)->visitNode(N); + } - void visitLetBlockBody(LetBlockBody* N) { - static_cast(this)->visitLetBody(N); - } + void visitLetBlockBody(LetBlockBody* N) { + static_cast(this)->visitLetBody(N); + } - void visitLetExprBody(LetExprBody* N) { - static_cast(this)->visitLetBody(N); - } + void visitLetExprBody(LetExprBody* N) { + static_cast(this)->visitLetBody(N); + } - void visitLetDeclaration(LetDeclaration* N) { - static_cast(this)->visitNode(N); - } + void visitLetDeclaration(LetDeclaration* N) { + static_cast(this)->visitNode(N); + } - void visitRecordDeclarationField(RecordDeclarationField* N) { - static_cast(this)->visitNode(N); - } + void visitRecordDeclarationField(RecordDeclarationField* N) { + static_cast(this)->visitNode(N); + } - void visitRecordDeclaration(RecordDeclaration* N) { - static_cast(this)->visitNode(N); - } + void visitRecordDeclaration(RecordDeclaration* N) { + static_cast(this)->visitNode(N); + } - void visitVariantDeclaration(VariantDeclaration* N) { - static_cast(this)->visitNode(N); - } + void visitVariantDeclaration(VariantDeclaration* N) { + static_cast(this)->visitNode(N); + } - void visitVariantDeclarationMember(VariantDeclarationMember* N) { - static_cast(this)->visitNode(N); - } + void visitVariantDeclarationMember(VariantDeclarationMember* N) { + static_cast(this)->visitNode(N); + } - void visitTupleVariantDeclarationMember(TupleVariantDeclarationMember* N) { - static_cast(this)->visitVariantDeclarationMember(N); - } + void visitTupleVariantDeclarationMember(TupleVariantDeclarationMember* N) { + static_cast(this)->visitVariantDeclarationMember(N); + } - void visitRecordVariantDeclarationMember(RecordVariantDeclarationMember* N) { - static_cast(this)->visitVariantDeclarationMember(N); - } + void visitRecordVariantDeclarationMember(RecordVariantDeclarationMember* N) { + static_cast(this)->visitVariantDeclarationMember(N); + } - void visitClassDeclaration(ClassDeclaration* N) { - static_cast(this)->visitNode(N); - } + void visitClassDeclaration(ClassDeclaration* N) { + static_cast(this)->visitNode(N); + } - void visitInstanceDeclaration(InstanceDeclaration* N) { - static_cast(this)->visitNode(N); - } + void visitInstanceDeclaration(InstanceDeclaration* N) { + static_cast(this)->visitNode(N); + } - void visitSourceFile(SourceFile* N) { - static_cast(this)->visitNode(N); - } + void visitSourceFile(SourceFile* N) { + static_cast(this)->visitNode(N); + } - public: +public: - void visitEachChild(Node* N) { + void visitEachChild(Node* N) { #define BOLT_GEN_CHILD_CASE(name) \ - case NodeKind::name: \ - visitEachChild(static_cast(N)); \ - break; + case NodeKind::name: \ + visitEachChild(static_cast(N)); \ + break; - switch (N->getKind()) { - BOLT_GEN_CHILD_CASE(VBar) - BOLT_GEN_CHILD_CASE(Equals) - BOLT_GEN_CHILD_CASE(Colon) - BOLT_GEN_CHILD_CASE(Comma) - BOLT_GEN_CHILD_CASE(Dot) - BOLT_GEN_CHILD_CASE(DotDot) - BOLT_GEN_CHILD_CASE(Tilde) - BOLT_GEN_CHILD_CASE(At) - BOLT_GEN_CHILD_CASE(LParen) - BOLT_GEN_CHILD_CASE(RParen) - BOLT_GEN_CHILD_CASE(LBracket) - BOLT_GEN_CHILD_CASE(RBracket) - BOLT_GEN_CHILD_CASE(LBrace) - BOLT_GEN_CHILD_CASE(RBrace) - BOLT_GEN_CHILD_CASE(RArrow) - BOLT_GEN_CHILD_CASE(RArrowAlt) - BOLT_GEN_CHILD_CASE(LetKeyword) - BOLT_GEN_CHILD_CASE(ForeignKeyword) - BOLT_GEN_CHILD_CASE(MutKeyword) - BOLT_GEN_CHILD_CASE(PubKeyword) - BOLT_GEN_CHILD_CASE(TypeKeyword) - BOLT_GEN_CHILD_CASE(ReturnKeyword) - BOLT_GEN_CHILD_CASE(ModKeyword) - BOLT_GEN_CHILD_CASE(StructKeyword) - BOLT_GEN_CHILD_CASE(EnumKeyword) - BOLT_GEN_CHILD_CASE(ClassKeyword) - BOLT_GEN_CHILD_CASE(InstanceKeyword) - BOLT_GEN_CHILD_CASE(ElifKeyword) - BOLT_GEN_CHILD_CASE(IfKeyword) - BOLT_GEN_CHILD_CASE(ElseKeyword) - BOLT_GEN_CHILD_CASE(MatchKeyword) - BOLT_GEN_CHILD_CASE(Invalid) - BOLT_GEN_CHILD_CASE(EndOfFile) - BOLT_GEN_CHILD_CASE(BlockStart) - BOLT_GEN_CHILD_CASE(BlockEnd) - BOLT_GEN_CHILD_CASE(LineFoldEnd) - BOLT_GEN_CHILD_CASE(CustomOperator) - BOLT_GEN_CHILD_CASE(Assignment) - BOLT_GEN_CHILD_CASE(Identifier) - BOLT_GEN_CHILD_CASE(IdentifierAlt) - BOLT_GEN_CHILD_CASE(WrappedOperator) - BOLT_GEN_CHILD_CASE(StringLiteral) - BOLT_GEN_CHILD_CASE(IntegerLiteral) - BOLT_GEN_CHILD_CASE(ExpressionAnnotation) - BOLT_GEN_CHILD_CASE(TypeAssertAnnotation) - BOLT_GEN_CHILD_CASE(TypeclassConstraintExpression) - BOLT_GEN_CHILD_CASE(EqualityConstraintExpression) - BOLT_GEN_CHILD_CASE(RecordTypeExpressionField) - BOLT_GEN_CHILD_CASE(RecordTypeExpression) - BOLT_GEN_CHILD_CASE(QualifiedTypeExpression) - BOLT_GEN_CHILD_CASE(ReferenceTypeExpression) - BOLT_GEN_CHILD_CASE(ArrowTypeExpression) - BOLT_GEN_CHILD_CASE(AppTypeExpression) - BOLT_GEN_CHILD_CASE(VarTypeExpression) - BOLT_GEN_CHILD_CASE(NestedTypeExpression) - BOLT_GEN_CHILD_CASE(TupleTypeExpression) - BOLT_GEN_CHILD_CASE(BindPattern) - BOLT_GEN_CHILD_CASE(LiteralPattern) - BOLT_GEN_CHILD_CASE(RecordPatternField) - BOLT_GEN_CHILD_CASE(RecordPattern) - BOLT_GEN_CHILD_CASE(NamedRecordPattern) - BOLT_GEN_CHILD_CASE(NamedTuplePattern) - BOLT_GEN_CHILD_CASE(TuplePattern) - BOLT_GEN_CHILD_CASE(NestedPattern) - BOLT_GEN_CHILD_CASE(ListPattern) - BOLT_GEN_CHILD_CASE(ReferenceExpression) - BOLT_GEN_CHILD_CASE(MatchCase) - BOLT_GEN_CHILD_CASE(MatchExpression) - BOLT_GEN_CHILD_CASE(MemberExpression) - BOLT_GEN_CHILD_CASE(TupleExpression) - BOLT_GEN_CHILD_CASE(NestedExpression) - BOLT_GEN_CHILD_CASE(LiteralExpression) - BOLT_GEN_CHILD_CASE(CallExpression) - BOLT_GEN_CHILD_CASE(InfixExpression) - BOLT_GEN_CHILD_CASE(PrefixExpression) - BOLT_GEN_CHILD_CASE(RecordExpressionField) - BOLT_GEN_CHILD_CASE(RecordExpression) - BOLT_GEN_CHILD_CASE(ExpressionStatement) - BOLT_GEN_CHILD_CASE(ReturnStatement) - BOLT_GEN_CHILD_CASE(IfStatement) - BOLT_GEN_CHILD_CASE(IfStatementPart) - BOLT_GEN_CHILD_CASE(TypeAssert) - BOLT_GEN_CHILD_CASE(Parameter) - BOLT_GEN_CHILD_CASE(LetBlockBody) - BOLT_GEN_CHILD_CASE(LetExprBody) - BOLT_GEN_CHILD_CASE(LetDeclaration) - BOLT_GEN_CHILD_CASE(RecordDeclaration) - BOLT_GEN_CHILD_CASE(RecordDeclarationField) - BOLT_GEN_CHILD_CASE(VariantDeclaration) - BOLT_GEN_CHILD_CASE(TupleVariantDeclarationMember) - BOLT_GEN_CHILD_CASE(RecordVariantDeclarationMember) - BOLT_GEN_CHILD_CASE(ClassDeclaration) - BOLT_GEN_CHILD_CASE(InstanceDeclaration) - BOLT_GEN_CHILD_CASE(SourceFile) - } + switch (N->getKind()) { + BOLT_GEN_CHILD_CASE(VBar) + BOLT_GEN_CHILD_CASE(Equals) + BOLT_GEN_CHILD_CASE(Colon) + BOLT_GEN_CHILD_CASE(Comma) + BOLT_GEN_CHILD_CASE(Dot) + BOLT_GEN_CHILD_CASE(DotDot) + BOLT_GEN_CHILD_CASE(Tilde) + BOLT_GEN_CHILD_CASE(At) + BOLT_GEN_CHILD_CASE(LParen) + BOLT_GEN_CHILD_CASE(RParen) + BOLT_GEN_CHILD_CASE(LBracket) + BOLT_GEN_CHILD_CASE(RBracket) + BOLT_GEN_CHILD_CASE(LBrace) + BOLT_GEN_CHILD_CASE(RBrace) + BOLT_GEN_CHILD_CASE(RArrow) + BOLT_GEN_CHILD_CASE(RArrowAlt) + BOLT_GEN_CHILD_CASE(LetKeyword) + BOLT_GEN_CHILD_CASE(ForeignKeyword) + BOLT_GEN_CHILD_CASE(MutKeyword) + BOLT_GEN_CHILD_CASE(PubKeyword) + BOLT_GEN_CHILD_CASE(TypeKeyword) + BOLT_GEN_CHILD_CASE(ReturnKeyword) + BOLT_GEN_CHILD_CASE(ModKeyword) + BOLT_GEN_CHILD_CASE(StructKeyword) + BOLT_GEN_CHILD_CASE(EnumKeyword) + BOLT_GEN_CHILD_CASE(ClassKeyword) + BOLT_GEN_CHILD_CASE(InstanceKeyword) + BOLT_GEN_CHILD_CASE(ElifKeyword) + BOLT_GEN_CHILD_CASE(IfKeyword) + BOLT_GEN_CHILD_CASE(ElseKeyword) + BOLT_GEN_CHILD_CASE(MatchKeyword) + BOLT_GEN_CHILD_CASE(Invalid) + BOLT_GEN_CHILD_CASE(EndOfFile) + BOLT_GEN_CHILD_CASE(BlockStart) + BOLT_GEN_CHILD_CASE(BlockEnd) + BOLT_GEN_CHILD_CASE(LineFoldEnd) + BOLT_GEN_CHILD_CASE(CustomOperator) + BOLT_GEN_CHILD_CASE(Assignment) + BOLT_GEN_CHILD_CASE(Identifier) + BOLT_GEN_CHILD_CASE(IdentifierAlt) + BOLT_GEN_CHILD_CASE(WrappedOperator) + BOLT_GEN_CHILD_CASE(StringLiteral) + BOLT_GEN_CHILD_CASE(IntegerLiteral) + BOLT_GEN_CHILD_CASE(ExpressionAnnotation) + BOLT_GEN_CHILD_CASE(TypeAssertAnnotation) + BOLT_GEN_CHILD_CASE(TypeclassConstraintExpression) + BOLT_GEN_CHILD_CASE(EqualityConstraintExpression) + BOLT_GEN_CHILD_CASE(RecordTypeExpressionField) + BOLT_GEN_CHILD_CASE(RecordTypeExpression) + BOLT_GEN_CHILD_CASE(QualifiedTypeExpression) + BOLT_GEN_CHILD_CASE(ReferenceTypeExpression) + BOLT_GEN_CHILD_CASE(ArrowTypeExpression) + BOLT_GEN_CHILD_CASE(AppTypeExpression) + BOLT_GEN_CHILD_CASE(VarTypeExpression) + BOLT_GEN_CHILD_CASE(NestedTypeExpression) + BOLT_GEN_CHILD_CASE(TupleTypeExpression) + BOLT_GEN_CHILD_CASE(BindPattern) + BOLT_GEN_CHILD_CASE(LiteralPattern) + BOLT_GEN_CHILD_CASE(RecordPatternField) + BOLT_GEN_CHILD_CASE(RecordPattern) + BOLT_GEN_CHILD_CASE(NamedRecordPattern) + BOLT_GEN_CHILD_CASE(NamedTuplePattern) + BOLT_GEN_CHILD_CASE(TuplePattern) + BOLT_GEN_CHILD_CASE(NestedPattern) + BOLT_GEN_CHILD_CASE(ListPattern) + BOLT_GEN_CHILD_CASE(ReferenceExpression) + BOLT_GEN_CHILD_CASE(MatchCase) + BOLT_GEN_CHILD_CASE(MatchExpression) + BOLT_GEN_CHILD_CASE(MemberExpression) + BOLT_GEN_CHILD_CASE(TupleExpression) + BOLT_GEN_CHILD_CASE(NestedExpression) + BOLT_GEN_CHILD_CASE(LiteralExpression) + BOLT_GEN_CHILD_CASE(CallExpression) + BOLT_GEN_CHILD_CASE(InfixExpression) + BOLT_GEN_CHILD_CASE(PrefixExpression) + BOLT_GEN_CHILD_CASE(RecordExpressionField) + BOLT_GEN_CHILD_CASE(RecordExpression) + BOLT_GEN_CHILD_CASE(ExpressionStatement) + BOLT_GEN_CHILD_CASE(ReturnStatement) + BOLT_GEN_CHILD_CASE(IfStatement) + BOLT_GEN_CHILD_CASE(IfStatementPart) + BOLT_GEN_CHILD_CASE(TypeAssert) + BOLT_GEN_CHILD_CASE(Parameter) + BOLT_GEN_CHILD_CASE(LetBlockBody) + BOLT_GEN_CHILD_CASE(LetExprBody) + BOLT_GEN_CHILD_CASE(LetDeclaration) + BOLT_GEN_CHILD_CASE(RecordDeclaration) + BOLT_GEN_CHILD_CASE(RecordDeclarationField) + BOLT_GEN_CHILD_CASE(VariantDeclaration) + BOLT_GEN_CHILD_CASE(TupleVariantDeclarationMember) + BOLT_GEN_CHILD_CASE(RecordVariantDeclarationMember) + BOLT_GEN_CHILD_CASE(ClassDeclaration) + BOLT_GEN_CHILD_CASE(InstanceDeclaration) + BOLT_GEN_CHILD_CASE(SourceFile) } + } #define BOLT_VISIT(node) static_cast(this)->visit(node) - void visitEachChild(VBar* N) { - } + void visitEachChild(VBar* N) { + } - void visitEachChild(Equals* N) { - } + void visitEachChild(Equals* N) { + } - void visitEachChild(Colon* N) { - } + void visitEachChild(Colon* N) { + } - void visitEachChild(Comma* N) { - } + void visitEachChild(Comma* N) { + } - void visitEachChild(Dot* N) { - } + void visitEachChild(Dot* N) { + } - void visitEachChild(DotDot* N) { - } + void visitEachChild(DotDot* N) { + } - void visitEachChild(Tilde* N) { - } + void visitEachChild(Tilde* N) { + } - void visitEachChild(At* N) { - } + void visitEachChild(At* N) { + } - void visitEachChild(LParen* N) { - } + void visitEachChild(LParen* N) { + } - void visitEachChild(RParen* N) { - } + void visitEachChild(RParen* N) { + } - void visitEachChild(LBracket* N) { - } + void visitEachChild(LBracket* N) { + } - void visitEachChild(RBracket* N) { - } + void visitEachChild(RBracket* N) { + } - void visitEachChild(LBrace* N) { - } + void visitEachChild(LBrace* N) { + } - void visitEachChild(RBrace* N) { - } + void visitEachChild(RBrace* N) { + } - void visitEachChild(RArrow* N) { - } + void visitEachChild(RArrow* N) { + } - void visitEachChild(RArrowAlt* N) { - } + void visitEachChild(RArrowAlt* N) { + } - void visitEachChild(LetKeyword* N) { - } + void visitEachChild(LetKeyword* N) { + } - void visitEachChild(ForeignKeyword* N) { - } + void visitEachChild(ForeignKeyword* N) { + } - void visitEachChild(MutKeyword* N) { - } + void visitEachChild(MutKeyword* N) { + } - void visitEachChild(PubKeyword* N) { - } + void visitEachChild(PubKeyword* N) { + } - void visitEachChild(TypeKeyword* N) { - } + void visitEachChild(TypeKeyword* N) { + } - void visitEachChild(ReturnKeyword* N) { - } + void visitEachChild(ReturnKeyword* N) { + } - void visitEachChild(ModKeyword* N) { - } + void visitEachChild(ModKeyword* N) { + } - void visitEachChild(StructKeyword* N) { - } + void visitEachChild(StructKeyword* N) { + } - void visitEachChild(EnumKeyword* N) { - } + void visitEachChild(EnumKeyword* N) { + } - void visitEachChild(ClassKeyword* N) { - } + void visitEachChild(ClassKeyword* N) { + } - void visitEachChild(InstanceKeyword* N) { - } + void visitEachChild(InstanceKeyword* N) { + } - void visitEachChild(ElifKeyword* N) { - } + void visitEachChild(ElifKeyword* N) { + } - void visitEachChild(IfKeyword* N) { - } + void visitEachChild(IfKeyword* N) { + } - void visitEachChild(ElseKeyword* N) { - } + void visitEachChild(ElseKeyword* N) { + } - void visitEachChild(MatchKeyword* N) { - } + void visitEachChild(MatchKeyword* N) { + } - void visitEachChild(Invalid* N) { - } + void visitEachChild(Invalid* N) { + } - void visitEachChild(EndOfFile* N) { - } + void visitEachChild(EndOfFile* N) { + } - void visitEachChild(BlockStart* N) { - } + void visitEachChild(BlockStart* N) { + } - void visitEachChild(BlockEnd* N) { - } + void visitEachChild(BlockEnd* N) { + } - void visitEachChild(LineFoldEnd* N) { - } + void visitEachChild(LineFoldEnd* N) { + } - void visitEachChild(CustomOperator* N) { - } + void visitEachChild(CustomOperator* N) { + } - void visitEachChild(Assignment* N) { - } + void visitEachChild(Assignment* N) { + } - void visitEachChild(Identifier* N) { - } + void visitEachChild(Identifier* N) { + } - void visitEachChild(IdentifierAlt* N) { - } + void visitEachChild(IdentifierAlt* N) { + } - void visitEachChild(StringLiteral* N) { - } + void visitEachChild(StringLiteral* N) { + } - void visitEachChild(IntegerLiteral* N) { - } + void visitEachChild(IntegerLiteral* N) { + } - void visitEachChild(WrappedOperator* N) { - BOLT_VISIT(N->LParen); - BOLT_VISIT(N->Op); - BOLT_VISIT(N->RParen); - } + void visitEachChild(WrappedOperator* N) { + BOLT_VISIT(N->LParen); + BOLT_VISIT(N->Op); + BOLT_VISIT(N->RParen); + } - void visitEachChild(ExpressionAnnotation* N) { - BOLT_VISIT(N->At); - BOLT_VISIT(N->Expression); - } + void visitEachChild(ExpressionAnnotation* N) { + BOLT_VISIT(N->At); + BOLT_VISIT(N->Expression); + } - void visitEachChild(TypeAssertAnnotation* N) { - BOLT_VISIT(N->At); - BOLT_VISIT(N->Colon); - BOLT_VISIT(N->TE); - } + void visitEachChild(TypeAssertAnnotation* N) { + BOLT_VISIT(N->At); + BOLT_VISIT(N->Colon); + BOLT_VISIT(N->TE); + } - void visitEachChild(TypeclassConstraintExpression* N) { + void visitEachChild(TypeclassConstraintExpression* N) { + BOLT_VISIT(N->Name); + for (auto TE: N->TEs) { + BOLT_VISIT(TE); + } + } + + void visitEachChild(EqualityConstraintExpression* N) { + BOLT_VISIT(N->Left); + BOLT_VISIT(N->Tilde); + BOLT_VISIT(N->Right); + } + + void visitEachChild(RecordTypeExpressionField* N) { + BOLT_VISIT(N->Name); + BOLT_VISIT(N->Colon); + BOLT_VISIT(N->TE); + } + + void visitEachChild(RecordTypeExpression* N) { + BOLT_VISIT(N->LBrace); + for (auto [Field, Comma]: N->Fields) { + BOLT_VISIT(Field); + if (Comma) { + BOLT_VISIT(Comma); + } + } + if (N->VBar) { + BOLT_VISIT(N->VBar); + } + if (N->Rest) { + BOLT_VISIT(N->Rest); + } + BOLT_VISIT(N->RBrace); + } + + void visitEachChild(QualifiedTypeExpression* N) { + for (auto [CE, Comma]: N->Constraints) { + BOLT_VISIT(CE); + if (Comma) { + BOLT_VISIT(Comma); + } + } + BOLT_VISIT(N->RArrowAlt); + BOLT_VISIT(N->TE); + } + + void visitEachChild(ReferenceTypeExpression* N) { + for (auto [Name, Dot]: N->ModulePath) { + BOLT_VISIT(Name); + BOLT_VISIT(Dot); + } + BOLT_VISIT(N->Name); + } + + void visitEachChild(ArrowTypeExpression* N) { + for (auto PT: N->ParamTypes) { + BOLT_VISIT(PT); + } + BOLT_VISIT(N->ReturnType); + } + + void visitEachChild(AppTypeExpression* N) { + BOLT_VISIT(N->Op); + for (auto Arg: N->Args) { + BOLT_VISIT(Arg); + } + } + + void visitEachChild(VarTypeExpression* N) { + BOLT_VISIT(N->Name); + } + + void visitEachChild(NestedTypeExpression* N) { + BOLT_VISIT(N->LParen); + BOLT_VISIT(N->TE); + BOLT_VISIT(N->RParen); + } + + void visitEachChild(TupleTypeExpression* N) { + BOLT_VISIT(N->LParen); + for (auto [TE, Comma]: N->Elements) { + if (Comma) { + BOLT_VISIT(Comma); + } + BOLT_VISIT(TE); + } + BOLT_VISIT(N->RParen); + } + + void visitEachChild(BindPattern* N) { + BOLT_VISIT(N->Name); + } + + void visitEachChild(LiteralPattern* N) { + BOLT_VISIT(N->Literal); + } + + void visitEachChild(RecordPatternField* N) { + if (N->DotDot) { + BOLT_VISIT(N->DotDot); + } + if (N->Name) { BOLT_VISIT(N->Name); - for (auto TE: N->TEs) { - BOLT_VISIT(TE); + } + if (N->Equals) { + BOLT_VISIT(N->Equals); + } + if (N->Pattern) { + BOLT_VISIT(N->Pattern); + } + } + + void visitEachChild(RecordPattern* N) { + BOLT_VISIT(N->LBrace); + for (auto [Field, Comma]: N->Fields) { + BOLT_VISIT(Field); + if (Comma) { + BOLT_VISIT(Comma); } } + BOLT_VISIT(N->RBrace); + } - void visitEachChild(EqualityConstraintExpression* N) { - BOLT_VISIT(N->Left); - BOLT_VISIT(N->Tilde); - BOLT_VISIT(N->Right); - } - - void visitEachChild(RecordTypeExpressionField* N) { - BOLT_VISIT(N->Name); - BOLT_VISIT(N->Colon); - BOLT_VISIT(N->TE); - } - - void visitEachChild(RecordTypeExpression* N) { - BOLT_VISIT(N->LBrace); - for (auto [Field, Comma]: N->Fields) { - BOLT_VISIT(Field); - if (Comma) { - BOLT_VISIT(Comma); - } - } - if (N->VBar) { - BOLT_VISIT(N->VBar); - } - if (N->Rest) { - BOLT_VISIT(N->Rest); - } - BOLT_VISIT(N->RBrace); - } - - void visitEachChild(QualifiedTypeExpression* N) { - for (auto [CE, Comma]: N->Constraints) { - BOLT_VISIT(CE); - if (Comma) { - BOLT_VISIT(Comma); - } - } - BOLT_VISIT(N->RArrowAlt); - BOLT_VISIT(N->TE); - } - - void visitEachChild(ReferenceTypeExpression* N) { - for (auto [Name, Dot]: N->ModulePath) { - BOLT_VISIT(Name); + void visitEachChild(NamedRecordPattern* N) { + for (auto [Name, Dot]: N->ModulePath) { + BOLT_VISIT(Name); + if (Dot) { BOLT_VISIT(Dot); } - BOLT_VISIT(N->Name); } - - void visitEachChild(ArrowTypeExpression* N) { - for (auto PT: N->ParamTypes) { - BOLT_VISIT(PT); - } - BOLT_VISIT(N->ReturnType); - } - - void visitEachChild(AppTypeExpression* N) { - BOLT_VISIT(N->Op); - for (auto Arg: N->Args) { - BOLT_VISIT(Arg); + BOLT_VISIT(N->Name); + BOLT_VISIT(N->LBrace); + for (auto [Field, Comma]: N->Fields) { + BOLT_VISIT(Field); + if (Comma) { + BOLT_VISIT(Comma); } } + BOLT_VISIT(N->LBrace); + BOLT_VISIT(N->RBrace); + } - void visitEachChild(VarTypeExpression* N) { - BOLT_VISIT(N->Name); + void visitEachChild(NamedTuplePattern* N) { + BOLT_VISIT(N->Name); + for (auto P: N->Patterns) { + BOLT_VISIT(P); } + } - void visitEachChild(NestedTypeExpression* N) { - BOLT_VISIT(N->LParen); - BOLT_VISIT(N->TE); - BOLT_VISIT(N->RParen); - } - - void visitEachChild(TupleTypeExpression* N) { - BOLT_VISIT(N->LParen); - for (auto [TE, Comma]: N->Elements) { - if (Comma) { - BOLT_VISIT(Comma); - } - BOLT_VISIT(TE); - } - BOLT_VISIT(N->RParen); - } - - void visitEachChild(BindPattern* N) { - BOLT_VISIT(N->Name); - } - - void visitEachChild(LiteralPattern* N) { - BOLT_VISIT(N->Literal); - } - - void visitEachChild(RecordPatternField* N) { - if (N->DotDot) { - BOLT_VISIT(N->DotDot); - } - if (N->Name) { - BOLT_VISIT(N->Name); - } - if (N->Equals) { - BOLT_VISIT(N->Equals); - } - if (N->Pattern) { - BOLT_VISIT(N->Pattern); + void visitEachChild(TuplePattern* N) { + BOLT_VISIT(N->LParen); + for (auto [P, Comma]: N->Elements) { + BOLT_VISIT(P); + if (Comma) { + BOLT_VISIT(Comma); } } + BOLT_VISIT(N->RParen); + } - void visitEachChild(RecordPattern* N) { - BOLT_VISIT(N->LBrace); - for (auto [Field, Comma]: N->Fields) { - BOLT_VISIT(Field); - if (Comma) { - BOLT_VISIT(Comma); - } - } - BOLT_VISIT(N->RBrace); - } + void visitEachChild(NestedPattern* N) { + BOLT_VISIT(N->LParen); + BOLT_VISIT(N->P); + BOLT_VISIT(N->RParen); + } - void visitEachChild(NamedRecordPattern* N) { - for (auto [Name, Dot]: N->ModulePath) { - BOLT_VISIT(Name); - if (Dot) { - BOLT_VISIT(Dot); - } - } - BOLT_VISIT(N->Name); - BOLT_VISIT(N->LBrace); - for (auto [Field, Comma]: N->Fields) { - BOLT_VISIT(Field); - if (Comma) { - BOLT_VISIT(Comma); - } - } - BOLT_VISIT(N->LBrace); - BOLT_VISIT(N->RBrace); - } - - void visitEachChild(NamedTuplePattern* N) { - BOLT_VISIT(N->Name); - for (auto P: N->Patterns) { - BOLT_VISIT(P); + void visitEachChild(ListPattern* N) { + BOLT_VISIT(N->LBracket); + for (auto [Element, Separator]: N->Elements) { + BOLT_VISIT(Element); + if (Separator) { + BOLT_VISIT(Separator); } } + BOLT_VISIT(N->RBracket); + } - void visitEachChild(TuplePattern* N) { - BOLT_VISIT(N->LParen); - for (auto [P, Comma]: N->Elements) { - BOLT_VISIT(P); - if (Comma) { - BOLT_VISIT(Comma); - } - } - BOLT_VISIT(N->RParen); + void visitEachChild(ReferenceExpression* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); } - - void visitEachChild(NestedPattern* N) { - BOLT_VISIT(N->LParen); - BOLT_VISIT(N->P); - BOLT_VISIT(N->RParen); + for (auto [Name, Dot]: N->ModulePath) { + BOLT_VISIT(Name); + BOLT_VISIT(Dot); } + BOLT_VISIT(N->Name); + } - void visitEachChild(ListPattern* N) { - BOLT_VISIT(N->LBracket); - for (auto [Element, Separator]: N->Elements) { - BOLT_VISIT(Element); - if (Separator) { - BOLT_VISIT(Separator); - } - } - BOLT_VISIT(N->RBracket); + void visitEachChild(MatchCase* N) { + BOLT_VISIT(N->Pattern); + BOLT_VISIT(N->RArrowAlt); + BOLT_VISIT(N->Expression); + } + + void visitEachChild(MatchExpression* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); } - - void visitEachChild(ReferenceExpression* N) { - for (auto A: N->Annotations) { - BOLT_VISIT(A); - } - for (auto [Name, Dot]: N->ModulePath) { - BOLT_VISIT(Name); - BOLT_VISIT(Dot); - } - BOLT_VISIT(N->Name); + BOLT_VISIT(N->MatchKeyword); + if (N->Value) { + BOLT_VISIT(N->Value); } - - void visitEachChild(MatchCase* N) { - BOLT_VISIT(N->Pattern); - BOLT_VISIT(N->RArrowAlt); - BOLT_VISIT(N->Expression); + BOLT_VISIT(N->BlockStart); + for (auto Case: N->Cases) { + BOLT_VISIT(Case); } + } - void visitEachChild(MatchExpression* N) { - for (auto A: N->Annotations) { - BOLT_VISIT(A); - } - BOLT_VISIT(N->MatchKeyword); - if (N->Value) { - BOLT_VISIT(N->Value); - } - BOLT_VISIT(N->BlockStart); - for (auto Case: N->Cases) { - BOLT_VISIT(Case); + void visitEachChild(MemberExpression* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); + } + BOLT_VISIT(N->getExpression()); + BOLT_VISIT(N->Dot); + BOLT_VISIT(N->Name); + } + + void visitEachChild(TupleExpression* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); + } + BOLT_VISIT(N->LParen); + for (auto [E, Comma]: N->Elements) { + BOLT_VISIT(E); + if (Comma) { + BOLT_VISIT(Comma); } } + BOLT_VISIT(N->RParen); + } - void visitEachChild(MemberExpression* N) { - for (auto A: N->Annotations) { - BOLT_VISIT(A); - } - BOLT_VISIT(N->getExpression()); - BOLT_VISIT(N->Dot); - BOLT_VISIT(N->Name); + void visitEachChild(NestedExpression* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); } + BOLT_VISIT(N->LParen); + BOLT_VISIT(N->Inner); + BOLT_VISIT(N->RParen); + } - void visitEachChild(TupleExpression* N) { - for (auto A: N->Annotations) { - BOLT_VISIT(A); - } - BOLT_VISIT(N->LParen); - for (auto [E, Comma]: N->Elements) { - BOLT_VISIT(E); - if (Comma) { - BOLT_VISIT(Comma); - } - } - BOLT_VISIT(N->RParen); + void visitEachChild(LiteralExpression* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); } + BOLT_VISIT(N->Token); + } - void visitEachChild(NestedExpression* N) { - for (auto A: N->Annotations) { - BOLT_VISIT(A); - } - BOLT_VISIT(N->LParen); - BOLT_VISIT(N->Inner); - BOLT_VISIT(N->RParen); + void visitEachChild(CallExpression* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); } - - void visitEachChild(LiteralExpression* N) { - for (auto A: N->Annotations) { - BOLT_VISIT(A); - } - BOLT_VISIT(N->Token); + BOLT_VISIT(N->Function); + for (auto Arg: N->Args) { + BOLT_VISIT(Arg); } + } - void visitEachChild(CallExpression* N) { - for (auto A: N->Annotations) { - BOLT_VISIT(A); - } - BOLT_VISIT(N->Function); - for (auto Arg: N->Args) { - BOLT_VISIT(Arg); + void visitEachChild(InfixExpression* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); + } + BOLT_VISIT(N->Left); + BOLT_VISIT(N->Operator); + BOLT_VISIT(N->Right); + } + + void visitEachChild(PrefixExpression* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); + } + BOLT_VISIT(N->Operator); + BOLT_VISIT(N->Argument); + } + + void visitEachChild(RecordExpressionField* N) { + BOLT_VISIT(N->Name); + BOLT_VISIT(N->Equals); + BOLT_VISIT(N->E); + } + + void visitEachChild(RecordExpression* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); + } + BOLT_VISIT(N->LBrace); + for (auto [Field, Comma]: N->Fields) { + BOLT_VISIT(Field); + if (Comma) { + BOLT_VISIT(Comma); } } + BOLT_VISIT(N->RBrace); + } - void visitEachChild(InfixExpression* N) { - for (auto A: N->Annotations) { - BOLT_VISIT(A); - } - BOLT_VISIT(N->Left); - BOLT_VISIT(N->Operator); - BOLT_VISIT(N->Right); + void visitEachChild(ExpressionStatement* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); } + BOLT_VISIT(N->Expression); + } - void visitEachChild(PrefixExpression* N) { - for (auto A: N->Annotations) { - BOLT_VISIT(A); - } - BOLT_VISIT(N->Operator); - BOLT_VISIT(N->Argument); + void visitEachChild(ReturnStatement* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); } + BOLT_VISIT(N->ReturnKeyword); + BOLT_VISIT(N->Expression); + } - void visitEachChild(RecordExpressionField* N) { - BOLT_VISIT(N->Name); - BOLT_VISIT(N->Equals); - BOLT_VISIT(N->E); + void visitEachChild(IfStatement* N) { + for (auto Part: N->Parts) { + BOLT_VISIT(Part); } + } - void visitEachChild(RecordExpression* N) { - for (auto A: N->Annotations) { - BOLT_VISIT(A); - } - BOLT_VISIT(N->LBrace); - for (auto [Field, Comma]: N->Fields) { - BOLT_VISIT(Field); - if (Comma) { - BOLT_VISIT(Comma); - } - } - BOLT_VISIT(N->RBrace); + void visitEachChild(IfStatementPart* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); } - - void visitEachChild(ExpressionStatement* N) { - for (auto A: N->Annotations) { - BOLT_VISIT(A); - } - BOLT_VISIT(N->Expression); + BOLT_VISIT(N->Keyword); + if (N->Test != nullptr) { + BOLT_VISIT(N->Test); } - - void visitEachChild(ReturnStatement* N) { - for (auto A: N->Annotations) { - BOLT_VISIT(A); - } - BOLT_VISIT(N->ReturnKeyword); - BOLT_VISIT(N->Expression); + BOLT_VISIT(N->BlockStart); + for (auto Element: N->Elements) { + BOLT_VISIT(Element); } + } - void visitEachChild(IfStatement* N) { - for (auto Part: N->Parts) { - BOLT_VISIT(Part); - } + void visitEachChild(TypeAssert* N) { + BOLT_VISIT(N->Colon); + BOLT_VISIT(N->TypeExpression); + } + + void visitEachChild(Parameter* N) { + BOLT_VISIT(N->Pattern); + if (N->TypeAssert != nullptr) { + BOLT_VISIT(N->TypeAssert); } + } - void visitEachChild(IfStatementPart* N) { - for (auto A: N->Annotations) { - BOLT_VISIT(A); - } - BOLT_VISIT(N->Keyword); - if (N->Test != nullptr) { - BOLT_VISIT(N->Test); - } - BOLT_VISIT(N->BlockStart); - for (auto Element: N->Elements) { - BOLT_VISIT(Element); - } + void visitEachChild(LetBlockBody* N) { + BOLT_VISIT(N->BlockStart); + for (auto Element: N->Elements) { + BOLT_VISIT(Element); } + } - void visitEachChild(TypeAssert* N) { - BOLT_VISIT(N->Colon); - BOLT_VISIT(N->TypeExpression); + void visitEachChild(LetExprBody* N) { + BOLT_VISIT(N->Equals); + BOLT_VISIT(N->Expression); + } + + void visitEachChild(LetDeclaration* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); } - - void visitEachChild(Parameter* N) { - BOLT_VISIT(N->Pattern); - if (N->TypeAssert != nullptr) { - BOLT_VISIT(N->TypeAssert); - } + if (N->PubKeyword) { + BOLT_VISIT(N->PubKeyword); } - - void visitEachChild(LetBlockBody* N) { - BOLT_VISIT(N->BlockStart); - for (auto Element: N->Elements) { - BOLT_VISIT(Element); - } + if (N->ForeignKeyword) { + BOLT_VISIT(N->ForeignKeyword); } - - void visitEachChild(LetExprBody* N) { - BOLT_VISIT(N->Equals); - BOLT_VISIT(N->Expression); + BOLT_VISIT(N->LetKeyword); + BOLT_VISIT(N->Pattern); + for (auto Param: N->Params) { + BOLT_VISIT(Param); } - - void visitEachChild(LetDeclaration* N) { - for (auto A: N->Annotations) { - BOLT_VISIT(A); - } - if (N->PubKeyword) { - BOLT_VISIT(N->PubKeyword); - } - if (N->ForeignKeyword) { - BOLT_VISIT(N->ForeignKeyword); - } - BOLT_VISIT(N->LetKeyword); - BOLT_VISIT(N->Pattern); - for (auto Param: N->Params) { - BOLT_VISIT(Param); - } - if (N->TypeAssert) { - BOLT_VISIT(N->TypeAssert); - } - if (N->Body) { - BOLT_VISIT(N->Body); - } + if (N->TypeAssert) { + BOLT_VISIT(N->TypeAssert); } - - void visitEachChild(RecordDeclarationField* N) { - BOLT_VISIT(N->Name); - BOLT_VISIT(N->Colon); - BOLT_VISIT(N->TypeExpression); + if (N->Body) { + BOLT_VISIT(N->Body); } + } - void visitEachChild(RecordDeclaration* N) { - if (N->PubKeyword) { - BOLT_VISIT(N->PubKeyword); - } - BOLT_VISIT(N->StructKeyword); - BOLT_VISIT(N->Name); - BOLT_VISIT(N->StructKeyword); - for (auto Field: N->Fields) { - BOLT_VISIT(Field); - } + void visitEachChild(RecordDeclarationField* N) { + BOLT_VISIT(N->Name); + BOLT_VISIT(N->Colon); + BOLT_VISIT(N->TypeExpression); + } + + void visitEachChild(RecordDeclaration* N) { + if (N->PubKeyword) { + BOLT_VISIT(N->PubKeyword); } - - void visitEachChild(VariantDeclaration* N) { - if (N->PubKeyword) { - BOLT_VISIT(N->PubKeyword); - } - BOLT_VISIT(N->EnumKeyword); - BOLT_VISIT(N->Name); - for (auto TV: N->TVs) { - BOLT_VISIT(TV); - } - BOLT_VISIT(N->BlockStart); - for (auto Member: N->Members) { - BOLT_VISIT(Member); - } + BOLT_VISIT(N->StructKeyword); + BOLT_VISIT(N->Name); + BOLT_VISIT(N->StructKeyword); + for (auto Field: N->Fields) { + BOLT_VISIT(Field); } + } - void visitEachChild(TupleVariantDeclarationMember* N) { - BOLT_VISIT(N->Name); - for (auto Element: N->Elements) { - BOLT_VISIT(Element); - } + void visitEachChild(VariantDeclaration* N) { + if (N->PubKeyword) { + BOLT_VISIT(N->PubKeyword); } - - void visitEachChild(RecordVariantDeclarationMember* N) { - BOLT_VISIT(N->Name); - BOLT_VISIT(N->BlockStart); - for (auto Field: N->Fields) { - BOLT_VISIT(Field); - } + BOLT_VISIT(N->EnumKeyword); + BOLT_VISIT(N->Name); + for (auto TV: N->TVs) { + BOLT_VISIT(TV); } - - void visitEachChild(ClassDeclaration* N) { - if (N->PubKeyword) { - BOLT_VISIT(N->PubKeyword); - } - BOLT_VISIT(N->ClassKeyword); - BOLT_VISIT(N->Name); - for (auto Name: N->TypeVars) { - BOLT_VISIT(Name); - } - BOLT_VISIT(N->BlockStart); - for (auto Element: N->Elements) { - BOLT_VISIT(Element); - } + BOLT_VISIT(N->BlockStart); + for (auto Member: N->Members) { + BOLT_VISIT(Member); } + } - void visitEachChild(InstanceDeclaration* N) { - BOLT_VISIT(N->InstanceKeyword); - BOLT_VISIT(N->Name); - for (auto TE: N->TypeExps) { - BOLT_VISIT(TE); - } - BOLT_VISIT(N->BlockStart); - for (auto Element: N->Elements) { - BOLT_VISIT(Element); - } + void visitEachChild(TupleVariantDeclarationMember* N) { + BOLT_VISIT(N->Name); + for (auto Element: N->Elements) { + BOLT_VISIT(Element); } + } - void visitEachChild(SourceFile* N) { - for (auto Element: N->Elements) { - BOLT_VISIT(Element); - } + void visitEachChild(RecordVariantDeclarationMember* N) { + BOLT_VISIT(N->Name); + BOLT_VISIT(N->BlockStart); + for (auto Field: N->Fields) { + BOLT_VISIT(Field); } + } - }; + void visitEachChild(ClassDeclaration* N) { + if (N->PubKeyword) { + BOLT_VISIT(N->PubKeyword); + } + BOLT_VISIT(N->ClassKeyword); + BOLT_VISIT(N->Name); + for (auto Name: N->TypeVars) { + BOLT_VISIT(Name); + } + BOLT_VISIT(N->BlockStart); + for (auto Element: N->Elements) { + BOLT_VISIT(Element); + } + } + + void visitEachChild(InstanceDeclaration* N) { + BOLT_VISIT(N->InstanceKeyword); + BOLT_VISIT(N->Name); + for (auto TE: N->TypeExps) { + BOLT_VISIT(TE); + } + BOLT_VISIT(N->BlockStart); + for (auto Element: N->Elements) { + BOLT_VISIT(Element); + } + } + + void visitEachChild(SourceFile* N) { + for (auto Element: N->Elements) { + BOLT_VISIT(Element); + } + } + +}; } diff --git a/bootstrap/cxx/include/bolt/Checker.hpp b/bootstrap/cxx/include/bolt/Checker.hpp index 985e8e6f5..279a78034 100644 --- a/bootstrap/cxx/include/bolt/Checker.hpp +++ b/bootstrap/cxx/include/bolt/Checker.hpp @@ -16,328 +16,328 @@ namespace bolt { - std::string describe(const Type* Ty); // For debugging only +std::string describe(const Type* Ty); // For debugging only - enum class SymKind { - Type, - Var, - }; +enum class SymKind { + Type, + Var, +}; - class DiagnosticEngine; +class DiagnosticEngine; - class Constraint; +class Constraint; - using ConstraintSet = std::vector; +using ConstraintSet = std::vector; - enum class SchemeKind : unsigned char { - Forall, - }; +enum class SchemeKind : unsigned char { + Forall, +}; - class Scheme { +class Scheme { - const SchemeKind Kind; + const SchemeKind Kind; - protected: +protected: - inline Scheme(SchemeKind Kind): - Kind(Kind) {} + inline Scheme(SchemeKind Kind): + Kind(Kind) {} - public: +public: - inline SchemeKind getKind() const noexcept { - return Kind; + inline SchemeKind getKind() const noexcept { + return Kind; + } + + virtual ~Scheme() {} + +}; + +class Forall : public Scheme { +public: + + TVSet* TVs; + ConstraintSet* Constraints; + class Type* Type; + + inline Forall(class Type* Type): + Scheme(SchemeKind::Forall), TVs(new TVSet), Constraints(new ConstraintSet), Type(Type) {} + + inline Forall( + TVSet* TVs, + ConstraintSet* Constraints, + class Type* Type + ): Scheme(SchemeKind::Forall), + TVs(TVs), + Constraints(Constraints), + Type(Type) {} + + static bool classof(const Scheme* Scm) { + return Scm->getKind() == SchemeKind::Forall; + } + +}; + +class TypeEnv { + + std::unordered_map, Scheme*> Mapping; + +public: + + Scheme* lookup(ByteString Name, SymKind Kind) { + auto Key = std::make_tuple(Name, Kind); + auto Match = Mapping.find(Key); + if (Match == Mapping.end()) { + return nullptr; } + return Match->second; + } - virtual ~Scheme() {} + void add(ByteString Name, Scheme* Scm, SymKind Kind) { + auto Key = std::make_tuple(Name, Kind); + ZEN_ASSERT(!Mapping.count(Key)) + // auto F = static_cast(Scm); + // std::cerr << Name << " : forall "; + // for (auto TV: *F->TVs) { + // std::cerr << describe(TV) << " "; + // } + // std::cerr << ". " << describe(F->Type) << "\n"; + Mapping.emplace(Key, Scm); + } - }; +}; - class Forall : public Scheme { - public: - TVSet* TVs; - ConstraintSet* Constraints; - class Type* Type; +enum class ConstraintKind { + Equal, + Field, + Many, + Empty, +}; - inline Forall(class Type* Type): - Scheme(SchemeKind::Forall), TVs(new TVSet), Constraints(new ConstraintSet), Type(Type) {} +class Constraint { - inline Forall( - TVSet* TVs, - ConstraintSet* Constraints, - class Type* Type - ): Scheme(SchemeKind::Forall), - TVs(TVs), - Constraints(Constraints), - Type(Type) {} + const ConstraintKind Kind; - static bool classof(const Scheme* Scm) { - return Scm->getKind() == SchemeKind::Forall; - } +public: - }; + inline Constraint(ConstraintKind Kind): + Kind(Kind) {} - class TypeEnv { + inline ConstraintKind getKind() const noexcept { + return Kind; + } - std::unordered_map, Scheme*> Mapping; + Constraint* substitute(const TVSub& Sub); - public: + virtual ~Constraint() {} - Scheme* lookup(ByteString Name, SymKind Kind) { - auto Key = std::make_tuple(Name, Kind); - auto Match = Mapping.find(Key); - if (Match == Mapping.end()) { - return nullptr; - } - return Match->second; - } +}; - void add(ByteString Name, Scheme* Scm, SymKind Kind) { - auto Key = std::make_tuple(Name, Kind); - ZEN_ASSERT(!Mapping.count(Key)) - // auto F = static_cast(Scm); - // std::cerr << Name << " : forall "; - // for (auto TV: *F->TVs) { - // std::cerr << describe(TV) << " "; - // } - // std::cerr << ". " << describe(F->Type) << "\n"; - Mapping.emplace(Key, Scm); - } +class CEqual : public Constraint { +public: - }; + Type* Left; + Type* Right; + Node* Source; + inline CEqual(Type* Left, Type* Right, Node* Source = nullptr): + Constraint(ConstraintKind::Equal), Left(Left), Right(Right), Source(Source) {} - enum class ConstraintKind { - Equal, - Field, - Many, - Empty, - }; +}; - class Constraint { +class CField : public Constraint { +public: - const ConstraintKind Kind; + Type* TupleTy; + size_t I; + Type* FieldTy; + Node* Source; - public: + inline CField(Type* TupleTy, size_t I, Type* FieldTy, Node* Source = nullptr): + Constraint(ConstraintKind::Field), TupleTy(TupleTy), I(I), FieldTy(FieldTy), Source(Source) {} - inline Constraint(ConstraintKind Kind): - Kind(Kind) {} +}; - inline ConstraintKind getKind() const noexcept { - return Kind; - } +class CMany : public Constraint { +public: - Constraint* substitute(const TVSub& Sub); + ConstraintSet& Elements; - virtual ~Constraint() {} + inline CMany(ConstraintSet& Elements): + Constraint(ConstraintKind::Many), Elements(Elements) {} - }; +}; - class CEqual : public Constraint { - public: +class CEmpty : public Constraint { +public: - Type* Left; - Type* Right; - Node* Source; + inline CEmpty(): + Constraint(ConstraintKind::Empty) {} - inline CEqual(Type* Left, Type* Right, Node* Source = nullptr): - Constraint(ConstraintKind::Equal), Left(Left), Right(Right), Source(Source) {} +}; - }; +using InferContextFlagsMask = unsigned; - class CField : public Constraint { - public: +class InferContext { +public: - Type* TupleTy; - size_t I; - Type* FieldTy; - Node* Source; + /** + * A heap-allocated list of type variables that eventually will become part of a Forall scheme. + */ + TVSet* TVs; - inline CField(Type* TupleTy, size_t I, Type* FieldTy, Node* Source = nullptr): - Constraint(ConstraintKind::Field), TupleTy(TupleTy), I(I), FieldTy(FieldTy), Source(Source) {} + /** + * A heap-allocated list of constraints that eventually will become part of a Forall scheme. + */ + ConstraintSet* Constraints; - }; + TypeEnv Env; - class CMany : public Constraint { - public: + Type* ReturnType = nullptr; - ConstraintSet& Elements; + InferContext* Parent = nullptr; - inline CMany(ConstraintSet& Elements): - Constraint(ConstraintKind::Many), Elements(Elements) {} +}; - }; +class Checker { - class CEmpty : public Constraint { - public: + friend class Unifier; + friend class UnificationFrame; - inline CEmpty(): - Constraint(ConstraintKind::Empty) {} + const LanguageConfig& Config; + DiagnosticEngine& DE; - }; + size_t NextConTypeId = 0; + size_t NextTypeVarId = 0; - using InferContextFlagsMask = unsigned; + Type* BoolType; + Type* ListType; + Type* IntType; + Type* StringType; + Type* UnitType; - class InferContext { - public: + Graph RefGraph; - /** - * A heap-allocated list of type variables that eventually will become part of a Forall scheme. - */ - TVSet* TVs; + std::unordered_map> InstanceMap; - /** - * A heap-allocated list of constraints that eventually will become part of a Forall scheme. - */ - ConstraintSet* Constraints; + /// Inference context management - TypeEnv Env; + InferContext* ActiveContext; - Type* ReturnType = nullptr; + InferContext& getContext(); + void setContext(InferContext* Ctx); + void popContext(); - InferContext* Parent = nullptr; + void makeEqual(Type* A, Type* B, Node* Source); - }; + void addConstraint(Constraint* Constraint); - class Checker { + /** + * Get the return type for the current context. If none could be found, the + * program will abort. + */ + Type* getReturnType(); - friend class Unifier; - friend class UnificationFrame; + /// Type inference - const LanguageConfig& Config; - DiagnosticEngine& DE; + void forwardDeclare(Node* Node); + void forwardDeclareFunctionDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints); - size_t NextConTypeId = 0; - size_t NextTypeVarId = 0; + Type* inferExpression(Expression* Expression); + Type* inferTypeExpression(TypeExpression* TE, bool AutoVars = true); + Type* inferLiteral(Literal* Lit); + Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet); - Type* BoolType; - Type* ListType; - Type* IntType; - Type* StringType; - Type* UnitType; + void infer(Node* node); + void inferFunctionDeclaration(LetDeclaration* N); + void inferConstraintExpression(ConstraintExpression* C); - Graph RefGraph; + /// Factory methods - std::unordered_map> InstanceMap; + Type* createConType(ByteString Name); + Type* createTypeVar(); + Type* createRigidVar(ByteString Name); + InferContext* createInferContext( + InferContext* Parent = nullptr, + TVSet* TVs = new TVSet, + ConstraintSet* Constraints = new ConstraintSet + ); - /// Inference context management + /// Environment manipulation - InferContext* ActiveContext; + Scheme* lookup(ByteString Name, SymKind Kind); - InferContext& getContext(); - void setContext(InferContext* Ctx); - void popContext(); + /** + * Looks up a type/variable and ensures that it is a monomorphic type. + * + * This method is mainly syntactic sugar to make it clear in the code when a + * monomorphic type is expected. + * + * Note that if the type is not monomorphic the program will abort with a + * stack trace. It wil **not** print a user-friendly error message. + * + * \returns If the type/variable could not be found `nullptr` is returned. + * Otherwise, a [Type] is returned. + */ + Type* lookupMono(ByteString Name, SymKind Kind); - void makeEqual(Type* A, Type* B, Node* Source); + void addBinding(ByteString Name, Scheme* Scm, SymKind Kind); - void addConstraint(Constraint* Constraint); + /// Constraint solving - /** - * Get the return type for the current context. If none could be found, the - * program will abort. - */ - Type* getReturnType(); + /** + * The queue that is used during solving to store any unsolved constraints. + */ + std::deque Queue; - /// Type inference + /** + * Unify two types, using `Source` as source location. + * + * \returns Whether a type variable was assigned a type or not. + */ + bool unify(Type* Left, Type* Right, Node* Source); - void forwardDeclare(Node* Node); - void forwardDeclareFunctionDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints); + void solve(Constraint* Constraint); - Type* inferExpression(Expression* Expression); - Type* inferTypeExpression(TypeExpression* TE, bool AutoVars = true); - Type* inferLiteral(Literal* Lit); - Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet); + /// Helpers - void infer(Node* node); - void inferFunctionDeclaration(LetDeclaration* N); - void inferConstraintExpression(ConstraintExpression* C); + void populate(SourceFile* SF); - /// Factory methods + /** + * Verifies that type class signatures on type asserts in let-declarations + * correctly declare the right type classes. + */ + void checkTypeclassSigs(Node* N); - Type* createConType(ByteString Name); - Type* createTypeVar(); - Type* createRigidVar(ByteString Name); - InferContext* createInferContext( - InferContext* Parent = nullptr, - TVSet* TVs = new TVSet, - ConstraintSet* Constraints = new ConstraintSet - ); + Type* instantiate(Scheme* S, Node* Source); - /// Environment manipulation + void initialize(Node* N); - Scheme* lookup(ByteString Name, SymKind Kind); +public: - /** - * Looks up a type/variable and ensures that it is a monomorphic type. - * - * This method is mainly syntactic sugar to make it clear in the code when a - * monomorphic type is expected. - * - * Note that if the type is not monomorphic the program will abort with a - * stack trace. It wil **not** print a user-friendly error message. - * - * \returns If the type/variable could not be found `nullptr` is returned. - * Otherwise, a [Type] is returned. - */ - Type* lookupMono(ByteString Name, SymKind Kind); + Checker(const LanguageConfig& Config, DiagnosticEngine& DE); - void addBinding(ByteString Name, Scheme* Scm, SymKind Kind); + /** + * \internal + */ + Type* solveType(Type* Ty); - /// Constraint solving + void check(SourceFile* SF); - /** - * The queue that is used during solving to store any unsolved constraints. - */ - std::deque Queue; + inline Type* getBoolType() const { + return BoolType; + } - /** - * Unify two types, using `Source` as source location. - * - * \returns Whether a type variable was assigned a type or not. - */ - bool unify(Type* Left, Type* Right, Node* Source); + inline Type* getStringType() const { + return StringType; + } - void solve(Constraint* Constraint); + inline Type* getIntType() const { + return IntType; + } - /// Helpers + Type* getType(TypedNode* Node); - void populate(SourceFile* SF); - - /** - * Verifies that type class signatures on type asserts in let-declarations - * correctly declare the right type classes. - */ - void checkTypeclassSigs(Node* N); - - Type* instantiate(Scheme* S, Node* Source); - - void initialize(Node* N); - - public: - - Checker(const LanguageConfig& Config, DiagnosticEngine& DE); - - /** - * \internal - */ - Type* solveType(Type* Ty); - - void check(SourceFile* SF); - - inline Type* getBoolType() const { - return BoolType; - } - - inline Type* getStringType() const { - return StringType; - } - - inline Type* getIntType() const { - return IntType; - } - - Type* getType(TypedNode* Node); - - }; +}; } diff --git a/bootstrap/cxx/include/bolt/Common.hpp b/bootstrap/cxx/include/bolt/Common.hpp index 1c360c532..930ae4517 100644 --- a/bootstrap/cxx/include/bolt/Common.hpp +++ b/bootstrap/cxx/include/bolt/Common.hpp @@ -3,50 +3,50 @@ namespace bolt { - class LanguageConfig { - - enum ConfigFlags { - ConfigFlags_TypeVarsRequireForall = 1 << 0, - }; - - unsigned Flags = 0; - - public: - - void setTypeVarsRequireForall(bool Enable) { - if (Enable) { - Flags |= ConfigFlags_TypeVarsRequireForall; - } else { - Flags |= ~ConfigFlags_TypeVarsRequireForall; - } - } - - bool typeVarsRequireForall() const noexcept { - return Flags & ConfigFlags_TypeVarsRequireForall; - } - - bool hasImmediateDiagnostics() const noexcept { - // TODO make this a configuration flag - return true; - } +class LanguageConfig { + enum ConfigFlags { + ConfigFlags_TypeVarsRequireForall = 1 << 0, }; - template - D* cast(B* base) { - ZEN_ASSERT(D::classof(base)); - return static_cast(base); + unsigned Flags = 0; + +public: + + void setTypeVarsRequireForall(bool Enable) { + if (Enable) { + Flags |= ConfigFlags_TypeVarsRequireForall; + } else { + Flags |= ~ConfigFlags_TypeVarsRequireForall; + } } - template - const D* cast(const B* base) { - ZEN_ASSERT(D::classof(base)); - return static_cast(base); + bool typeVarsRequireForall() const noexcept { + return Flags & ConfigFlags_TypeVarsRequireForall; } - template - bool isa(const T* value) { - return D::classof(value); + bool hasImmediateDiagnostics() const noexcept { + // TODO make this a configuration flag + return true; } +}; + +template +D* cast(B* base) { + ZEN_ASSERT(D::classof(base)); + return static_cast(base); +} + +template +const D* cast(const B* base) { + ZEN_ASSERT(D::classof(base)); + return static_cast(base); +} + +template +bool isa(const T* value) { + return D::classof(value); +} + } diff --git a/bootstrap/cxx/include/bolt/ConsolePrinter.hpp b/bootstrap/cxx/include/bolt/ConsolePrinter.hpp index 40d3c549b..e98bd6cef 100644 --- a/bootstrap/cxx/include/bolt/ConsolePrinter.hpp +++ b/bootstrap/cxx/include/bolt/ConsolePrinter.hpp @@ -9,181 +9,181 @@ namespace bolt { - class Node; - class Type; - class TypeclassSignature; - class Diagnostic; +class Node; +class Type; +class TypeclassSignature; +class Diagnostic; - enum class Color { - None, - Black, - White, - Red, - Yellow, - Green, - Blue, - Cyan, - Magenta, - }; +enum class Color { + None, + Black, + White, + Red, + Yellow, + Green, + Blue, + Cyan, + Magenta, +}; - enum StyleFlags : unsigned { - StyleFlags_None = 0, - StyleFlags_Bold = 1 << 0, - StyleFlags_Underline = 1 << 1, - StyleFlags_Italic = 1 << 2, - }; +enum StyleFlags : unsigned { + StyleFlags_None = 0, + StyleFlags_Bold = 1 << 0, + StyleFlags_Underline = 1 << 1, + StyleFlags_Italic = 1 << 2, +}; - class Style { +class Style { - unsigned Flags = StyleFlags_None; + unsigned Flags = StyleFlags_None; - Color FgColor = Color::None; - Color BgColor = Color::None; + Color FgColor = Color::None; + Color BgColor = Color::None; - public: +public: - Color getForegroundColor() const noexcept { - return FgColor; + Color getForegroundColor() const noexcept { + return FgColor; + } + + Color getBackgroundColor() const noexcept { + return BgColor; + } + + void setForegroundColor(Color NewColor) noexcept { + FgColor = NewColor; + } + + void setBackgroundColor(Color NewColor) noexcept { + BgColor = NewColor; + } + + bool hasForegroundColor() const noexcept { + return FgColor != Color::None; + } + + bool hasBackgroundColor() const noexcept { + return BgColor != Color::None; + } + + void clearForegroundColor() noexcept { + FgColor = Color::None; + } + + void clearBackgroundColor() noexcept { + BgColor = Color::None; + } + + bool isUnderline() const noexcept { + return Flags & StyleFlags_Underline; + } + + bool isItalic() const noexcept { + return Flags & StyleFlags_Italic; + } + + bool isBold() const noexcept { + return Flags & StyleFlags_Bold; + } + + void setUnderline(bool Enable) noexcept { + if (Enable) { + Flags |= StyleFlags_Underline; + } else { + Flags &= ~StyleFlags_Underline; } + } - Color getBackgroundColor() const noexcept { - return BgColor; + void setItalic(bool Enable) noexcept { + if (Enable) { + Flags |= StyleFlags_Italic; + } else { + Flags &= ~StyleFlags_Italic; } + } - void setForegroundColor(Color NewColor) noexcept { - FgColor = NewColor; + void setBold(bool Enable) noexcept { + if (Enable) { + Flags |= StyleFlags_Bold; + } else { + Flags &= ~StyleFlags_Bold; } + } - void setBackgroundColor(Color NewColor) noexcept { - BgColor = NewColor; - } + void reset() noexcept { + FgColor = Color::None; + BgColor = Color::None; + Flags = 0; + } - bool hasForegroundColor() const noexcept { - return FgColor != Color::None; - } +}; - bool hasBackgroundColor() const noexcept { - return BgColor != Color::None; - } +/** + * Prints any diagnostic message that was added to it to the console. + */ +class ConsolePrinter { - void clearForegroundColor() noexcept { - FgColor = Color::None; - } + std::ostream& Out; - void clearBackgroundColor() noexcept { - BgColor = Color::None; - } + Style ActiveStyle; - bool isUnderline() const noexcept { - return Flags & StyleFlags_Underline; - } + void setForegroundColor(Color C); + void setBackgroundColor(Color C); + void applyStyles(); - bool isItalic() const noexcept { - return Flags & StyleFlags_Italic; - } + void setBold(bool Enable); + void setItalic(bool Enable); + void setUnderline(bool Enable); + void resetStyles(); - bool isBold() const noexcept { - return Flags & StyleFlags_Bold; - } + void writeGutter( + std::size_t GutterWidth, + std::string Text + ); - void setUnderline(bool Enable) noexcept { - if (Enable) { - Flags |= StyleFlags_Underline; - } else { - Flags &= ~StyleFlags_Underline; - } - } + void writeHighlight( + std::size_t GutterWidth, + TextRange Range, + Color HighlightColor, + std::size_t Line, + std::size_t LineLength + ); - void setItalic(bool Enable) noexcept { - if (Enable) { - Flags |= StyleFlags_Italic; - } else { - Flags &= ~StyleFlags_Italic; - } - } + void writeExcerpt( + const TextFile& File, + TextRange ToPrint, + TextRange ToHighlight, + Color HighlightColor + ); - void setBold(bool Enable) noexcept { - if (Enable) { - Flags |= StyleFlags_Bold; - } else { - Flags &= ~StyleFlags_Bold; - } - } + void writeNode(const Node* N); - void reset() noexcept { - FgColor = Color::None; - BgColor = Color::None; - Flags = 0; - } + void writePrefix(const Diagnostic& D); + void writeBinding(const ByteString& Name); + void writeType(std::size_t I); + void writeType(const Type* Ty, const TypePath& Underline); + void writeType(const Type* Ty); + void writeLoc(const TextFile& File, const TextLoc& Loc); + void writeTypeclassName(const ByteString& Name); + void writeTypeclassSignature(const TypeclassSignature& Sig); - }; + void write(const std::string_view& S); + void write(std::size_t N); + void write(char C); - /** - * Prints any diagnostic message that was added to it to the console. - */ - class ConsolePrinter { +public: - std::ostream& Out; + unsigned ExcerptLinesPre = 2; + unsigned ExcerptLinesPost = 2; + std::size_t MaxTypeSubsitutionCount = 0; + bool PrintFilePosition = true; + bool PrintExcerpts = true; + bool EnableColors = true; - Style ActiveStyle; + ConsolePrinter(std::ostream& Out = std::cerr); - void setForegroundColor(Color C); - void setBackgroundColor(Color C); - void applyStyles(); + void writeDiagnostic(const Diagnostic& D); - void setBold(bool Enable); - void setItalic(bool Enable); - void setUnderline(bool Enable); - void resetStyles(); - - void writeGutter( - std::size_t GutterWidth, - std::string Text - ); - - void writeHighlight( - std::size_t GutterWidth, - TextRange Range, - Color HighlightColor, - std::size_t Line, - std::size_t LineLength - ); - - void writeExcerpt( - const TextFile& File, - TextRange ToPrint, - TextRange ToHighlight, - Color HighlightColor - ); - - void writeNode(const Node* N); - - void writePrefix(const Diagnostic& D); - void writeBinding(const ByteString& Name); - void writeType(std::size_t I); - void writeType(const Type* Ty, const TypePath& Underline); - void writeType(const Type* Ty); - void writeLoc(const TextFile& File, const TextLoc& Loc); - void writeTypeclassName(const ByteString& Name); - void writeTypeclassSignature(const TypeclassSignature& Sig); - - void write(const std::string_view& S); - void write(std::size_t N); - void write(char C); - - public: - - unsigned ExcerptLinesPre = 2; - unsigned ExcerptLinesPost = 2; - std::size_t MaxTypeSubsitutionCount = 0; - bool PrintFilePosition = true; - bool PrintExcerpts = true; - bool EnableColors = true; - - ConsolePrinter(std::ostream& Out = std::cerr); - - void writeDiagnostic(const Diagnostic& D); - - }; +}; } diff --git a/bootstrap/cxx/include/bolt/DiagnosticEngine.hpp b/bootstrap/cxx/include/bolt/DiagnosticEngine.hpp index 867b35844..fcc927753 100644 --- a/bootstrap/cxx/include/bolt/DiagnosticEngine.hpp +++ b/bootstrap/cxx/include/bolt/DiagnosticEngine.hpp @@ -7,78 +7,78 @@ namespace bolt { - class ConsolePrinter; - class Diagnostic; - class TypeclassSignature; - class Type; - class Node; +class ConsolePrinter; +class Diagnostic; +class TypeclassSignature; +class Type; +class Node; - class DiagnosticEngine { - protected: +class DiagnosticEngine { +protected: - bool HasError = false; + bool HasError = false; - virtual void addDiagnostic(Diagnostic* Diagnostic) = 0; + virtual void addDiagnostic(Diagnostic* Diagnostic) = 0; - public: +public: - bool FailOnError = false; + bool FailOnError = false; - inline bool hasError() const noexcept { - return HasError; - } + inline bool hasError() const noexcept { + return HasError; + } - template - void add(Ts&&... Args) { - // if (FailOnError) { - // ZEN_PANIC("An error diagnostic caused the program to abort."); - // } - HasError = true; - addDiagnostic(new D { std::forward(Args)... }); - } + template + void add(Ts&&... Args) { + // if (FailOnError) { + // ZEN_PANIC("An error diagnostic caused the program to abort."); + // } + HasError = true; + addDiagnostic(new D { std::forward(Args)... }); + } - virtual ~DiagnosticEngine() {} + virtual ~DiagnosticEngine() {} - }; +}; - /** - * Keeps diagnostics alive in-memory until a seperate procedure processes them. - */ - class DiagnosticStore : public DiagnosticEngine { - public: +/** + * Keeps diagnostics alive in-memory until a seperate procedure processes them. + */ +class DiagnosticStore : public DiagnosticEngine { +public: - std::vector Diagnostics; + std::vector Diagnostics; - void addDiagnostic(Diagnostic* Diagnostic) { - Diagnostics.push_back(Diagnostic); - } + void addDiagnostic(Diagnostic* Diagnostic) { + Diagnostics.push_back(Diagnostic); + } - void clear() { - Diagnostics.clear(); - } + void clear() { + Diagnostics.clear(); + } - void sort(); + void sort(); - std::size_t countDiagnostics() const noexcept { - return Diagnostics.size(); - } + std::size_t countDiagnostics() const noexcept { + return Diagnostics.size(); + } - ~DiagnosticStore(); + ~DiagnosticStore(); - }; +}; - class ConsoleDiagnostics : public DiagnosticEngine { +class ConsoleDiagnostics : public DiagnosticEngine { - ConsolePrinter& ThePrinter; + ConsolePrinter& ThePrinter; - protected: +protected: - void addDiagnostic(Diagnostic* Diagnostic) override; + void addDiagnostic(Diagnostic* Diagnostic) override; - public: +public: - ConsoleDiagnostics(ConsolePrinter& ThePrinter); + ConsoleDiagnostics(ConsolePrinter& ThePrinter); - }; +}; } diff --git a/bootstrap/cxx/include/bolt/Diagnostics.hpp b/bootstrap/cxx/include/bolt/Diagnostics.hpp index 51361b020..56b1dd5fa 100644 --- a/bootstrap/cxx/include/bolt/Diagnostics.hpp +++ b/bootstrap/cxx/include/bolt/Diagnostics.hpp @@ -10,237 +10,237 @@ namespace bolt { - enum class DiagnosticKind : unsigned char { - BindingNotFound, - FieldNotFound, - InstanceNotFound, - InvalidTypeToTypeclass, - NotATuple, - TupleIndexOutOfRange, - TypeclassMissing, - UnexpectedString, - UnexpectedToken, - UnificationError, - }; +enum class DiagnosticKind : unsigned char { + BindingNotFound, + FieldNotFound, + InstanceNotFound, + InvalidTypeToTypeclass, + NotATuple, + TupleIndexOutOfRange, + TypeclassMissing, + UnexpectedString, + UnexpectedToken, + UnificationError, +}; - class Diagnostic { +class Diagnostic { - const DiagnosticKind Kind; + const DiagnosticKind Kind; - protected: +protected: - Diagnostic(DiagnosticKind Kind); + Diagnostic(DiagnosticKind Kind); - public: +public: - inline DiagnosticKind getKind() const noexcept { - return Kind; - } + inline DiagnosticKind getKind() const noexcept { + return Kind; + } - virtual Node* getNode() const { - return nullptr; - } + virtual Node* getNode() const { + return nullptr; + } - virtual unsigned getCode() const noexcept = 0; + virtual unsigned getCode() const noexcept = 0; - virtual ~Diagnostic() {} + virtual ~Diagnostic() {} - }; +}; - class UnexpectedStringDiagnostic : public Diagnostic { - public: +class UnexpectedStringDiagnostic : public Diagnostic { +public: - TextFile& File; - TextLoc Location; - String Actual; + TextFile& File; + TextLoc Location; + String Actual; - inline UnexpectedStringDiagnostic(TextFile& File, TextLoc Location, String Actual): - Diagnostic(DiagnosticKind::UnexpectedString), File(File), Location(Location), Actual(Actual) {} + inline UnexpectedStringDiagnostic(TextFile& File, TextLoc Location, String Actual): + Diagnostic(DiagnosticKind::UnexpectedString), File(File), Location(Location), Actual(Actual) {} - unsigned getCode() const noexcept override { - return 1001; - } + unsigned getCode() const noexcept override { + return 1001; + } - }; +}; - class UnexpectedTokenDiagnostic : public Diagnostic { - public: +class UnexpectedTokenDiagnostic : public Diagnostic { +public: - TextFile& File; - Token* Actual; - std::vector Expected; + TextFile& File; + Token* Actual; + std::vector Expected; - inline UnexpectedTokenDiagnostic(TextFile& File, Token* Actual, std::vector Expected): - Diagnostic(DiagnosticKind::UnexpectedToken), File(File), Actual(Actual), Expected(Expected) {} + inline UnexpectedTokenDiagnostic(TextFile& File, Token* Actual, std::vector Expected): + Diagnostic(DiagnosticKind::UnexpectedToken), File(File), Actual(Actual), Expected(Expected) {} - unsigned getCode() const noexcept override { - return 1101; - } + unsigned getCode() const noexcept override { + return 1101; + } - }; +}; - class BindingNotFoundDiagnostic : public Diagnostic { - public: +class BindingNotFoundDiagnostic : public Diagnostic { +public: - ByteString Name; - Node* Initiator; + ByteString Name; + Node* Initiator; - inline BindingNotFoundDiagnostic(ByteString Name, Node* Initiator): - Diagnostic(DiagnosticKind::BindingNotFound), Name(Name), Initiator(Initiator) {} + inline BindingNotFoundDiagnostic(ByteString Name, Node* Initiator): + Diagnostic(DiagnosticKind::BindingNotFound), Name(Name), Initiator(Initiator) {} - inline Node* getNode() const override { - return Initiator; - } + inline Node* getNode() const override { + return Initiator; + } - unsigned getCode() const noexcept override { - return 2005; - } + unsigned getCode() const noexcept override { + return 2005; + } - }; +}; - class UnificationErrorDiagnostic : public Diagnostic { - public: +class UnificationErrorDiagnostic : public Diagnostic { +public: - Type* OrigLeft; - Type* OrigRight; - TypePath LeftPath; - TypePath RightPath; - Node* Source; + Type* OrigLeft; + Type* OrigRight; + TypePath LeftPath; + TypePath RightPath; + Node* Source; - inline UnificationErrorDiagnostic(Type* OrigLeft, Type* OrigRight, TypePath LeftPath, TypePath RightPath, Node* Source): - Diagnostic(DiagnosticKind::UnificationError), OrigLeft(OrigLeft), OrigRight(OrigRight), LeftPath(LeftPath), RightPath(RightPath), Source(Source) {} + inline UnificationErrorDiagnostic(Type* OrigLeft, Type* OrigRight, TypePath LeftPath, TypePath RightPath, Node* Source): + Diagnostic(DiagnosticKind::UnificationError), OrigLeft(OrigLeft), OrigRight(OrigRight), LeftPath(LeftPath), RightPath(RightPath), Source(Source) {} - inline Type* getLeft() const { - return OrigLeft->resolve(LeftPath); - } + inline Type* getLeft() const { + return OrigLeft->resolve(LeftPath); + } - inline Type* getRight() const { - return OrigRight->resolve(RightPath); - } + inline Type* getRight() const { + return OrigRight->resolve(RightPath); + } - inline Node* getNode() const override { - return Source; - } + inline Node* getNode() const override { + return Source; + } - unsigned getCode() const noexcept override { - return 2010; - } + unsigned getCode() const noexcept override { + return 2010; + } - }; +}; - class TypeclassMissingDiagnostic : public Diagnostic { - public: +class TypeclassMissingDiagnostic : public Diagnostic { +public: - TypeclassSignature Sig; - Node* Decl; + TypeclassSignature Sig; + Node* Decl; - inline TypeclassMissingDiagnostic(TypeclassSignature Sig, Node* Decl): - Diagnostic(DiagnosticKind::TypeclassMissing), Sig(Sig), Decl(Decl) {} + inline TypeclassMissingDiagnostic(TypeclassSignature Sig, Node* Decl): + Diagnostic(DiagnosticKind::TypeclassMissing), Sig(Sig), Decl(Decl) {} - inline Node* getNode() const override { - return Decl; - } + inline Node* getNode() const override { + return Decl; + } - unsigned getCode() const noexcept override { - return 2201; - } + unsigned getCode() const noexcept override { + return 2201; + } - }; +}; - class InstanceNotFoundDiagnostic : public Diagnostic { - public: +class InstanceNotFoundDiagnostic : public Diagnostic { +public: - ByteString TypeclassName; - Type* Ty; - Node* Source; + ByteString TypeclassName; + Type* Ty; + Node* Source; - inline InstanceNotFoundDiagnostic(ByteString TypeclassName, Type* Ty, Node* Source): - Diagnostic(DiagnosticKind::InstanceNotFound), TypeclassName(TypeclassName), Ty(Ty), Source(Source) {} + inline InstanceNotFoundDiagnostic(ByteString TypeclassName, Type* Ty, Node* Source): + Diagnostic(DiagnosticKind::InstanceNotFound), TypeclassName(TypeclassName), Ty(Ty), Source(Source) {} - inline Node* getNode() const override { - return Source; - } + inline Node* getNode() const override { + return Source; + } - unsigned getCode() const noexcept override { - return 2251; - } + unsigned getCode() const noexcept override { + return 2251; + } - }; +}; - class TupleIndexOutOfRangeDiagnostic : public Diagnostic { - public: +class TupleIndexOutOfRangeDiagnostic : public Diagnostic { +public: - Type* Tuple; - std::size_t I; - Node* Source; + Type* Tuple; + std::size_t I; + Node* Source; - inline TupleIndexOutOfRangeDiagnostic(Type* Tuple, std::size_t I, Node* Source): - Diagnostic(DiagnosticKind::TupleIndexOutOfRange), Tuple(Tuple), I(I), Source(Source) {} + inline TupleIndexOutOfRangeDiagnostic(Type* Tuple, std::size_t I, Node* Source): + Diagnostic(DiagnosticKind::TupleIndexOutOfRange), Tuple(Tuple), I(I), Source(Source) {} - inline Node * getNode() const override { - return Source; - } + inline Node * getNode() const override { + return Source; + } - unsigned getCode() const noexcept override { - return 2015; - } + unsigned getCode() const noexcept override { + return 2015; + } - }; +}; - class InvalidTypeToTypeclassDiagnostic : public Diagnostic { - public: +class InvalidTypeToTypeclassDiagnostic : public Diagnostic { +public: - Type* Actual; - std::vector Classes; - Node* Source; + Type* Actual; + std::vector Classes; + Node* Source; - inline InvalidTypeToTypeclassDiagnostic(Type* Actual, std::vector Classes, Node* Source): - Diagnostic(DiagnosticKind::InvalidTypeToTypeclass), Actual(Actual), Classes(Classes), Source(Source) {} + inline InvalidTypeToTypeclassDiagnostic(Type* Actual, std::vector Classes, Node* Source): + Diagnostic(DiagnosticKind::InvalidTypeToTypeclass), Actual(Actual), Classes(Classes), Source(Source) {} - inline Node* getNode() const override { - return Source; - } + inline Node* getNode() const override { + return Source; + } - unsigned getCode() const noexcept override { - return 2060; - } + unsigned getCode() const noexcept override { + return 2060; + } - }; +}; - class FieldNotFoundDiagnostic : public Diagnostic { - public: +class FieldNotFoundDiagnostic : public Diagnostic { +public: - ByteString Name; - Type* Ty; - TypePath Path; - Node* Source; + ByteString Name; + Type* Ty; + TypePath Path; + Node* Source; - inline FieldNotFoundDiagnostic(ByteString Name, Type* Ty, TypePath Path, Node* Source): - Diagnostic(DiagnosticKind::FieldNotFound), Name(Name), Ty(Ty), Path(Path), Source(Source) {} + inline FieldNotFoundDiagnostic(ByteString Name, Type* Ty, TypePath Path, Node* Source): + Diagnostic(DiagnosticKind::FieldNotFound), Name(Name), Ty(Ty), Path(Path), Source(Source) {} - unsigned getCode() const noexcept override { - return 2017; - } + unsigned getCode() const noexcept override { + return 2017; + } - }; +}; - class NotATupleDiagnostic : public Diagnostic { - public: +class NotATupleDiagnostic : public Diagnostic { +public: - Type* Ty; - Node* Source; + Type* Ty; + Node* Source; - inline NotATupleDiagnostic(Type* Ty, Node* Source): - Diagnostic(DiagnosticKind::NotATuple), Ty(Ty), Source(Source) {} + inline NotATupleDiagnostic(Type* Ty, Node* Source): + Diagnostic(DiagnosticKind::NotATuple), Ty(Ty), Source(Source) {} - inline Node * getNode() const override { - return Source; - } + inline Node * getNode() const override { + return Source; + } - unsigned getCode() const noexcept override { - return 2016; - } + unsigned getCode() const noexcept override { + return 2016; + } - }; +}; } diff --git a/bootstrap/cxx/include/bolt/Evaluator.hpp b/bootstrap/cxx/include/bolt/Evaluator.hpp index ae5127951..f984a5489 100644 --- a/bootstrap/cxx/include/bolt/Evaluator.hpp +++ b/bootstrap/cxx/include/bolt/Evaluator.hpp @@ -9,179 +9,180 @@ namespace bolt { - enum class ValueKind { - Empty, - String, - Integer, - Tuple, - SourceFunction, - NativeFunction, +enum class ValueKind { + Empty, + String, + Integer, + Tuple, + SourceFunction, + NativeFunction, +}; + +class Value { + + using NativeFunction = std::function)>; + + using Tuple = std::vector; + + ValueKind Kind; + + union { + ByteString S; + Integer I; + LetDeclaration* D; + NativeFunction F; + Tuple T; }; - class Value { +public: - using NativeFunction = std::function)>; + Value(): + Kind(ValueKind::Empty) {} - using Tuple = std::vector; + Value(ByteString S): + Kind(ValueKind::String), S(S) {} - ValueKind Kind; + Value(Integer I): + Kind(ValueKind::Integer), I(I) {} - union { - ByteString S; - Integer I; - LetDeclaration* D; - NativeFunction F; - Tuple T; - }; + Value(LetDeclaration* D): + Kind(ValueKind::SourceFunction), D(D) {} - public: + Value(NativeFunction F): + Kind(ValueKind::NativeFunction), F(F) {} - Value(): - Kind(ValueKind::Empty) {} + Value(std::vector T): + Kind(ValueKind::Tuple), T(T) {} - Value(ByteString S): - Kind(ValueKind::String), S(S) {} - - Value(Integer I): - Kind(ValueKind::Integer), I(I) {} - - Value(LetDeclaration* D): - Kind(ValueKind::SourceFunction), D(D) {} - - Value(NativeFunction F): - Kind(ValueKind::NativeFunction), F(F) {} - - Value(std::vector T): - Kind(ValueKind::Tuple), T(T) {} - - Value(const Value& V): - Kind(V.Kind) { - switch (Kind) { - case ValueKind::String: - new (&S) ByteString(V.S); - break; - case ValueKind::Integer: - new (&I) Integer(V.I); - break; - case ValueKind::Tuple: - new (&I) Tuple(V.T); - break; - case ValueKind::SourceFunction: - new (&D) LetDeclaration*(V.D); - break; - case ValueKind::NativeFunction: - new (&F) NativeFunction(V.F); - break; - case ValueKind::Empty: - break; - } - } - - Value& operator=(const Value& Other) noexcept { - Kind = Other.Kind; + Value(const Value& V): + Kind(V.Kind) { switch (Kind) { case ValueKind::String: - new (&S) ByteString(Other.S); + new (&S) ByteString(V.S); break; case ValueKind::Integer: - new (&I) Integer(Other.I); + new (&I) Integer(V.I); break; case ValueKind::Tuple: - new (&I) Tuple(Other.T); + new (&I) Tuple(V.T); break; case ValueKind::SourceFunction: - new (&D) LetDeclaration*(Other.D); + new (&D) LetDeclaration*(V.D); break; case ValueKind::NativeFunction: - new (&F) NativeFunction(Other.F); - break; - case ValueKind::Empty: - break; - } - return *this; - } - - // Add move constructor and move assignment methods - - inline ValueKind getKind() const noexcept { - return Kind; - } - - inline ByteString& asString() { - ZEN_ASSERT(Kind == ValueKind::String); - return S; - } - - inline LetDeclaration* getDeclaration() { - ZEN_ASSERT(Kind == ValueKind::SourceFunction); - return D; - } - - inline NativeFunction getBinding() { - ZEN_ASSERT(Kind == ValueKind::NativeFunction); - return F; - } - - static Value binding(NativeFunction F) { - return Value(F); - } - - static Value unit() { - return Value(Tuple {}); - } - - ~Value() { - switch (Kind) { - case ValueKind::String: - S.~ByteString(); - break; - case ValueKind::Integer: - I.~Integer(); - break; - case ValueKind::Tuple: - T.~Tuple(); - break; - case ValueKind::SourceFunction: - break; - case ValueKind::NativeFunction: - F.~NativeFunction(); + new (&F) NativeFunction(V.F); break; case ValueKind::Empty: break; } } - }; - - class Env { - - std::unordered_map Bindings; - - public: - - void add(const ByteString& Name, Value V) { - Bindings.emplace(Name, V); + Value& operator=(const Value& Other) noexcept { + Kind = Other.Kind; + switch (Kind) { + case ValueKind::String: + new (&S) ByteString(Other.S); + break; + case ValueKind::Integer: + new (&I) Integer(Other.I); + break; + case ValueKind::Tuple: + new (&I) Tuple(Other.T); + break; + case ValueKind::SourceFunction: + new (&D) LetDeclaration*(Other.D); + break; + case ValueKind::NativeFunction: + new (&F) NativeFunction(Other.F); + break; + case ValueKind::Empty: + break; } + return *this; + } - Value& lookup(const ByteString& Name) { - auto Match = Bindings.find(Name); - ZEN_ASSERT(Match != Bindings.end()); - return Match->second; + // Add move constructor and move assignment methods + + inline ValueKind getKind() const noexcept { + return Kind; + } + + inline ByteString& asString() { + ZEN_ASSERT(Kind == ValueKind::String); + return S; + } + + inline LetDeclaration* getDeclaration() { + ZEN_ASSERT(Kind == ValueKind::SourceFunction); + return D; + } + + inline NativeFunction getBinding() { + ZEN_ASSERT(Kind == ValueKind::NativeFunction); + return F; + } + + static Value binding(NativeFunction F) { + return Value(F); + } + + static Value unit() { + return Value(Tuple {}); + } + + ~Value() { + switch (Kind) { + case ValueKind::String: + S.~ByteString(); + break; + case ValueKind::Integer: + I.~Integer(); + break; + case ValueKind::Tuple: + T.~Tuple(); + break; + case ValueKind::SourceFunction: + break; + case ValueKind::NativeFunction: + F.~NativeFunction(); + break; + case ValueKind::Empty: + break; } + } - }; +}; - class Evaluator { +class Env { - public: + std::unordered_map Bindings; - void assignPattern(Pattern* P, Value& V, Env& E); +public: - Value apply(Value Op, std::vector Args); + void add(const ByteString& Name, Value V) { + Bindings.emplace(Name, V); + } - Value evaluateExpression(Expression* N, Env& E); + Value& lookup(const ByteString& Name) { + auto Match = Bindings.find(Name); + ZEN_ASSERT(Match != Bindings.end()); + return Match->second; + } - void evaluate(Node* N, Env& E); +}; + +class Evaluator { + +public: + + void assignPattern(Pattern* P, Value& V, Env& E); + + Value apply(Value Op, std::vector Args); + + Value evaluateExpression(Expression* N, Env& E); + + void evaluate(Node* N, Env& E); + +}; - }; } diff --git a/bootstrap/cxx/include/bolt/Integer.hpp b/bootstrap/cxx/include/bolt/Integer.hpp index 729e80459..f1a8a09d5 100644 --- a/bootstrap/cxx/include/bolt/Integer.hpp +++ b/bootstrap/cxx/include/bolt/Integer.hpp @@ -3,7 +3,7 @@ namespace bolt { - using Integer = long long; +using Integer = long long; } diff --git a/bootstrap/cxx/include/bolt/Parser.hpp b/bootstrap/cxx/include/bolt/Parser.hpp index db2882bd7..478c425fc 100644 --- a/bootstrap/cxx/include/bolt/Parser.hpp +++ b/bootstrap/cxx/include/bolt/Parser.hpp @@ -9,144 +9,144 @@ namespace bolt { - class DiagnosticEngine; - class Scanner; +class DiagnosticEngine; +class Scanner; - enum OperatorFlags { - OperatorFlags_Prefix = 1, - OperatorFlags_Suffix = 2, - OperatorFlags_InfixL = 4, - OperatorFlags_InfixR = 8, - }; +enum OperatorFlags { + OperatorFlags_Prefix = 1, + OperatorFlags_Suffix = 2, + OperatorFlags_InfixL = 4, + OperatorFlags_InfixR = 8, +}; - struct OperatorInfo { +struct OperatorInfo { - int Precedence; - unsigned Flags; + int Precedence; + unsigned Flags; - inline bool isPrefix() const noexcept { - return Flags & OperatorFlags_Prefix; - } + inline bool isPrefix() const noexcept { + return Flags & OperatorFlags_Prefix; + } - inline bool isSuffix() const noexcept { - return Flags & OperatorFlags_Suffix; - } + inline bool isSuffix() const noexcept { + return Flags & OperatorFlags_Suffix; + } - inline bool isInfix() const noexcept { - return Flags & (OperatorFlags_InfixL | OperatorFlags_InfixR); - } + inline bool isInfix() const noexcept { + return Flags & (OperatorFlags_InfixL | OperatorFlags_InfixR); + } - inline bool isRightAssoc() const noexcept { - return Flags & OperatorFlags_InfixR; - } + inline bool isRightAssoc() const noexcept { + return Flags & OperatorFlags_InfixR; + } - }; +}; - class OperatorTable { +class OperatorTable { - std::unordered_map Mapping; + std::unordered_map Mapping; - public: +public: - void add(std::string Name, unsigned Flags, int Precedence); + void add(std::string Name, unsigned Flags, int Precedence); - std::optional getInfix(Token* T); + std::optional getInfix(Token* T); - bool isInfix(Token* T); - bool isPrefix(Token* T); - bool isSuffix(Token* T); + bool isInfix(Token* T); + bool isPrefix(Token* T); + bool isSuffix(Token* T); - }; +}; - class Parser { +class Parser { - TextFile& File; - DiagnosticEngine& DE; + TextFile& File; + DiagnosticEngine& DE; - Stream& Tokens; + Stream& Tokens; - OperatorTable ExprOperators; + OperatorTable ExprOperators; - Token* peekFirstTokenAfterAnnotationsAndModifiers(); + Token* peekFirstTokenAfterAnnotationsAndModifiers(); - Token* expectToken(NodeKind Ty); + Token* expectToken(NodeKind Ty); - std::vector parseRecordDeclarationFields(); - std::optional>> parseRecordPatternFields(); + std::vector parseRecordDeclarationFields(); + std::optional>> parseRecordPatternFields(); - template - T* expectToken() { - return static_cast(expectToken(getNodeType())); - } + template + T* expectToken() { + return static_cast(expectToken(getNodeType())); + } - Expression* parseInfixOperatorAfterExpression(Expression* LHS, int MinPrecedence); + Expression* parseInfixOperatorAfterExpression(Expression* LHS, int MinPrecedence); - MatchExpression* parseMatchExpression(); - Expression* parseMemberExpression(); - RecordExpression* parseRecordExpression(); - Expression* parsePrimitiveExpression(); + MatchExpression* parseMatchExpression(); + Expression* parseMemberExpression(); + RecordExpression* parseRecordExpression(); + Expression* parsePrimitiveExpression(); - ConstraintExpression* parseConstraintExpression(); + ConstraintExpression* parseConstraintExpression(); - TypeExpression* parseAppTypeExpression(); - TypeExpression* parsePrimitiveTypeExpression(); - TypeExpression* parseQualifiedTypeExpression(); - TypeExpression* parseArrowTypeExpression(); - VarTypeExpression* parseVarTypeExpression(); - ReferenceTypeExpression* parseReferenceTypeExpression(); + TypeExpression* parseAppTypeExpression(); + TypeExpression* parsePrimitiveTypeExpression(); + TypeExpression* parseQualifiedTypeExpression(); + TypeExpression* parseArrowTypeExpression(); + VarTypeExpression* parseVarTypeExpression(); + ReferenceTypeExpression* parseReferenceTypeExpression(); - std::vector parseAnnotations(); + std::vector parseAnnotations(); - void checkLineFoldEnd(); - void skipPastLineFoldEnd(); - void skipToRBrace(); + void checkLineFoldEnd(); + void skipPastLineFoldEnd(); + void skipToRBrace(); - public: +public: - Parser(TextFile& File, Stream& S, DiagnosticEngine& DE); + Parser(TextFile& File, Stream& S, DiagnosticEngine& DE); - TypeExpression* parseTypeExpression(); + TypeExpression* parseTypeExpression(); - ListPattern* parseListPattern(); - Pattern* parsePrimitivePattern(bool IsNarrow); - Pattern* parseWidePattern(); - Pattern* parseNarrowPattern(); + ListPattern* parseListPattern(); + Pattern* parsePrimitivePattern(bool IsNarrow); + Pattern* parseWidePattern(); + Pattern* parseNarrowPattern(); - Parameter* parseParam(); + Parameter* parseParam(); - ReferenceExpression* parseReferenceExpression(); + ReferenceExpression* parseReferenceExpression(); - Expression* parseUnaryExpression(); + Expression* parseUnaryExpression(); - Expression* parseExpression(); + Expression* parseExpression(); - Expression* parseCallExpression(); + Expression* parseCallExpression(); - IfStatement* parseIfStatement(); + IfStatement* parseIfStatement(); - ReturnStatement* parseReturnStatement(); + ReturnStatement* parseReturnStatement(); - ExpressionStatement* parseExpressionStatement(); + ExpressionStatement* parseExpressionStatement(); - Node* parseLetBodyElement(); + Node* parseLetBodyElement(); - LetDeclaration* parseLetDeclaration(); + LetDeclaration* parseLetDeclaration(); - Node* parseClassElement(); + Node* parseClassElement(); - ClassDeclaration* parseClassDeclaration(); + ClassDeclaration* parseClassDeclaration(); - InstanceDeclaration* parseInstanceDeclaration(); + InstanceDeclaration* parseInstanceDeclaration(); - RecordDeclaration* parseRecordDeclaration(); + RecordDeclaration* parseRecordDeclaration(); - VariantDeclaration* parseVariantDeclaration(); + VariantDeclaration* parseVariantDeclaration(); - Node* parseSourceElement(); + Node* parseSourceElement(); - SourceFile* parseSourceFile(); + SourceFile* parseSourceFile(); - }; +}; } diff --git a/bootstrap/cxx/include/bolt/Scanner.hpp b/bootstrap/cxx/include/bolt/Scanner.hpp index ba16e276c..abaaf9902 100644 --- a/bootstrap/cxx/include/bolt/Scanner.hpp +++ b/bootstrap/cxx/include/bolt/Scanner.hpp @@ -12,73 +12,73 @@ namespace bolt { - class Token; - class DiagnosticEngine; +class Token; +class DiagnosticEngine; - class Scanner : public BufferedStream { +class Scanner : public BufferedStream { - DiagnosticEngine& DE; + DiagnosticEngine& DE; - TextFile& File; + TextFile& File; - Stream& Chars; + Stream& Chars; - TextLoc CurrLoc; + TextLoc CurrLoc; - inline TextLoc getCurrentLoc() const { - return CurrLoc; + inline TextLoc getCurrentLoc() const { + return CurrLoc; + } + + inline Char getChar() { + auto Chr = Chars.get(); + if (Chr == '\n') { + CurrLoc.Line += 1; + CurrLoc.Column = 1; + } else { + CurrLoc.Column += 1; } + return Chr; + } - inline Char getChar() { - auto Chr = Chars.get(); - if (Chr == '\n') { - CurrLoc.Line += 1; - CurrLoc.Column = 1; - } else { - CurrLoc.Column += 1; - } - return Chr; - } + inline Char peekChar(std::size_t Offset = 0) { + return Chars.peek(Offset); + } - inline Char peekChar(std::size_t Offset = 0) { - return Chars.peek(Offset); - } + std::string scanIdentifier(); - std::string scanIdentifier(); + Token* readNullable(); - Token* readNullable(); +protected: - protected: + Token* read() override; - Token* read() override; +public: - public: + Scanner(DiagnosticEngine& DE, TextFile& File, Stream& Chars); - Scanner(DiagnosticEngine& DE, TextFile& File, Stream& Chars); +}; - }; +enum class FrameType { + Block, + LineFold, + Fallthrough, +}; - enum class FrameType { - Block, - LineFold, - Fallthrough, - }; +class Punctuator : public BufferedStream { - class Punctuator : public BufferedStream { + Stream& Tokens; - Stream& Tokens; + std::stack Frames; + std::stack Locations; - std::stack Frames; - std::stack Locations; +protected: - protected: + virtual Token* read() override; - virtual Token* read() override; +public: - public: + Punctuator(Stream& Tokens); - Punctuator(Stream& Tokens); - - }; +}; } diff --git a/bootstrap/cxx/include/bolt/Stream.hpp b/bootstrap/cxx/include/bolt/Stream.hpp index b0659fc45..58cbaec56 100644 --- a/bootstrap/cxx/include/bolt/Stream.hpp +++ b/bootstrap/cxx/include/bolt/Stream.hpp @@ -8,74 +8,74 @@ namespace bolt { - template - class Stream { - public: +template +class Stream { +public: - virtual T get() = 0; - virtual T peek(std::size_t Offset = 0) = 0; + virtual T get() = 0; + virtual T peek(std::size_t Offset = 0) = 0; - virtual ~Stream() {} + virtual ~Stream() {} - }; +}; - template - class VectorStream : public Stream { - public: +template +class VectorStream : public Stream { +public: - using value_type = T; + using value_type = T; - ContainerT& Data; - value_type Sentry; - std::size_t Offset; + ContainerT& Data; + value_type Sentry; + std::size_t Offset; - VectorStream(ContainerT& Data, value_type Sentry, std::size_t Offset = 0): - Data(Data), Sentry(Sentry), Offset(Offset) {} + VectorStream(ContainerT& Data, value_type Sentry, std::size_t Offset = 0): + Data(Data), Sentry(Sentry), Offset(Offset) {} - value_type get() override { - return Offset < Data.size() ? Data[Offset++] : Sentry; + value_type get() override { + return Offset < Data.size() ? Data[Offset++] : Sentry; + } + + value_type peek(std::size_t Offset2) override { + auto I = Offset + Offset2; + return I < Data.size() ? Data[I] : Sentry; + } + +}; + +template +class BufferedStream : public Stream { + + std::deque Buffer; + +protected: + + virtual T read() = 0; + +public: + + using value_type = T; + + value_type get() override { + if (Buffer.empty()) { + return read(); + } else { + auto Keep = Buffer.front(); + Buffer.pop_front(); + Keep->unref(); + return Keep; } + } - value_type peek(std::size_t Offset2) override { - auto I = Offset + Offset2; - return I < Data.size() ? Data[I] : Sentry; + value_type peek(std::size_t Offset = 0) override { + while (Buffer.size() <= Offset) { + auto Item = read(); + Item->ref(); + Buffer.push_back(Item); } + return Buffer[Offset]; + } - }; - - template - class BufferedStream : public Stream { - - std::deque Buffer; - - protected: - - virtual T read() = 0; - - public: - - using value_type = T; - - value_type get() override { - if (Buffer.empty()) { - return read(); - } else { - auto Keep = Buffer.front(); - Buffer.pop_front(); - Keep->unref(); - return Keep; - } - } - - value_type peek(std::size_t Offset = 0) override { - while (Buffer.size() <= Offset) { - auto Item = read(); - Item->ref(); - Buffer.push_back(Item); - } - return Buffer[Offset]; - } - - }; +}; } diff --git a/bootstrap/cxx/include/bolt/String.hpp b/bootstrap/cxx/include/bolt/String.hpp index e62376bc9..121ea1c91 100644 --- a/bootstrap/cxx/include/bolt/String.hpp +++ b/bootstrap/cxx/include/bolt/String.hpp @@ -7,9 +7,9 @@ namespace bolt { - using Char = char; +using Char = char; - using String = std::basic_string; +using String = std::basic_string; } diff --git a/bootstrap/cxx/include/bolt/Support/Graph.hpp b/bootstrap/cxx/include/bolt/Support/Graph.hpp index 4dc3659e8..e20022f67 100644 --- a/bootstrap/cxx/include/bolt/Support/Graph.hpp +++ b/bootstrap/cxx/include/bolt/Support/Graph.hpp @@ -11,135 +11,136 @@ namespace bolt { - template - class Graph { +template +class Graph { - std::unordered_set Vertices; - std::unordered_multimap Edges; + std::unordered_set Vertices; + std::unordered_multimap Edges; - public: +public: - void addVertex(V Vert) { - Vertices.emplace(Vert); - } + void addVertex(V Vert) { + Vertices.emplace(Vert); + } - void addEdge(V A, V B) { - Vertices.emplace(A); - Vertices.emplace(B); - Edges.emplace(A, B); - } + void addEdge(V A, V B) { + Vertices.emplace(A); + Vertices.emplace(B); + Edges.emplace(A, B); + } - std::size_t countVertices() const { - return Vertices.size(); - } + std::size_t countVertices() const { + return Vertices.size(); + } - bool hasVertex(const V& Vert) const { - return Vertices.count(Vert); - } + bool hasVertex(const V& Vert) const { + return Vertices.count(Vert); + } - bool hasEdge(const V& From) const { - return Edges.count(From); - } + bool hasEdge(const V& From) const { + return Edges.count(From); + } - bool hasEdge(const V& From, const V& To) const { - for (auto X: Edges.equal_range(From)) { - if (X == To) { - return true; - } + bool hasEdge(const V& From, const V& To) const { + for (auto X: Edges.equal_range(From)) { + if (X == To) { + return true; } } + } - auto getTargetVertices(const V& From) const { - return zen::make_iterator_range(Edges.equal_range(From)).map_second(); - } + auto getTargetVertices(const V& From) const { + return zen::make_iterator_range(Edges.equal_range(From)).map_second(); + } - auto getVertices() const { - return zen::make_iterator_range(Vertices); - } + auto getVertices() const { + return zen::make_iterator_range(Vertices); + } + +private: + + struct TarjanVertexData { + std::optional Index; + std::size_t LowLink; + bool OnStack = false; + }; + + class TarjanSolver { + public: + + std::vector> SCCs; private: - struct TarjanVertexData { - std::optional Index; - std::size_t LowLink; - bool OnStack = false; - }; + const Graph& G; + std::unordered_map Map; + std::size_t Index = 0; + std::stack Stack; - class TarjanSolver { - public: + TarjanVertexData& getData(V From) { + return Map.emplace(From, TarjanVertexData {}).first->second; + } - std::vector> SCCs; + void visitCycle(const V& From) { - private: + auto& DataFrom = getData(From); + DataFrom.Index = Index; + DataFrom.LowLink = Index; + Index++; + Stack.push(From); + DataFrom.OnStack = true; - const Graph& G; - std::unordered_map Map; - std::size_t Index = 0; - std::stack Stack; - - TarjanVertexData& getData(V From) { - return Map.emplace(From, TarjanVertexData {}).first->second; - } - - void visitCycle(const V& From) { - - auto& DataFrom = getData(From); - DataFrom.Index = Index; - DataFrom.LowLink = Index; - Index++; - Stack.push(From); - DataFrom.OnStack = true; - - for (const auto& To: G.getTargetVertices(From)) { - auto& DataTo = getData(To); - if (!DataTo.Index) { - visitCycle(To); - DataFrom.LowLink = std::min(DataFrom.LowLink, DataTo.LowLink); - } else if (DataTo.OnStack) { - DataFrom.LowLink = std::min(DataFrom.LowLink, *DataTo.Index); - } - } - - if (DataFrom.LowLink == DataFrom.Index) { - std::vector SCC; - for (;;) { - auto& X = Stack.top(); - Stack.pop(); - auto& DataX = getData(X); - DataX.OnStack = false; - SCC.push_back(X); - if (X == From) { - break; - } - } - SCCs.push_back(SCC); - } - - } - - public: - - TarjanSolver(const Graph& G): - G(G) {} - - void solve() { - for (auto From: G.Vertices) { - if (!Map.count(From)) { - visitCycle(From); - } + for (const auto& To: G.getTargetVertices(From)) { + auto& DataTo = getData(To); + if (!DataTo.Index) { + visitCycle(To); + DataFrom.LowLink = std::min(DataFrom.LowLink, DataTo.LowLink); + } else if (DataTo.OnStack) { + DataFrom.LowLink = std::min(DataFrom.LowLink, *DataTo.Index); } } - }; + if (DataFrom.LowLink == DataFrom.Index) { + std::vector SCC; + for (;;) { + auto& X = Stack.top(); + Stack.pop(); + auto& DataX = getData(X); + DataX.OnStack = false; + SCC.push_back(X); + if (X == From) { + break; + } + } + SCCs.push_back(SCC); + } - public: + } - std::vector> strongconnect() const { - TarjanSolver S { *this }; - S.solve(); - return S.SCCs; + public: + + TarjanSolver(const Graph& G): + G(G) {} + + void solve() { + for (auto From: G.Vertices) { + if (!Map.count(From)) { + visitCycle(From); + } + } } }; + +public: + + std::vector> strongconnect() const { + TarjanSolver S { *this }; + S.solve(); + return S.SCCs; + } + +}; + } diff --git a/bootstrap/cxx/include/bolt/Type.hpp b/bootstrap/cxx/include/bolt/Type.hpp index 091cbff16..e51ac3eb4 100644 --- a/bootstrap/cxx/include/bolt/Type.hpp +++ b/bootstrap/cxx/include/bolt/Type.hpp @@ -15,741 +15,741 @@ namespace bolt { - class Type; - class TCon; +class Type; +class TCon; + +using TypeclassId = ByteString; + +using TypeclassContext = std::unordered_set; + +struct TypeclassSignature { using TypeclassId = ByteString; + TypeclassId Id; + std::vector Params; - using TypeclassContext = std::unordered_set; + bool operator<(const TypeclassSignature& Other) const; + bool operator==(const TypeclassSignature& Other) const; - struct TypeclassSignature { +}; - using TypeclassId = ByteString; - TypeclassId Id; - std::vector Params; +struct TypeSig { + Type* Orig; + Type* Op; + std::vector Args; +}; - bool operator<(const TypeclassSignature& Other) const; - bool operator==(const TypeclassSignature& Other) const; +enum class TypeIndexKind { + AppOpType, + AppArgType, + ArrowParamType, + ArrowReturnType, + TupleElement, + FieldType, + FieldRestType, + PresentType, + End, +}; +class TypeIndex { +protected: + + friend class Type; + friend class TypeIterator; + + TypeIndexKind Kind; + + union { + std::size_t I; }; - struct TypeSig { - Type* Orig; - Type* Op; - std::vector Args; + TypeIndex(TypeIndexKind Kind): + Kind(Kind) {} + + TypeIndex(TypeIndexKind Kind, std::size_t I): + Kind(Kind), I(I) {} + +public: + + bool operator==(const TypeIndex& Other) const noexcept; + + void advance(const Type* Ty); + + static TypeIndex forFieldType() { + return { TypeIndexKind::FieldType }; + } + + static TypeIndex forFieldRest() { + return { TypeIndexKind::FieldRestType }; + } + + static TypeIndex forArrowParamType() { + return { TypeIndexKind::ArrowParamType }; + } + + static TypeIndex forArrowReturnType() { + return { TypeIndexKind::ArrowReturnType }; + } + + static TypeIndex forTupleElement(std::size_t I) { + return { TypeIndexKind::TupleElement, I }; + } + + static TypeIndex forAppOpType() { + return { TypeIndexKind::AppOpType }; + } + + static TypeIndex forAppArgType() { + return { TypeIndexKind::AppArgType }; + } + + static TypeIndex forPresentType() { + return { TypeIndexKind::PresentType }; + } + +}; + +class TypeIterator { + + friend class Type; + + Type* Ty; + TypeIndex Index; + + TypeIterator(Type* Ty, TypeIndex Index): + Ty(Ty), Index(Index) {} + +public: + + TypeIterator& operator++() noexcept { + Index.advance(Ty); + return *this; + } + + bool operator==(const TypeIterator& Other) const noexcept { + return Ty == Other.Ty && Index == Other.Index; + } + + Type* operator*() { + return Ty; + } + + TypeIndex getIndex() const noexcept { + return Index; + } + +}; + +using TypePath = std::vector; + +using TVSub = std::unordered_map; +using TVSet = std::unordered_set; + +enum class TypeKind : unsigned char { + Var, + Con, + App, + Arrow, + Tuple, + Field, + Nil, + Absent, + Present, +}; + +class Type; + +struct TCon { + size_t Id; + ByteString DisplayName; + + bool operator==(const TCon& Other) const; + +}; + +struct TApp { + Type* Op; + Type* Arg; + + bool operator==(const TApp& Other) const; + +}; + +enum class VarKind { + Rigid, + Unification, +}; + +struct TVar { + VarKind VK; + size_t Id; + TypeclassContext Context; + std::optional Name; + std::optional Provided; + + VarKind getKind() const { + return VK; + } + + bool isUni() const { + return VK == VarKind::Unification; + } + + bool isRigid() const { + return VK == VarKind::Rigid; + } + + bool operator==(const TVar& Other) const; + +}; + +struct TArrow { + Type* ParamType; + Type* ReturnType; + + bool operator==(const TArrow& Other) const; + +}; + +struct TTuple { + std::vector ElementTypes; + + bool operator==(const TTuple& Other) const; + +}; + +struct TNil { + bool operator==(const TNil& Other) const; +}; + +struct TField { + ByteString Name; + Type* Ty; + Type* RestTy; + bool operator==(const TField& Other) const; +}; + +struct TAbsent { + bool operator==(const TAbsent& Other) const; +}; + +struct TPresent { + Type* Ty; + bool operator==(const TPresent& Other) const; +}; + +struct Type { + + TypeKind Kind; + + Type* Parent = this; + + union { + TCon Con; + TApp App; + TVar Var; + TArrow Arrow; + TTuple Tuple; + TNil Nil; + TField Field; + TAbsent Absent; + TPresent Present; }; - enum class TypeIndexKind { - AppOpType, - AppArgType, - ArrowParamType, - ArrowReturnType, - TupleElement, - FieldType, - FieldRestType, - PresentType, - End, - }; + Type(TCon&& Con): + Kind(TypeKind::Con), Con(std::move(Con)) {}; - class TypeIndex { - protected: + Type(TApp&& App): + Kind(TypeKind::App), App(std::move(App)) {}; - friend class Type; - friend class TypeIterator; + Type(TVar&& Var): + Kind(TypeKind::Var), Var(std::move(Var)) {}; - TypeIndexKind Kind; + Type(TArrow&& Arrow): + Kind(TypeKind::Arrow), Arrow(std::move(Arrow)) {}; - union { - std::size_t I; - }; + Type(TTuple&& Tuple): + Kind(TypeKind::Tuple), Tuple(std::move(Tuple)) {}; - TypeIndex(TypeIndexKind Kind): - Kind(Kind) {} + Type(TNil&& Nil): + Kind(TypeKind::Nil), Nil(std::move(Nil)) {}; - TypeIndex(TypeIndexKind Kind, std::size_t I): - Kind(Kind), I(I) {} + Type(TField&& Field): + Kind(TypeKind::Field), Field(std::move(Field)) {}; - public: + Type(TAbsent&& Absent): + Kind(TypeKind::Absent), Absent(std::move(Absent)) {}; - bool operator==(const TypeIndex& Other) const noexcept; + Type(TPresent&& Present): + Kind(TypeKind::Present), Present(std::move(Present)) {}; - void advance(const Type* Ty); - - static TypeIndex forFieldType() { - return { TypeIndexKind::FieldType }; + Type(const Type& Other): Kind(Other.Kind) { + switch (Kind) { + case TypeKind::Con: + new (&Con)TCon(Other.Con); + break; + case TypeKind::App: + new (&App)TApp(Other.App); + break; + case TypeKind::Var: + new (&Var)TVar(Other.Var); + break; + case TypeKind::Arrow: + new (&Arrow)TArrow(Other.Arrow); + break; + case TypeKind::Tuple: + new (&Tuple)TTuple(Other.Tuple); + break; + case TypeKind::Nil: + new (&Nil)TNil(Other.Nil); + break; + case TypeKind::Field: + new (&Field)TField(Other.Field); + break; + case TypeKind::Absent: + new (&Absent)TAbsent(Other.Absent); + break; + case TypeKind::Present: + new (&Present)TPresent(Other.Present); + break; } + } - static TypeIndex forFieldRest() { - return { TypeIndexKind::FieldRestType }; + Type(Type&& Other): Kind(std::move(Other.Kind)) { + switch (Kind) { + case TypeKind::Con: + new (&Con)TCon(std::move(Other.Con)); + break; + case TypeKind::App: + new (&App)TApp(std::move(Other.App)); + break; + case TypeKind::Var: + new (&Var)TVar(std::move(Other.Var)); + break; + case TypeKind::Arrow: + new (&Arrow)TArrow(std::move(Other.Arrow)); + break; + case TypeKind::Tuple: + new (&Tuple)TTuple(std::move(Other.Tuple)); + break; + case TypeKind::Nil: + new (&Nil)TNil(std::move(Other.Nil)); + break; + case TypeKind::Field: + new (&Field)TField(std::move(Other.Field)); + break; + case TypeKind::Absent: + new (&Absent)TAbsent(std::move(Other.Absent)); + break; + case TypeKind::Present: + new (&Present)TPresent(std::move(Other.Present)); + break; } + } - static TypeIndex forArrowParamType() { - return { TypeIndexKind::ArrowParamType }; + TypeKind getKind() const { + return Kind; + } + + bool isVarRigid() const { + return Kind == TypeKind::Var + && asVar().getKind() == VarKind::Rigid; + } + + bool isVar() const { + return Kind == TypeKind::Var; + } + + TVar& asVar() { + ZEN_ASSERT(Kind == TypeKind::Var); + return Var; + } + + const TVar& asVar() const { + ZEN_ASSERT(Kind == TypeKind::Var); + return Var; + } + + bool isApp() const { + return Kind == TypeKind::App; + } + + TApp& asApp() { + ZEN_ASSERT(Kind == TypeKind::App); + return App; + } + + const TApp& asApp() const { + ZEN_ASSERT(Kind == TypeKind::App); + return App; + } + + bool isCon() const { + return Kind == TypeKind::Con; + } + + TCon& asCon() { + ZEN_ASSERT(Kind == TypeKind::Con); + return Con; + } + + const TCon& asCon() const { + ZEN_ASSERT(Kind == TypeKind::Con); + return Con; + } + + bool isArrow() const { + return Kind == TypeKind::Arrow; + } + + TArrow& asArrow() { + ZEN_ASSERT(Kind == TypeKind::Arrow); + return Arrow; + } + + const TArrow& asArrow() const { + ZEN_ASSERT(Kind == TypeKind::Arrow); + return Arrow; + } + + bool isTuple() const { + return Kind == TypeKind::Tuple; + } + + TTuple& asTuple() { + ZEN_ASSERT(Kind == TypeKind::Tuple); + return Tuple; + } + + const TTuple& asTuple() const { + ZEN_ASSERT(Kind == TypeKind::Tuple); + return Tuple; + } + + bool isField() const { + return Kind == TypeKind::Field; + } + + TField& asField() { + ZEN_ASSERT(Kind == TypeKind::Field); + return Field; + } + + const TField& asField() const { + ZEN_ASSERT(Kind == TypeKind::Field); + return Field; + } + + bool isAbsent() const { + return Kind == TypeKind::Absent; + } + + TAbsent& asAbsent() { + ZEN_ASSERT(Kind == TypeKind::Absent); + return Absent; + } + const TAbsent& asAbsent() const { + ZEN_ASSERT(Kind == TypeKind::Absent); + return Absent; + } + + bool isPresent() const { + return Kind == TypeKind::Present; + } + + TPresent& asPresent() { + ZEN_ASSERT(Kind == TypeKind::Present); + return Present; + } + const TPresent& asPresent() const { + ZEN_ASSERT(Kind == TypeKind::Present); + return Present; + } + + bool isNil() const { + return Kind == TypeKind::Nil; + } + + TNil& asNil() { + ZEN_ASSERT(Kind == TypeKind::Nil); + return Nil; + } + const TNil& asNil() const { + ZEN_ASSERT(Kind == TypeKind::Nil); + return Nil; + } + + Type* rewrite(std::function Fn, bool Recursive = true); + + Type* resolve(const TypeIndex& Index) const noexcept; + + Type* resolve(const TypePath& Path) noexcept { + Type* Ty = this; + for (auto El: Path) { + Ty = Ty->resolve(El); } + return Ty; + } - static TypeIndex forArrowReturnType() { - return { TypeIndexKind::ArrowReturnType }; + void set(Type* Ty) { + auto Root = find(); + // It is not possible to set a solution twice. + if (isVar()) { + ZEN_ASSERT(Root->isVar()); } + Root->Parent = Ty; + } - static TypeIndex forTupleElement(std::size_t I) { - return { TypeIndexKind::TupleElement, I }; - } - - static TypeIndex forAppOpType() { - return { TypeIndexKind::AppOpType }; - } - - static TypeIndex forAppArgType() { - return { TypeIndexKind::AppArgType }; - } - - static TypeIndex forPresentType() { - return { TypeIndexKind::PresentType }; - } - - }; - - class TypeIterator { - - friend class Type; - - Type* Ty; - TypeIndex Index; - - TypeIterator(Type* Ty, TypeIndex Index): - Ty(Ty), Index(Index) {} - - public: - - TypeIterator& operator++() noexcept { - Index.advance(Ty); - return *this; - } - - bool operator==(const TypeIterator& Other) const noexcept { - return Ty == Other.Ty && Index == Other.Index; - } - - Type* operator*() { - return Ty; - } - - TypeIndex getIndex() const noexcept { - return Index; - } - - }; - - using TypePath = std::vector; - - using TVSub = std::unordered_map; - using TVSet = std::unordered_set; - - enum class TypeKind : unsigned char { - Var, - Con, - App, - Arrow, - Tuple, - Field, - Nil, - Absent, - Present, - }; - - class Type; - - struct TCon { - size_t Id; - ByteString DisplayName; - - bool operator==(const TCon& Other) const; - - }; - - struct TApp { - Type* Op; - Type* Arg; - - bool operator==(const TApp& Other) const; - - }; - - enum class VarKind { - Rigid, - Unification, - }; - - struct TVar { - VarKind VK; - size_t Id; - TypeclassContext Context; - std::optional Name; - std::optional Provided; - - VarKind getKind() const { - return VK; - } - - bool isUni() const { - return VK == VarKind::Unification; - } - - bool isRigid() const { - return VK == VarKind::Rigid; - } - - bool operator==(const TVar& Other) const; - - }; - - struct TArrow { - Type* ParamType; - Type* ReturnType; - - bool operator==(const TArrow& Other) const; - - }; - - struct TTuple { - std::vector ElementTypes; - - bool operator==(const TTuple& Other) const; - - }; - - struct TNil { - bool operator==(const TNil& Other) const; - }; - - struct TField { - ByteString Name; - Type* Ty; - Type* RestTy; - bool operator==(const TField& Other) const; - }; - - struct TAbsent { - bool operator==(const TAbsent& Other) const; - }; - - struct TPresent { - Type* Ty; - bool operator==(const TPresent& Other) const; - }; - - struct Type { - - TypeKind Kind; - - Type* Parent = this; - - union { - TCon Con; - TApp App; - TVar Var; - TArrow Arrow; - TTuple Tuple; - TNil Nil; - TField Field; - TAbsent Absent; - TPresent Present; - }; - - Type(TCon&& Con): - Kind(TypeKind::Con), Con(std::move(Con)) {}; - - Type(TApp&& App): - Kind(TypeKind::App), App(std::move(App)) {}; - - Type(TVar&& Var): - Kind(TypeKind::Var), Var(std::move(Var)) {}; - - Type(TArrow&& Arrow): - Kind(TypeKind::Arrow), Arrow(std::move(Arrow)) {}; - - Type(TTuple&& Tuple): - Kind(TypeKind::Tuple), Tuple(std::move(Tuple)) {}; - - Type(TNil&& Nil): - Kind(TypeKind::Nil), Nil(std::move(Nil)) {}; - - Type(TField&& Field): - Kind(TypeKind::Field), Field(std::move(Field)) {}; - - Type(TAbsent&& Absent): - Kind(TypeKind::Absent), Absent(std::move(Absent)) {}; - - Type(TPresent&& Present): - Kind(TypeKind::Present), Present(std::move(Present)) {}; - - Type(const Type& Other): Kind(Other.Kind) { - switch (Kind) { - case TypeKind::Con: - new (&Con)TCon(Other.Con); - break; - case TypeKind::App: - new (&App)TApp(Other.App); - break; - case TypeKind::Var: - new (&Var)TVar(Other.Var); - break; - case TypeKind::Arrow: - new (&Arrow)TArrow(Other.Arrow); - break; - case TypeKind::Tuple: - new (&Tuple)TTuple(Other.Tuple); - break; - case TypeKind::Nil: - new (&Nil)TNil(Other.Nil); - break; - case TypeKind::Field: - new (&Field)TField(Other.Field); - break; - case TypeKind::Absent: - new (&Absent)TAbsent(Other.Absent); - break; - case TypeKind::Present: - new (&Present)TPresent(Other.Present); - break; + Type* find() const { + Type* Curr = const_cast(this); + for (;;) { + auto Keep = Curr->Parent; + if (Keep == Curr) { + return Keep; } + Curr->Parent = Keep->Parent; + Curr = Keep; } + } - Type(Type&& Other): Kind(std::move(Other.Kind)) { - switch (Kind) { - case TypeKind::Con: - new (&Con)TCon(std::move(Other.Con)); - break; - case TypeKind::App: - new (&App)TApp(std::move(Other.App)); - break; - case TypeKind::Var: - new (&Var)TVar(std::move(Other.Var)); - break; - case TypeKind::Arrow: - new (&Arrow)TArrow(std::move(Other.Arrow)); - break; - case TypeKind::Tuple: - new (&Tuple)TTuple(std::move(Other.Tuple)); - break; - case TypeKind::Nil: - new (&Nil)TNil(std::move(Other.Nil)); - break; - case TypeKind::Field: - new (&Field)TField(std::move(Other.Field)); - break; - case TypeKind::Absent: - new (&Absent)TAbsent(std::move(Other.Absent)); - break; - case TypeKind::Present: - new (&Present)TPresent(std::move(Other.Present)); - break; + bool operator==(const Type& Other) const; + + void destroy() { + switch (Kind) { + case TypeKind::Con: + App.~TApp(); + break; + case TypeKind::App: + App.~TApp(); + break; + case TypeKind::Var: + Var.~TVar(); + break; + case TypeKind::Arrow: + Arrow.~TArrow(); + break; + case TypeKind::Tuple: + Tuple.~TTuple(); + break; + case TypeKind::Nil: + Nil.~TNil(); + break; + case TypeKind::Field: + Field.~TField(); + break; + case TypeKind::Absent: + Absent.~TAbsent(); + break; + case TypeKind::Present: + Present.~TPresent(); + break; + } + } + + Type& operator=(Type& Other) { + destroy(); + Kind = Other.Kind; + switch (Kind) { + case TypeKind::Con: + App = Other.App; + break; + case TypeKind::App: + App = Other.App; + break; + case TypeKind::Var: + Var = Other.Var; + break; + case TypeKind::Arrow: + Arrow = Other.Arrow; + break; + case TypeKind::Tuple: + Tuple = Other.Tuple; + break; + case TypeKind::Nil: + Nil = Other.Nil; + break; + case TypeKind::Field: + Field = Other.Field; + break; + case TypeKind::Absent: + Absent = Other.Absent; + break; + case TypeKind::Present: + Present = Other.Present; + break; + } + return *this; + } + + bool hasTypeVar(Type* TV) const; + + TypeIterator begin(); + TypeIterator end(); + + TypeIndex getStartIndex() const; + TypeIndex getEndIndex() const; + + Type* substitute(const TVSub& Sub); + + void visitEachChild(std::function Proc); + + TVSet getTypeVars(); + + ~Type() { + destroy(); + } + + static Type* buildArrow(std::vector ParamTypes, Type* ReturnType) { + Type* Curr = ReturnType; + for (auto Iter = ParamTypes.rbegin(); Iter != ParamTypes.rend(); ++Iter) { + Curr = new Type(TArrow(*Iter, Curr)); + } + return Curr; + } + +}; + +template +class TypeVisitorBase { +protected: + + template + using C = std::conditional::type; + + virtual void enterType(C* Ty) {} + virtual void exitType(C* Ty) {} + + // virtual void visitType(C* Ty) { + // visitEachChild(Ty); + // } + + virtual void visitVarType(C& Ty) { + } + + virtual void visitAppType(C& Ty) { + visit(Ty.Op); + visit(Ty.Arg); + } + + virtual void visitPresentType(C& Ty) { + visit(Ty.Ty); + } + + virtual void visitConType(C& Ty) { + } + + virtual void visitArrowType(C& Ty) { + visit(Ty.ParamType); + visit(Ty.ReturnType); + } + + virtual void visitTupleType(C& Ty) { + for (auto ElTy: Ty.ElementTypes) { + visit(ElTy); + } + } + + virtual void visitAbsentType(C& Ty) { + } + + virtual void visitFieldType(C& Ty) { + visit(Ty.Ty); + visit(Ty.RestTy); + } + + virtual void visitNilType(C& Ty) { + } + +public: + + void visitEachChild(C* Ty) { + switch (Ty->getKind()) { + case TypeKind::Var: + case TypeKind::Absent: + case TypeKind::Nil: + case TypeKind::Con: + break; + case TypeKind::Arrow: + { + auto& Arrow = Ty->asArrow(); + visit(Arrow->ParamType); + visit(Arrow->ReturnType); + break; } - } - - TypeKind getKind() const { - return Kind; - } - - bool isVarRigid() const { - return Kind == TypeKind::Var - && asVar().getKind() == VarKind::Rigid; - } - - bool isVar() const { - return Kind == TypeKind::Var; - } - - TVar& asVar() { - ZEN_ASSERT(Kind == TypeKind::Var); - return Var; - } - - const TVar& asVar() const { - ZEN_ASSERT(Kind == TypeKind::Var); - return Var; - } - - bool isApp() const { - return Kind == TypeKind::App; - } - - TApp& asApp() { - ZEN_ASSERT(Kind == TypeKind::App); - return App; - } - - const TApp& asApp() const { - ZEN_ASSERT(Kind == TypeKind::App); - return App; - } - - bool isCon() const { - return Kind == TypeKind::Con; - } - - TCon& asCon() { - ZEN_ASSERT(Kind == TypeKind::Con); - return Con; - } - - const TCon& asCon() const { - ZEN_ASSERT(Kind == TypeKind::Con); - return Con; - } - - bool isArrow() const { - return Kind == TypeKind::Arrow; - } - - TArrow& asArrow() { - ZEN_ASSERT(Kind == TypeKind::Arrow); - return Arrow; - } - - const TArrow& asArrow() const { - ZEN_ASSERT(Kind == TypeKind::Arrow); - return Arrow; - } - - bool isTuple() const { - return Kind == TypeKind::Tuple; - } - - TTuple& asTuple() { - ZEN_ASSERT(Kind == TypeKind::Tuple); - return Tuple; - } - - const TTuple& asTuple() const { - ZEN_ASSERT(Kind == TypeKind::Tuple); - return Tuple; - } - - bool isField() const { - return Kind == TypeKind::Field; - } - - TField& asField() { - ZEN_ASSERT(Kind == TypeKind::Field); - return Field; - } - - const TField& asField() const { - ZEN_ASSERT(Kind == TypeKind::Field); - return Field; - } - - bool isAbsent() const { - return Kind == TypeKind::Absent; - } - - TAbsent& asAbsent() { - ZEN_ASSERT(Kind == TypeKind::Absent); - return Absent; - } - const TAbsent& asAbsent() const { - ZEN_ASSERT(Kind == TypeKind::Absent); - return Absent; - } - - bool isPresent() const { - return Kind == TypeKind::Present; - } - - TPresent& asPresent() { - ZEN_ASSERT(Kind == TypeKind::Present); - return Present; - } - const TPresent& asPresent() const { - ZEN_ASSERT(Kind == TypeKind::Present); - return Present; - } - - bool isNil() const { - return Kind == TypeKind::Nil; - } - - TNil& asNil() { - ZEN_ASSERT(Kind == TypeKind::Nil); - return Nil; - } - const TNil& asNil() const { - ZEN_ASSERT(Kind == TypeKind::Nil); - return Nil; - } - - Type* rewrite(std::function Fn, bool Recursive = true); - - Type* resolve(const TypeIndex& Index) const noexcept; - - Type* resolve(const TypePath& Path) noexcept { - Type* Ty = this; - for (auto El: Path) { - Ty = Ty->resolve(El); - } - return Ty; - } - - void set(Type* Ty) { - auto Root = find(); - // It is not possible to set a solution twice. - if (isVar()) { - ZEN_ASSERT(Root->isVar()); - } - Root->Parent = Ty; - } - - Type* find() const { - Type* Curr = const_cast(this); - for (;;) { - auto Keep = Curr->Parent; - if (Keep == Curr) { - return Keep; + case TypeKind::Tuple: + { + auto& Tuple = Ty->asTuple(); + for (auto I = 0; I < Tuple->ElementTypes.size(); ++I) { + visit(Tuple->ElementTypes[I]); } - Curr->Parent = Keep->Parent; - Curr = Keep; + break; + } + case TypeKind::App: + { + auto& App = Ty->asApp(); + visit(App->Op); + visit(App->Arg); + break; + } + case TypeKind::Field: + { + auto& Field = Ty->asField(); + visit(Field->Ty); + visit(Field->RestTy); + break; + } + case TypeKind::Present: + { + auto& Present = Ty->asPresent(); + visit(Present->Ty); + break; } } + } - bool operator==(const Type& Other) const; + void visit(C* Ty) { - void destroy() { - switch (Kind) { - case TypeKind::Con: - App.~TApp(); - break; - case TypeKind::App: - App.~TApp(); - break; - case TypeKind::Var: - Var.~TVar(); - break; - case TypeKind::Arrow: - Arrow.~TArrow(); - break; - case TypeKind::Tuple: - Tuple.~TTuple(); - break; - case TypeKind::Nil: - Nil.~TNil(); - break; - case TypeKind::Field: - Field.~TField(); - break; - case TypeKind::Absent: - Absent.~TAbsent(); - break; - case TypeKind::Present: - Present.~TPresent(); - break; - } + // Always look at the most solved solution + Ty = Ty->find(); + + enterType(Ty); + switch (Ty->getKind()) { + case TypeKind::Present: + visitPresentType(Ty->asPresent()); + break; + case TypeKind::Absent: + visitAbsentType(Ty->asAbsent()); + break; + case TypeKind::Nil: + visitNilType(Ty->asNil()); + break; + case TypeKind::Field: + visitFieldType(Ty->asField()); + break; + case TypeKind::Con: + visitConType(Ty->asCon()); + break; + case TypeKind::Arrow: + visitArrowType(Ty->asArrow()); + break; + case TypeKind::Var: + visitVarType(Ty->asVar()); + break; + case TypeKind::Tuple: + visitTupleType(Ty->asTuple()); + break; + case TypeKind::App: + visitAppType(Ty->asApp()); + break; } + exitType(Ty); + } - Type& operator=(Type& Other) { - destroy(); - Kind = Other.Kind; - switch (Kind) { - case TypeKind::Con: - App = Other.App; - break; - case TypeKind::App: - App = Other.App; - break; - case TypeKind::Var: - Var = Other.Var; - break; - case TypeKind::Arrow: - Arrow = Other.Arrow; - break; - case TypeKind::Tuple: - Tuple = Other.Tuple; - break; - case TypeKind::Nil: - Nil = Other.Nil; - break; - case TypeKind::Field: - Field = Other.Field; - break; - case TypeKind::Absent: - Absent = Other.Absent; - break; - case TypeKind::Present: - Present = Other.Present; - break; - } - return *this; - } + virtual ~TypeVisitorBase() {} - bool hasTypeVar(Type* TV) const; +}; - TypeIterator begin(); - TypeIterator end(); - - TypeIndex getStartIndex() const; - TypeIndex getEndIndex() const; - - Type* substitute(const TVSub& Sub); - - void visitEachChild(std::function Proc); - - TVSet getTypeVars(); - - ~Type() { - destroy(); - } - - static Type* buildArrow(std::vector ParamTypes, Type* ReturnType) { - Type* Curr = ReturnType; - for (auto Iter = ParamTypes.rbegin(); Iter != ParamTypes.rend(); ++Iter) { - Curr = new Type(TArrow(*Iter, Curr)); - } - return Curr; - } - - }; - - template - class TypeVisitorBase { - protected: - - template - using C = std::conditional::type; - - virtual void enterType(C* Ty) {} - virtual void exitType(C* Ty) {} - - // virtual void visitType(C* Ty) { - // visitEachChild(Ty); - // } - - virtual void visitVarType(C& Ty) { - } - - virtual void visitAppType(C& Ty) { - visit(Ty.Op); - visit(Ty.Arg); - } - - virtual void visitPresentType(C& Ty) { - visit(Ty.Ty); - } - - virtual void visitConType(C& Ty) { - } - - virtual void visitArrowType(C& Ty) { - visit(Ty.ParamType); - visit(Ty.ReturnType); - } - - virtual void visitTupleType(C& Ty) { - for (auto ElTy: Ty.ElementTypes) { - visit(ElTy); - } - } - - virtual void visitAbsentType(C& Ty) { - } - - virtual void visitFieldType(C& Ty) { - visit(Ty.Ty); - visit(Ty.RestTy); - } - - virtual void visitNilType(C& Ty) { - } - - public: - - void visitEachChild(C* Ty) { - switch (Ty->getKind()) { - case TypeKind::Var: - case TypeKind::Absent: - case TypeKind::Nil: - case TypeKind::Con: - break; - case TypeKind::Arrow: - { - auto& Arrow = Ty->asArrow(); - visit(Arrow->ParamType); - visit(Arrow->ReturnType); - break; - } - case TypeKind::Tuple: - { - auto& Tuple = Ty->asTuple(); - for (auto I = 0; I < Tuple->ElementTypes.size(); ++I) { - visit(Tuple->ElementTypes[I]); - } - break; - } - case TypeKind::App: - { - auto& App = Ty->asApp(); - visit(App->Op); - visit(App->Arg); - break; - } - case TypeKind::Field: - { - auto& Field = Ty->asField(); - visit(Field->Ty); - visit(Field->RestTy); - break; - } - case TypeKind::Present: - { - auto& Present = Ty->asPresent(); - visit(Present->Ty); - break; - } - } - } - - void visit(C* Ty) { - - // Always look at the most solved solution - Ty = Ty->find(); - - enterType(Ty); - switch (Ty->getKind()) { - case TypeKind::Present: - visitPresentType(Ty->asPresent()); - break; - case TypeKind::Absent: - visitAbsentType(Ty->asAbsent()); - break; - case TypeKind::Nil: - visitNilType(Ty->asNil()); - break; - case TypeKind::Field: - visitFieldType(Ty->asField()); - break; - case TypeKind::Con: - visitConType(Ty->asCon()); - break; - case TypeKind::Arrow: - visitArrowType(Ty->asArrow()); - break; - case TypeKind::Var: - visitVarType(Ty->asVar()); - break; - case TypeKind::Tuple: - visitTupleType(Ty->asTuple()); - break; - case TypeKind::App: - visitAppType(Ty->asApp()); - break; - } - exitType(Ty); - } - - virtual ~TypeVisitorBase() {} - - }; - - using TypeVisitor = TypeVisitorBase; - using ConstTypeVisitor = TypeVisitorBase; +using TypeVisitor = TypeVisitorBase; +using ConstTypeVisitor = TypeVisitorBase; } diff --git a/bootstrap/cxx/src/CST.cc b/bootstrap/cxx/src/CST.cc index 82643d7b9..9623124a7 100644 --- a/bootstrap/cxx/src/CST.cc +++ b/bootstrap/cxx/src/CST.cc @@ -6,1125 +6,1125 @@ namespace bolt { - TextFile::TextFile(ByteString Path, ByteString Text): - Path(Path), Text(Text) { - LineOffsets.push_back(0); - for (size_t I = 0; I < Text.size(); I++) { - auto Chr = Text[I]; - if (Chr == '\n') { - LineOffsets.push_back(I+1); - } - } - LineOffsets.push_back(Text.size()); - } - - size_t TextFile::getLineCount() const { - return LineOffsets.size()-1; - } - - size_t TextFile::getStartOffsetOfLine(size_t Line) const { - ZEN_ASSERT(Line-1 < LineOffsets.size()); - return LineOffsets[Line-1]; - } - - size_t TextFile::getEndOffsetOfLine(size_t Line) const { - ZEN_ASSERT(Line <= LineOffsets.size()); - if (Line == LineOffsets.size()) { - return Text.size(); - } - return LineOffsets[Line]; - } - - size_t TextFile::getLine(size_t Offset) const { - ZEN_ASSERT(Offset < Text.size()); - for (size_t I = 0; I < LineOffsets.size(); ++I) { - if (LineOffsets[I] > Offset) { - return I; +TextFile::TextFile(ByteString Path, ByteString Text): + Path(Path), Text(Text) { + LineOffsets.push_back(0); + for (size_t I = 0; I < Text.size(); I++) { + auto Chr = Text[I]; + if (Chr == '\n') { + LineOffsets.push_back(I+1); } } - ZEN_UNREACHABLE + LineOffsets.push_back(Text.size()); } - size_t TextFile::getColumn(size_t Offset) const { - auto Line = getLine(Offset); - auto StartOffset = getStartOffsetOfLine(Line); - return Offset - StartOffset + 1 ; +size_t TextFile::getLineCount() const { + return LineOffsets.size()-1; +} + +size_t TextFile::getStartOffsetOfLine(size_t Line) const { + ZEN_ASSERT(Line-1 < LineOffsets.size()); + return LineOffsets[Line-1]; +} + +size_t TextFile::getEndOffsetOfLine(size_t Line) const { + ZEN_ASSERT(Line <= LineOffsets.size()); + if (Line == LineOffsets.size()) { + return Text.size(); } + return LineOffsets[Line]; +} - ByteString TextFile::getPath() const { - return Path; - } - - ByteString TextFile::getText() const { - return Text; - } - - Scope::Scope(Node* Source): - Source(Source) { - scan(Source); - } - - void Scope::addSymbol(ByteString Name, Node* Decl, SymbolKind Kind) { - Mapping.emplace(Name, std::make_tuple(Decl, Kind)); - } - - void Scope::scan(Node* X) { - switch (X->getKind()) { - case NodeKind::SourceFile: - { - auto File = static_cast(X); - for (auto Element: File->Elements) { - scanChild(Element); - } - break; - } - case NodeKind::MatchCase: - { - auto Case = static_cast(X); - visitPattern(Case->Pattern, Case); - break; - } - case NodeKind::LetDeclaration: - { - auto Decl = static_cast(X); - ZEN_ASSERT(Decl->isFunction()); - for (auto Param: Decl->Params) { - visitPattern(Param->Pattern, Param); - } - if (Decl->Body) { - scanChild(Decl->Body); - } - break; - } - default: - ZEN_UNREACHABLE +size_t TextFile::getLine(size_t Offset) const { + ZEN_ASSERT(Offset < Text.size()); + for (size_t I = 0; I < LineOffsets.size(); ++I) { + if (LineOffsets[I] > Offset) { + return I; } } + ZEN_UNREACHABLE +} - void Scope::scanChild(Node* X) { - switch (X->getKind()) { - case NodeKind::LetExprBody: - case NodeKind::ExpressionStatement: - case NodeKind::IfStatement: - case NodeKind::ReturnStatement: - break; - case NodeKind::LetBlockBody: - { - auto Block = static_cast(X); - for (auto Element: Block->Elements) { - scanChild(Element); - } - break; +size_t TextFile::getColumn(size_t Offset) const { + auto Line = getLine(Offset); + auto StartOffset = getStartOffsetOfLine(Line); + return Offset - StartOffset + 1 ; +} + +ByteString TextFile::getPath() const { + return Path; +} + +ByteString TextFile::getText() const { + return Text; +} + +Scope::Scope(Node* Source): + Source(Source) { + scan(Source); + } + +void Scope::addSymbol(ByteString Name, Node* Decl, SymbolKind Kind) { + Mapping.emplace(Name, std::make_tuple(Decl, Kind)); +} + +void Scope::scan(Node* X) { + switch (X->getKind()) { + case NodeKind::SourceFile: + { + auto File = static_cast(X); + for (auto Element: File->Elements) { + scanChild(Element); } - case NodeKind::InstanceDeclaration: - // We ignore let-declarations inside instance-declarations for now - break; - case NodeKind::ClassDeclaration: - { - auto Decl = static_cast(X); - addSymbol(getCanonicalText(Decl->Name), Decl, SymbolKind::Class); - for (auto Element: Decl->Elements) { - scanChild(Element); - } - break; + break; + } + case NodeKind::MatchCase: + { + auto Case = static_cast(X); + visitPattern(Case->Pattern, Case); + break; + } + case NodeKind::LetDeclaration: + { + auto Decl = static_cast(X); + ZEN_ASSERT(Decl->isFunction()); + for (auto Param: Decl->Params) { + visitPattern(Param->Pattern, Param); } - case NodeKind::LetDeclaration: - { - auto Decl = static_cast(X); - // No matter if it is a function or a variable, by visiting the pattern - // we add all relevant bindings to the current scope. - visitPattern(Decl->Pattern, Decl); - break; + if (Decl->Body) { + scanChild(Decl->Body); } - case NodeKind::RecordDeclaration: - { - auto Decl = static_cast(X); - addSymbol(getCanonicalText(Decl->Name), Decl, SymbolKind::Type); - break; + break; + } + default: + ZEN_UNREACHABLE + } +} + +void Scope::scanChild(Node* X) { + switch (X->getKind()) { + case NodeKind::LetExprBody: + case NodeKind::ExpressionStatement: + case NodeKind::IfStatement: + case NodeKind::ReturnStatement: + break; + case NodeKind::LetBlockBody: + { + auto Block = static_cast(X); + for (auto Element: Block->Elements) { + scanChild(Element); } - case NodeKind::VariantDeclaration: - { - auto Decl = static_cast(X); - addSymbol(getCanonicalText(Decl->Name), Decl, SymbolKind::Type); - for (auto Member: Decl->Members) { - switch (Member->getKind()) { - case NodeKind::TupleVariantDeclarationMember: - { - auto T = static_cast(Member); - addSymbol(getCanonicalText(T->Name), Decl, SymbolKind::Constructor); - break; - } - case NodeKind::RecordVariantDeclarationMember: - { - auto R = static_cast(Member); - addSymbol(getCanonicalText(R->Name), Decl, SymbolKind::Constructor); - break; - } - default: - ZEN_UNREACHABLE + break; + } + case NodeKind::InstanceDeclaration: + // We ignore let-declarations inside instance-declarations for now + break; + case NodeKind::ClassDeclaration: + { + auto Decl = static_cast(X); + addSymbol(getCanonicalText(Decl->Name), Decl, SymbolKind::Class); + for (auto Element: Decl->Elements) { + scanChild(Element); + } + break; + } + case NodeKind::LetDeclaration: + { + auto Decl = static_cast(X); + // No matter if it is a function or a variable, by visiting the pattern + // we add all relevant bindings to the current scope. + visitPattern(Decl->Pattern, Decl); + break; + } + case NodeKind::RecordDeclaration: + { + auto Decl = static_cast(X); + addSymbol(getCanonicalText(Decl->Name), Decl, SymbolKind::Type); + break; + } + case NodeKind::VariantDeclaration: + { + auto Decl = static_cast(X); + addSymbol(getCanonicalText(Decl->Name), Decl, SymbolKind::Type); + for (auto Member: Decl->Members) { + switch (Member->getKind()) { + case NodeKind::TupleVariantDeclarationMember: + { + auto T = static_cast(Member); + addSymbol(getCanonicalText(T->Name), Decl, SymbolKind::Constructor); + break; } - } - break; - } - default: - ZEN_UNREACHABLE - } - } - - void Scope::visitPattern(Pattern* X, Node* Decl) { - switch (X->getKind()) { - case NodeKind::BindPattern: - { - auto Y = static_cast(X); - addSymbol(getCanonicalText(Y->Name), Decl, SymbolKind::Var); - break; - } - case NodeKind::RecordPattern: - { - auto Y = static_cast(X); - for (auto [Field, Comma]: Y->Fields) { - if (Field->Pattern) { - visitPattern(Field->Pattern, Decl); - } else if (Field->Name) { - addSymbol(Field->Name->Text, Decl, SymbolKind::Var); + case NodeKind::RecordVariantDeclarationMember: + { + auto R = static_cast(Member); + addSymbol(getCanonicalText(R->Name), Decl, SymbolKind::Constructor); + break; } + default: + ZEN_UNREACHABLE } - break; } - case NodeKind::NamedRecordPattern: - { - auto Y = static_cast(X); - for (auto [Field, Comma]: Y->Fields) { - if (Field->Pattern) { - visitPattern(Field->Pattern, Decl); - } else if (Field->Name) { - addSymbol(Field->Name->Text, Decl, SymbolKind::Var); - } - } - break; - } - case NodeKind::NamedTuplePattern: - { - auto Y = static_cast(X); - for (auto P: Y->Patterns) { - visitPattern(P, Decl); - } - break; - } - case NodeKind::NestedPattern: - { - auto Y = static_cast(X); - visitPattern(Y->P, Decl); - break; - } - case NodeKind::TuplePattern: - { - auto Y = static_cast(X); - for (auto [Element, Comma]: Y->Elements) { - visitPattern(Element, Decl); - } - break; - } - case NodeKind::ListPattern: - { - auto Y = static_cast(X); - for (auto [Element, Separator]: Y->Elements) { - visitPattern(Element, Decl); - } - break; - } - case NodeKind::LiteralPattern: - break; - default: - ZEN_UNREACHABLE + break; } + default: + ZEN_UNREACHABLE } +} - Node* Scope::lookupDirect(SymbolPath Path, SymbolKind Kind) { - ZEN_ASSERT(Path.Modules.empty()); - auto Match = Mapping.find(Path.Name); - if (Match != Mapping.end() && std::get<1>(Match->second) == Kind) { - return std::get<0>(Match->second); +void Scope::visitPattern(Pattern* X, Node* Decl) { + switch (X->getKind()) { + case NodeKind::BindPattern: + { + auto Y = static_cast(X); + addSymbol(getCanonicalText(Y->Name), Decl, SymbolKind::Var); + break; } + case NodeKind::RecordPattern: + { + auto Y = static_cast(X); + for (auto [Field, Comma]: Y->Fields) { + if (Field->Pattern) { + visitPattern(Field->Pattern, Decl); + } else if (Field->Name) { + addSymbol(Field->Name->Text, Decl, SymbolKind::Var); + } + } + break; + } + case NodeKind::NamedRecordPattern: + { + auto Y = static_cast(X); + for (auto [Field, Comma]: Y->Fields) { + if (Field->Pattern) { + visitPattern(Field->Pattern, Decl); + } else if (Field->Name) { + addSymbol(Field->Name->Text, Decl, SymbolKind::Var); + } + } + break; + } + case NodeKind::NamedTuplePattern: + { + auto Y = static_cast(X); + for (auto P: Y->Patterns) { + visitPattern(P, Decl); + } + break; + } + case NodeKind::NestedPattern: + { + auto Y = static_cast(X); + visitPattern(Y->P, Decl); + break; + } + case NodeKind::TuplePattern: + { + auto Y = static_cast(X); + for (auto [Element, Comma]: Y->Elements) { + visitPattern(Element, Decl); + } + break; + } + case NodeKind::ListPattern: + { + auto Y = static_cast(X); + for (auto [Element, Separator]: Y->Elements) { + visitPattern(Element, Decl); + } + break; + } + case NodeKind::LiteralPattern: + break; + default: + ZEN_UNREACHABLE + } +} + +Node* Scope::lookupDirect(SymbolPath Path, SymbolKind Kind) { + ZEN_ASSERT(Path.Modules.empty()); + auto Match = Mapping.find(Path.Name); + if (Match != Mapping.end() && std::get<1>(Match->second) == Kind) { + return std::get<0>(Match->second); + } + return nullptr; +} + +Node* Scope::lookup(SymbolPath Path, SymbolKind Kind) { + ZEN_ASSERT(Path.Modules.empty()); + auto Curr = this; + do { + auto Found = Curr->lookupDirect(Path, Kind); + if (Found) { + return Found; + } + Curr = Curr->getParentScope(); + } while (Curr != nullptr); + return nullptr; +} + +Scope* Scope::getParentScope() { + if (Source->Parent == nullptr) { return nullptr; } + return Source->Parent->getScope(); +} - Node* Scope::lookup(SymbolPath Path, SymbolKind Kind) { - ZEN_ASSERT(Path.Modules.empty()); - auto Curr = this; - do { - auto Found = Curr->lookupDirect(Path, Kind); - if (Found) { - return Found; - } - Curr = Curr->getParentScope(); - } while (Curr != nullptr); - return nullptr; - } - - Scope* Scope::getParentScope() { - if (Source->Parent == nullptr) { - return nullptr; +const SourceFile* Node::getSourceFile() const { + const Node* CurrNode = this; + for (;;) { + if (CurrNode->Kind == NodeKind::SourceFile) { + return static_cast(CurrNode); } - return Source->Parent->getScope(); + CurrNode = CurrNode->Parent; + ZEN_ASSERT(CurrNode != nullptr); } - - const SourceFile* Node::getSourceFile() const { - const Node* CurrNode = this; - for (;;) { - if (CurrNode->Kind == NodeKind::SourceFile) { - return static_cast(CurrNode); - } - CurrNode = CurrNode->Parent; - ZEN_ASSERT(CurrNode != nullptr); +} +SourceFile* Node::getSourceFile() { + Node* CurrNode = this; + for (;;) { + if (CurrNode->Kind == NodeKind::SourceFile) { + return static_cast(CurrNode); } + CurrNode = CurrNode->Parent; + ZEN_ASSERT(CurrNode != nullptr); } - SourceFile* Node::getSourceFile() { - Node* CurrNode = this; - for (;;) { - if (CurrNode->Kind == NodeKind::SourceFile) { - return static_cast(CurrNode); - } - CurrNode = CurrNode->Parent; - ZEN_ASSERT(CurrNode != nullptr); +} + +std::size_t Node::getStartLine() const { + return getFirstToken()->getStartLine(); +} + +std::size_t Node::getStartColumn() const { + return getFirstToken()->getStartColumn(); +} + +std::size_t Node::getEndLine() const { + return getLastToken()->getEndLine(); +} + +std::size_t Node::getEndColumn() const { + return getLastToken()->getEndColumn(); +} + +TextRange Node::getRange() const { + return TextRange { + getFirstToken()->getStartLoc(), + getLastToken()->getEndLoc(), + }; +} + +Scope* Node::getScope() { + return Parent->getScope(); +} + +TextLoc Token::getEndLoc() const { + auto Loc = StartLoc; + Loc.advance(getText()); + return Loc; +} + +void Node::setParents() { + + struct SetParentsVisitor : public CSTVisitor { + + std::vector Parents { nullptr }; + + void visit(Node* N) { + N->Parent = Parents.back(); + Parents.push_back(N); + visitEachChild(N); + Parents.pop_back(); } - } - std::size_t Node::getStartLine() const { - return getFirstToken()->getStartLine(); - } + }; - std::size_t Node::getStartColumn() const { - return getFirstToken()->getStartColumn(); - } + SetParentsVisitor V; + V.visit(this); - std::size_t Node::getEndLine() const { - return getLastToken()->getEndLine(); - } +} - std::size_t Node::getEndColumn() const { - return getLastToken()->getEndColumn(); - } +void Node::unref() { - TextRange Node::getRange() const { - return TextRange { - getFirstToken()->getStartLoc(), - getLastToken()->getEndLoc(), - }; - } + --RefCount; - Scope* Node::getScope() { - return Parent->getScope(); - } - - TextLoc Token::getEndLoc() const { - auto Loc = StartLoc; - Loc.advance(getText()); - return Loc; - } - - void Node::setParents() { - - struct SetParentsVisitor : public CSTVisitor { - - std::vector Parents { nullptr }; + if (RefCount == 0) { + // You may be wondering why we aren't unreffing the children in the + // destructor. This is due to a behaviour in Clang where a top-level + // destructor ~Node() wont get access to the fields in derived classes + // because they may already have been destroyed. + struct UnrefVisitor : public CSTVisitor { void visit(Node* N) { - N->Parent = Parents.back(); - Parents.push_back(N); - visitEachChild(N); - Parents.pop_back(); + N->unref(); } - }; + UnrefVisitor V; + V.visitEachChild(this); - SetParentsVisitor V; - V.visit(this); - - } - - void Node::unref() { - - --RefCount; - - if (RefCount == 0) { - - // You may be wondering why we aren't unreffing the children in the - // destructor. This is due to a behaviour in Clang where a top-level - // destructor ~Node() wont get access to the fields in derived classes - // because they may already have been destroyed. - struct UnrefVisitor : public CSTVisitor { - void visit(Node* N) { - N->unref(); - } - }; - UnrefVisitor V; - V.visitEachChild(this); - - delete this; - } - - } - - bool Identifier::isTypeVar() const { - for (auto C: Text) { - if (!((C >= 97 && C <= 122) || C == '_')) { - return false; - } - } - return true; - } - - Token* ExpressionAnnotation::getFirstToken() const { - return At; - } - - Token* ExpressionAnnotation::getLastToken() const { - return Expression->getLastToken(); - } - - Token* TypeAssertAnnotation::getFirstToken() const { - return At; - } - - Token* TypeAssertAnnotation::getLastToken() const { - return TE->getLastToken(); - } - - Token* TypeclassConstraintExpression::getFirstToken() const { - return Name; - } - - Token* TypeclassConstraintExpression::getLastToken() const { - if (!TEs.empty()) { - return TEs.back()->getLastToken(); - } - return Name; - } - - Token* EqualityConstraintExpression::getFirstToken() const { - return Left->getFirstToken(); - } - - Token* EqualityConstraintExpression::getLastToken() const { - return Left->getLastToken(); - } - - Token* RecordTypeExpressionField::getFirstToken() const { - return Name; - } - - Token* RecordTypeExpressionField::getLastToken() const { - return TE->getLastToken(); - } - - Token* RecordTypeExpression::getFirstToken() const { - return LBrace; - } - - Token* RecordTypeExpression::getLastToken() const { - return RBrace; - } - - Token* QualifiedTypeExpression::getFirstToken() const { - if (!Constraints.empty()) { - return std::get<0>(Constraints.front())->getFirstToken(); - } - return TE->getFirstToken(); - } - - Token* QualifiedTypeExpression::getLastToken() const { - return TE->getLastToken(); - } - - Token* ReferenceTypeExpression::getFirstToken() const { - if (!ModulePath.empty()) { - return std::get<0>(ModulePath.front()); - } - return Name; - } - - Token* ReferenceTypeExpression::getLastToken() const { - return Name; - } - - Token* ArrowTypeExpression::getFirstToken() const { - if (ParamTypes.size()) { - return ParamTypes.front()->getFirstToken(); - } - return ReturnType->getFirstToken(); - } - - Token* ArrowTypeExpression::getLastToken() const { - return ReturnType->getLastToken(); - } - - Token* AppTypeExpression::getFirstToken() const { - return Op->getFirstToken(); - } - - Token* AppTypeExpression::getLastToken() const { - if (Args.size()) { - return Args.back()->getLastToken(); - } - return Op->getLastToken(); - } - - Token* VarTypeExpression::getLastToken() const { - return Name; - } - - Token* VarTypeExpression::getFirstToken() const { - return Name; - } - - Token* NestedTypeExpression::getLastToken() const { - return LParen; - } - - Token* NestedTypeExpression::getFirstToken() const { - return RParen; - } - - Token* TupleTypeExpression::getLastToken() const { - return LParen; - } - - Token* TupleTypeExpression::getFirstToken() const { - return RParen; - } - - Token* WrappedOperator::getFirstToken() const { - return LParen; - } - - Token* WrappedOperator::getLastToken() const { - return RParen; - } - - Token* BindPattern::getFirstToken() const { - switch (Name->getKind()) { - case NodeKind::Identifier: - return static_cast(Name); - case NodeKind::IdentifierAlt: - return static_cast(Name); - case NodeKind::WrappedOperator: - return static_cast(Name)->LParen; - default: - ZEN_UNREACHABLE - } - } - - Token* BindPattern::getLastToken() const { - switch (Name->getKind()) { - case NodeKind::Identifier: - return static_cast(Name); - case NodeKind::IdentifierAlt: - return static_cast(Name); - case NodeKind::WrappedOperator: - return static_cast(Name)->RParen; - default: - ZEN_UNREACHABLE - } - } - - Token* LiteralPattern::getFirstToken() const { - return Literal; - } - - Token* LiteralPattern::getLastToken() const { - return Literal; - } - - Token* RecordPatternField::getFirstToken() const { - return Name; - } - - Token* RecordPatternField::getLastToken() const { - if (Pattern) { - return Pattern->getLastToken(); - } - if (Equals) { - return Equals; - } - return Name; - } - - Token* RecordPattern::getFirstToken() const { - return LBrace; - } - - Token* RecordPattern::getLastToken() const { - return RBrace; - } - - Token* NamedRecordPattern::getFirstToken() const { - if (!ModulePath.empty()) { - return std::get<0>(ModulePath.back()); - } - return Name; - } - - Token* NamedRecordPattern::getLastToken() const { - return RBrace; - } - - Token* NamedTuplePattern::getFirstToken() const { - return Name; - } - - Token* NamedTuplePattern::getLastToken() const { - if (Patterns.size()) { - return Patterns.back()->getLastToken(); - } - return Name; - } - - Token* TuplePattern::getFirstToken() const { - return LParen; - } - - Token* TuplePattern::getLastToken() const { - return RParen; - } - - Token* NestedPattern::getFirstToken() const { - return LParen; - } - - Token* NestedPattern::getLastToken() const { - return RParen; - } - - Token* ListPattern::getFirstToken() const { - return LBracket; - } - - Token* ListPattern::getLastToken() const { - return RBracket; - } - - Token* ReferenceExpression::getFirstToken() const { - if (!ModulePath.empty()) { - return std::get<0>(ModulePath.front()); - } - switch (Name->getKind()) { - case NodeKind::Identifier: - return static_cast(Name); - case NodeKind::IdentifierAlt: - return static_cast(Name); - case NodeKind::WrappedOperator: - return static_cast(Name)->LParen; - default: - ZEN_UNREACHABLE - } - } - - Token* ReferenceExpression::getLastToken() const { - switch (Name->getKind()) { - case NodeKind::Identifier: - return static_cast(Name); - case NodeKind::IdentifierAlt: - return static_cast(Name); - case NodeKind::WrappedOperator: - return static_cast(Name)->RParen; - default: - ZEN_UNREACHABLE - } - } - - Token* MatchCase::getFirstToken() const { - return Pattern->getFirstToken(); - } - - Token* MatchCase::getLastToken() const { - return Expression->getLastToken(); - } - - Token* MatchExpression::getFirstToken() const { - return MatchKeyword; - } - - Token* MatchExpression::getLastToken() const { - if (!Cases.empty()) { - return Cases.back()->getLastToken(); - } - return BlockStart; - } - - Token* RecordExpressionField::getFirstToken() const { - return Name; - } - - Token* RecordExpressionField::getLastToken() const { - return E->getLastToken(); - } - - Token* RecordExpression::getFirstToken() const { - return LBrace; - } - - Token* RecordExpression::getLastToken() const { - return RBrace; - } - - Token* MemberExpression::getFirstToken() const { - return E->getFirstToken(); - } - - Token* MemberExpression::getLastToken() const { - return Name; - } - - Token* TupleExpression::getFirstToken() const { - return LParen; - } - - Token* TupleExpression::getLastToken() const { - return RParen; - } - - Token* NestedExpression::getFirstToken() const { - return LParen; - } - - Token* NestedExpression::getLastToken() const { - return RParen; - } - - Token* LiteralExpression::getFirstToken() const { - return Token; - } - - Token* LiteralExpression::getLastToken() const { - return Token; - } - - Token* CallExpression::getFirstToken() const { - return Function->getFirstToken(); - } - - Token* CallExpression::getLastToken() const { - if (Args.size()) { - return Args.back()->getLastToken(); - } - return Function->getLastToken(); - } - - Token* InfixExpression::getFirstToken() const { - return Left->getFirstToken(); - } - - Token* InfixExpression::getLastToken() const { - return Right->getLastToken(); - } - - Token* PrefixExpression::getFirstToken() const { - return Operator; - } - - Token* PrefixExpression::getLastToken() const { - return Argument->getLastToken(); - } - - Token* ExpressionStatement::getFirstToken() const { - return Expression->getFirstToken(); - } - - Token* ExpressionStatement::getLastToken() const { - return Expression->getLastToken(); - } - - Token* ReturnStatement::getFirstToken() const { - return ReturnKeyword; - } - - Token* ReturnStatement::getLastToken() const { - if (Expression) { - return Expression->getLastToken(); - } - return ReturnKeyword; - } - - Token* IfStatementPart::getFirstToken() const { - return Keyword; - } - - Token* IfStatementPart::getLastToken() const { - if (Elements.size()) { - return Elements.back()->getLastToken(); - } - return BlockStart; - } - - Token* IfStatement::getFirstToken() const { - ZEN_ASSERT(Parts.size()); - return Parts.front()->getFirstToken(); - } - - Token* IfStatement::getLastToken() const { - ZEN_ASSERT(Parts.size()); - return Parts.back()->getLastToken(); - } - - Token* TypeAssert::getFirstToken() const { - return Colon; - } - - Token* TypeAssert::getLastToken() const { - return TypeExpression->getLastToken(); - } - - Token* Parameter::getFirstToken() const { - return Pattern->getFirstToken(); - } - - Token* Parameter::getLastToken() const { - if (TypeAssert) { - return TypeAssert->getLastToken(); - } - return Pattern->getLastToken(); - } - - Token* LetBlockBody::getFirstToken() const { - return BlockStart; - } - - Token* LetBlockBody::getLastToken() const { - if (Elements.size()) { - return Elements.back()->getLastToken(); - } - return BlockStart; - } - - Token* LetExprBody::getFirstToken() const { - return Equals; - } - - Token* LetExprBody::getLastToken() const { - return Expression->getLastToken(); - } - - Token* LetDeclaration::getFirstToken() const { - if (PubKeyword) { - return PubKeyword; - } - if (ForeignKeyword) { - return ForeignKeyword; - } - return LetKeyword; - } - - Token* LetDeclaration::getLastToken() const { - if (Body) { - return Body->getLastToken(); - } - if (TypeAssert) { - return TypeAssert->getLastToken(); - } - if (Params.size()) { - return Params.back()->getLastToken(); - } - return Pattern->getLastToken(); - } - - Token* RecordDeclarationField::getFirstToken() const { - return Name; - } - - Token* RecordDeclarationField::getLastToken() const { - return TypeExpression->getLastToken(); - } - - Token* RecordDeclaration::getFirstToken() const { - if (PubKeyword) { - return PubKeyword; - } - return StructKeyword; - } - - Token* RecordDeclaration::getLastToken() const { - if (Fields.size()) { - return Fields.back()->getLastToken(); - } - return BlockStart; - } - - Token* VariantDeclaration::getFirstToken() const { - if (PubKeyword) { - return PubKeyword; - } - return EnumKeyword; - } - - Token* VariantDeclaration::getLastToken() const { - if (Members.size()) { - return Members.back()->getLastToken(); - } - return BlockStart; - } - - Token* TupleVariantDeclarationMember::getFirstToken() const { - return Name; - } - - Token* TupleVariantDeclarationMember::getLastToken() const { - if (Elements.size()) { - return Elements.back()->getLastToken(); - } - return Name; - } - - Token* RecordVariantDeclarationMember::getFirstToken() const { - return Name; - } - - Token* RecordVariantDeclarationMember::getLastToken() const { - if (Fields.size()) { - return Fields.back()->getLastToken(); - } - return BlockStart; - } - - Token* InstanceDeclaration::getFirstToken() const { - return InstanceKeyword; - } - - Token* InstanceDeclaration::getLastToken() const { - if (!Elements.empty()) { - return Elements.back()->getLastToken(); - } - return BlockStart; - } - - Token* ClassDeclaration::getFirstToken() const { - if (PubKeyword != nullptr) { - return PubKeyword; - } - return ClassKeyword; - } - - Token* ClassDeclaration::getLastToken() const { - if (!Elements.empty()) { - return Elements.back()->getLastToken(); - } - return BlockStart; - } - - Token* SourceFile::getFirstToken() const { - if (Elements.size()) { - return Elements.front()->getFirstToken(); - } - return nullptr; - } - - Token* SourceFile::getLastToken() const { - if (Elements.size()) { - return Elements.back()->getLastToken(); - } - return nullptr; - } - - std::string VBar::getText() const { - return "|"; - } - - std::string Equals::getText() const { - return "="; - } - - std::string Colon::getText() const { - return ":"; - } - - std::string Comma::getText() const { - return ","; - } - - std::string RArrow::getText() const { - return "->"; - } - - std::string RArrowAlt::getText() const { - return "=>"; - } - - std::string Dot::getText() const { - return "."; - } - - std::string LParen::getText() const { - return "("; - } - - std::string RParen::getText() const { - return ")"; - } - - std::string LBracket::getText() const { - return "["; - } - - std::string RBracket::getText() const { - return "]"; - } - - std::string LBrace::getText() const { - return "{"; - } - - std::string RBrace::getText() const { - return "}"; - } - - std::string LetKeyword::getText() const { - return "let"; - } - - std::string ForeignKeyword::getText() const { - return "foreign"; - } - - std::string MutKeyword::getText() const { - return "mut"; - } - - std::string PubKeyword::getText() const { - return "pub"; - } - - std::string TypeKeyword::getText() const { - return "type"; - } - - std::string ReturnKeyword::getText() const { - return "return"; - } - - std::string IfKeyword::getText() const { - return "if"; - } - - std::string ElseKeyword::getText() const { - return "else"; - } - - std::string ElifKeyword::getText() const { - return "elif"; - } - - std::string MatchKeyword::getText() const { - return "match"; - } - - std::string ModKeyword::getText() const { - return "mod"; - } - - std::string StructKeyword::getText() const { - return "struct"; - } - - std::string EnumKeyword::getText() const { - return "enum"; - } - - std::string Invalid::getText() const { - return ""; - } - - std::string EndOfFile::getText() const { - return ""; - } - - std::string BlockStart::getText() const { - return "."; - } - - std::string BlockEnd::getText() const { - return ""; - } - - std::string LineFoldEnd::getText() const { - return ""; - } - - std::string CustomOperator::getText() const { - return Text; - } - - std::string Assignment::getText() const { - return Text + "="; - } - - std::string Identifier::getText() const { - return Text; - } - - std::string IdentifierAlt::getText() const { - return Text; - } - - std::string StringLiteral::getText() const { - return "\"" + Text + "\""; - } - - std::string IntegerLiteral::getText() const { - return std::to_string(V); - } - - std::string DotDot::getText() const { - return ".."; - } - - std::string Tilde::getText() const { - return "~"; - } - - std::string At::getText() const { - return "@"; - } - - std::string ClassKeyword::getText() const { - return "class"; - } - - std::string InstanceKeyword::getText() const { - return "instance"; - } - - ByteString getCanonicalText(const Symbol* N) { - switch (N->getKind()) { - case NodeKind::Identifier: - return static_cast(N)->Text; - case NodeKind::IdentifierAlt: - return static_cast(N)->Text; - case NodeKind::CustomOperator: - return static_cast(N)->Text; - case NodeKind::VBar: - return static_cast(N)->getText(); - case NodeKind::WrappedOperator: - return static_cast(N)->getOperator()->getText(); - default: - ZEN_UNREACHABLE - } - } - - LiteralValue StringLiteral::getValue() { - return Text; - } - - LiteralValue IntegerLiteral::getValue() { - return V; - } - - SymbolPath ReferenceExpression::getSymbolPath() const { - std::vector ModuleNames; - for (auto [Name, Dot]: ModulePath) { - ModuleNames.push_back(getCanonicalText(Name)); - } - return SymbolPath { ModuleNames, getCanonicalText(Name) }; + delete this; } } +bool Identifier::isTypeVar() const { + for (auto C: Text) { + if (!((C >= 97 && C <= 122) || C == '_')) { + return false; + } + } + return true; +} + +Token* ExpressionAnnotation::getFirstToken() const { + return At; +} + +Token* ExpressionAnnotation::getLastToken() const { + return Expression->getLastToken(); +} + +Token* TypeAssertAnnotation::getFirstToken() const { + return At; +} + +Token* TypeAssertAnnotation::getLastToken() const { + return TE->getLastToken(); +} + +Token* TypeclassConstraintExpression::getFirstToken() const { + return Name; +} + +Token* TypeclassConstraintExpression::getLastToken() const { + if (!TEs.empty()) { + return TEs.back()->getLastToken(); + } + return Name; +} + +Token* EqualityConstraintExpression::getFirstToken() const { + return Left->getFirstToken(); +} + +Token* EqualityConstraintExpression::getLastToken() const { + return Left->getLastToken(); +} + +Token* RecordTypeExpressionField::getFirstToken() const { + return Name; +} + +Token* RecordTypeExpressionField::getLastToken() const { + return TE->getLastToken(); +} + +Token* RecordTypeExpression::getFirstToken() const { + return LBrace; +} + +Token* RecordTypeExpression::getLastToken() const { + return RBrace; +} + +Token* QualifiedTypeExpression::getFirstToken() const { + if (!Constraints.empty()) { + return std::get<0>(Constraints.front())->getFirstToken(); + } + return TE->getFirstToken(); +} + +Token* QualifiedTypeExpression::getLastToken() const { + return TE->getLastToken(); +} + +Token* ReferenceTypeExpression::getFirstToken() const { + if (!ModulePath.empty()) { + return std::get<0>(ModulePath.front()); + } + return Name; +} + +Token* ReferenceTypeExpression::getLastToken() const { + return Name; +} + +Token* ArrowTypeExpression::getFirstToken() const { + if (ParamTypes.size()) { + return ParamTypes.front()->getFirstToken(); + } + return ReturnType->getFirstToken(); +} + +Token* ArrowTypeExpression::getLastToken() const { + return ReturnType->getLastToken(); +} + +Token* AppTypeExpression::getFirstToken() const { + return Op->getFirstToken(); +} + +Token* AppTypeExpression::getLastToken() const { + if (Args.size()) { + return Args.back()->getLastToken(); + } + return Op->getLastToken(); +} + +Token* VarTypeExpression::getLastToken() const { + return Name; +} + +Token* VarTypeExpression::getFirstToken() const { + return Name; +} + +Token* NestedTypeExpression::getLastToken() const { + return LParen; +} + +Token* NestedTypeExpression::getFirstToken() const { + return RParen; +} + +Token* TupleTypeExpression::getLastToken() const { + return LParen; +} + +Token* TupleTypeExpression::getFirstToken() const { + return RParen; +} + +Token* WrappedOperator::getFirstToken() const { + return LParen; +} + +Token* WrappedOperator::getLastToken() const { + return RParen; +} + +Token* BindPattern::getFirstToken() const { + switch (Name->getKind()) { + case NodeKind::Identifier: + return static_cast(Name); + case NodeKind::IdentifierAlt: + return static_cast(Name); + case NodeKind::WrappedOperator: + return static_cast(Name)->LParen; + default: + ZEN_UNREACHABLE + } +} + +Token* BindPattern::getLastToken() const { + switch (Name->getKind()) { + case NodeKind::Identifier: + return static_cast(Name); + case NodeKind::IdentifierAlt: + return static_cast(Name); + case NodeKind::WrappedOperator: + return static_cast(Name)->RParen; + default: + ZEN_UNREACHABLE + } +} + +Token* LiteralPattern::getFirstToken() const { + return Literal; +} + +Token* LiteralPattern::getLastToken() const { + return Literal; +} + +Token* RecordPatternField::getFirstToken() const { + return Name; +} + +Token* RecordPatternField::getLastToken() const { + if (Pattern) { + return Pattern->getLastToken(); + } + if (Equals) { + return Equals; + } + return Name; +} + +Token* RecordPattern::getFirstToken() const { + return LBrace; +} + +Token* RecordPattern::getLastToken() const { + return RBrace; +} + +Token* NamedRecordPattern::getFirstToken() const { + if (!ModulePath.empty()) { + return std::get<0>(ModulePath.back()); + } + return Name; +} + +Token* NamedRecordPattern::getLastToken() const { + return RBrace; +} + +Token* NamedTuplePattern::getFirstToken() const { + return Name; +} + +Token* NamedTuplePattern::getLastToken() const { + if (Patterns.size()) { + return Patterns.back()->getLastToken(); + } + return Name; +} + +Token* TuplePattern::getFirstToken() const { + return LParen; +} + +Token* TuplePattern::getLastToken() const { + return RParen; +} + +Token* NestedPattern::getFirstToken() const { + return LParen; +} + +Token* NestedPattern::getLastToken() const { + return RParen; +} + +Token* ListPattern::getFirstToken() const { + return LBracket; +} + +Token* ListPattern::getLastToken() const { + return RBracket; +} + +Token* ReferenceExpression::getFirstToken() const { + if (!ModulePath.empty()) { + return std::get<0>(ModulePath.front()); + } + switch (Name->getKind()) { + case NodeKind::Identifier: + return static_cast(Name); + case NodeKind::IdentifierAlt: + return static_cast(Name); + case NodeKind::WrappedOperator: + return static_cast(Name)->LParen; + default: + ZEN_UNREACHABLE + } +} + +Token* ReferenceExpression::getLastToken() const { + switch (Name->getKind()) { + case NodeKind::Identifier: + return static_cast(Name); + case NodeKind::IdentifierAlt: + return static_cast(Name); + case NodeKind::WrappedOperator: + return static_cast(Name)->RParen; + default: + ZEN_UNREACHABLE + } +} + +Token* MatchCase::getFirstToken() const { + return Pattern->getFirstToken(); +} + +Token* MatchCase::getLastToken() const { + return Expression->getLastToken(); +} + +Token* MatchExpression::getFirstToken() const { + return MatchKeyword; +} + +Token* MatchExpression::getLastToken() const { + if (!Cases.empty()) { + return Cases.back()->getLastToken(); + } + return BlockStart; +} + +Token* RecordExpressionField::getFirstToken() const { + return Name; +} + +Token* RecordExpressionField::getLastToken() const { + return E->getLastToken(); +} + +Token* RecordExpression::getFirstToken() const { + return LBrace; +} + +Token* RecordExpression::getLastToken() const { + return RBrace; +} + +Token* MemberExpression::getFirstToken() const { + return E->getFirstToken(); +} + +Token* MemberExpression::getLastToken() const { + return Name; +} + +Token* TupleExpression::getFirstToken() const { + return LParen; +} + +Token* TupleExpression::getLastToken() const { + return RParen; +} + +Token* NestedExpression::getFirstToken() const { + return LParen; +} + +Token* NestedExpression::getLastToken() const { + return RParen; +} + +Token* LiteralExpression::getFirstToken() const { + return Token; +} + +Token* LiteralExpression::getLastToken() const { + return Token; +} + +Token* CallExpression::getFirstToken() const { + return Function->getFirstToken(); +} + +Token* CallExpression::getLastToken() const { + if (Args.size()) { + return Args.back()->getLastToken(); + } + return Function->getLastToken(); +} + +Token* InfixExpression::getFirstToken() const { + return Left->getFirstToken(); +} + +Token* InfixExpression::getLastToken() const { + return Right->getLastToken(); +} + +Token* PrefixExpression::getFirstToken() const { + return Operator; +} + +Token* PrefixExpression::getLastToken() const { + return Argument->getLastToken(); +} + +Token* ExpressionStatement::getFirstToken() const { + return Expression->getFirstToken(); +} + +Token* ExpressionStatement::getLastToken() const { + return Expression->getLastToken(); +} + +Token* ReturnStatement::getFirstToken() const { + return ReturnKeyword; +} + +Token* ReturnStatement::getLastToken() const { + if (Expression) { + return Expression->getLastToken(); + } + return ReturnKeyword; +} + +Token* IfStatementPart::getFirstToken() const { + return Keyword; +} + +Token* IfStatementPart::getLastToken() const { + if (Elements.size()) { + return Elements.back()->getLastToken(); + } + return BlockStart; +} + +Token* IfStatement::getFirstToken() const { + ZEN_ASSERT(Parts.size()); + return Parts.front()->getFirstToken(); +} + +Token* IfStatement::getLastToken() const { + ZEN_ASSERT(Parts.size()); + return Parts.back()->getLastToken(); +} + +Token* TypeAssert::getFirstToken() const { + return Colon; +} + +Token* TypeAssert::getLastToken() const { + return TypeExpression->getLastToken(); +} + +Token* Parameter::getFirstToken() const { + return Pattern->getFirstToken(); +} + +Token* Parameter::getLastToken() const { + if (TypeAssert) { + return TypeAssert->getLastToken(); + } + return Pattern->getLastToken(); +} + +Token* LetBlockBody::getFirstToken() const { + return BlockStart; +} + +Token* LetBlockBody::getLastToken() const { + if (Elements.size()) { + return Elements.back()->getLastToken(); + } + return BlockStart; +} + +Token* LetExprBody::getFirstToken() const { + return Equals; +} + +Token* LetExprBody::getLastToken() const { + return Expression->getLastToken(); +} + +Token* LetDeclaration::getFirstToken() const { + if (PubKeyword) { + return PubKeyword; + } + if (ForeignKeyword) { + return ForeignKeyword; + } + return LetKeyword; +} + +Token* LetDeclaration::getLastToken() const { + if (Body) { + return Body->getLastToken(); + } + if (TypeAssert) { + return TypeAssert->getLastToken(); + } + if (Params.size()) { + return Params.back()->getLastToken(); + } + return Pattern->getLastToken(); +} + +Token* RecordDeclarationField::getFirstToken() const { + return Name; +} + +Token* RecordDeclarationField::getLastToken() const { + return TypeExpression->getLastToken(); +} + +Token* RecordDeclaration::getFirstToken() const { + if (PubKeyword) { + return PubKeyword; + } + return StructKeyword; +} + +Token* RecordDeclaration::getLastToken() const { + if (Fields.size()) { + return Fields.back()->getLastToken(); + } + return BlockStart; +} + +Token* VariantDeclaration::getFirstToken() const { + if (PubKeyword) { + return PubKeyword; + } + return EnumKeyword; +} + +Token* VariantDeclaration::getLastToken() const { + if (Members.size()) { + return Members.back()->getLastToken(); + } + return BlockStart; +} + +Token* TupleVariantDeclarationMember::getFirstToken() const { + return Name; +} + +Token* TupleVariantDeclarationMember::getLastToken() const { + if (Elements.size()) { + return Elements.back()->getLastToken(); + } + return Name; +} + +Token* RecordVariantDeclarationMember::getFirstToken() const { + return Name; +} + +Token* RecordVariantDeclarationMember::getLastToken() const { + if (Fields.size()) { + return Fields.back()->getLastToken(); + } + return BlockStart; +} + +Token* InstanceDeclaration::getFirstToken() const { + return InstanceKeyword; +} + +Token* InstanceDeclaration::getLastToken() const { + if (!Elements.empty()) { + return Elements.back()->getLastToken(); + } + return BlockStart; +} + +Token* ClassDeclaration::getFirstToken() const { + if (PubKeyword != nullptr) { + return PubKeyword; + } + return ClassKeyword; +} + +Token* ClassDeclaration::getLastToken() const { + if (!Elements.empty()) { + return Elements.back()->getLastToken(); + } + return BlockStart; +} + +Token* SourceFile::getFirstToken() const { + if (Elements.size()) { + return Elements.front()->getFirstToken(); + } + return nullptr; +} + +Token* SourceFile::getLastToken() const { + if (Elements.size()) { + return Elements.back()->getLastToken(); + } + return nullptr; +} + +std::string VBar::getText() const { + return "|"; +} + +std::string Equals::getText() const { + return "="; +} + +std::string Colon::getText() const { + return ":"; +} + +std::string Comma::getText() const { + return ","; +} + +std::string RArrow::getText() const { + return "->"; +} + +std::string RArrowAlt::getText() const { + return "=>"; +} + +std::string Dot::getText() const { + return "."; +} + +std::string LParen::getText() const { + return "("; +} + +std::string RParen::getText() const { + return ")"; +} + +std::string LBracket::getText() const { + return "["; +} + +std::string RBracket::getText() const { + return "]"; +} + +std::string LBrace::getText() const { + return "{"; +} + +std::string RBrace::getText() const { + return "}"; +} + +std::string LetKeyword::getText() const { + return "let"; +} + +std::string ForeignKeyword::getText() const { + return "foreign"; +} + +std::string MutKeyword::getText() const { + return "mut"; +} + +std::string PubKeyword::getText() const { + return "pub"; +} + +std::string TypeKeyword::getText() const { + return "type"; +} + +std::string ReturnKeyword::getText() const { + return "return"; +} + +std::string IfKeyword::getText() const { + return "if"; +} + +std::string ElseKeyword::getText() const { + return "else"; +} + +std::string ElifKeyword::getText() const { + return "elif"; +} + +std::string MatchKeyword::getText() const { + return "match"; +} + +std::string ModKeyword::getText() const { + return "mod"; +} + +std::string StructKeyword::getText() const { + return "struct"; +} + +std::string EnumKeyword::getText() const { + return "enum"; +} + +std::string Invalid::getText() const { + return ""; +} + +std::string EndOfFile::getText() const { + return ""; +} + +std::string BlockStart::getText() const { + return "."; +} + +std::string BlockEnd::getText() const { + return ""; +} + +std::string LineFoldEnd::getText() const { + return ""; +} + +std::string CustomOperator::getText() const { + return Text; +} + +std::string Assignment::getText() const { + return Text + "="; +} + +std::string Identifier::getText() const { + return Text; +} + +std::string IdentifierAlt::getText() const { + return Text; +} + +std::string StringLiteral::getText() const { + return "\"" + Text + "\""; +} + +std::string IntegerLiteral::getText() const { + return std::to_string(V); +} + +std::string DotDot::getText() const { + return ".."; +} + +std::string Tilde::getText() const { + return "~"; +} + +std::string At::getText() const { + return "@"; +} + +std::string ClassKeyword::getText() const { + return "class"; +} + +std::string InstanceKeyword::getText() const { + return "instance"; +} + +ByteString getCanonicalText(const Symbol* N) { + switch (N->getKind()) { + case NodeKind::Identifier: + return static_cast(N)->Text; + case NodeKind::IdentifierAlt: + return static_cast(N)->Text; + case NodeKind::CustomOperator: + return static_cast(N)->Text; + case NodeKind::VBar: + return static_cast(N)->getText(); + case NodeKind::WrappedOperator: + return static_cast(N)->getOperator()->getText(); + default: + ZEN_UNREACHABLE + } +} + +LiteralValue StringLiteral::getValue() { + return Text; +} + +LiteralValue IntegerLiteral::getValue() { + return V; +} + +SymbolPath ReferenceExpression::getSymbolPath() const { + std::vector ModuleNames; + for (auto [Name, Dot]: ModulePath) { + ModuleNames.push_back(getCanonicalText(Name)); + } + return SymbolPath { ModuleNames, getCanonicalText(Name) }; +} + +} + diff --git a/bootstrap/cxx/src/Checker.cc b/bootstrap/cxx/src/Checker.cc index a0ef8e2c2..f28fc9b46 100644 --- a/bootstrap/cxx/src/Checker.cc +++ b/bootstrap/cxx/src/Checker.cc @@ -14,1937 +14,1937 @@ namespace bolt { - Constraint* Constraint::substitute(const TVSub &Sub) { - switch (Kind) { - case ConstraintKind::Equal: - { - auto Equal = static_cast(this); - return new CEqual(Equal->Left->substitute(Sub), Equal->Right->substitute(Sub), Equal->Source); - } - case ConstraintKind::Many: - { - auto Many = static_cast(this); - auto NewConstraints = new ConstraintSet(); - for (auto Element: Many->Elements) { - NewConstraints->push_back(Element->substitute(Sub)); - } - return new CMany(*NewConstraints); - } - case ConstraintKind::Field: - { - auto Field = static_cast(this); - auto NewTupleTy = Field->TupleTy->substitute(Sub); - auto NewFieldTy = Field->FieldTy->substitute(Sub); - return new CField(NewTupleTy, Field->I, NewFieldTy, Field->Source); - } - case ConstraintKind::Empty: - return this; +Constraint* Constraint::substitute(const TVSub &Sub) { + switch (Kind) { + case ConstraintKind::Equal: + { + auto Equal = static_cast(this); + return new CEqual(Equal->Left->substitute(Sub), Equal->Right->substitute(Sub), Equal->Source); } - ZEN_UNREACHABLE + case ConstraintKind::Many: + { + auto Many = static_cast(this); + auto NewConstraints = new ConstraintSet(); + for (auto Element: Many->Elements) { + NewConstraints->push_back(Element->substitute(Sub)); + } + return new CMany(*NewConstraints); + } + case ConstraintKind::Field: + { + auto Field = static_cast(this); + auto NewTupleTy = Field->TupleTy->substitute(Sub); + auto NewFieldTy = Field->FieldTy->substitute(Sub); + return new CField(NewTupleTy, Field->I, NewFieldTy, Field->Source); + } + case ConstraintKind::Empty: + return this; + } + ZEN_UNREACHABLE +} + +Type* Checker::solveType(Type* Ty) { + return Ty->rewrite([this](auto Ty) { return Ty->find(); }, true); +} + +Checker::Checker(const LanguageConfig& Config, DiagnosticEngine& DE): + Config(Config), DE(DE) { + BoolType = createConType("Bool"); + IntType = createConType("Int"); + StringType = createConType("String"); + ListType = createConType("List"); + UnitType = new Type(TTuple({})); } - Type* Checker::solveType(Type* Ty) { - return Ty->rewrite([this](auto Ty) { return Ty->find(); }, true); +Scheme* Checker::lookup(ByteString Name, SymKind Kind) { + auto Curr = &getContext(); + for (;;) { + auto Match = Curr->Env.lookup(Name, Kind); + if (Match != nullptr) { + return Match; + } + Curr = Curr->Parent; + if (!Curr) { + break; + } } + return nullptr; +} - Checker::Checker(const LanguageConfig& Config, DiagnosticEngine& DE): - Config(Config), DE(DE) { - BoolType = createConType("Bool"); - IntType = createConType("Int"); - StringType = createConType("String"); - ListType = createConType("List"); - UnitType = new Type(TTuple({})); - } - - Scheme* Checker::lookup(ByteString Name, SymKind Kind) { - auto Curr = &getContext(); - for (;;) { - auto Match = Curr->Env.lookup(Name, Kind); - if (Match != nullptr) { - return Match; - } - Curr = Curr->Parent; - if (!Curr) { - break; - } - } +Type* Checker::lookupMono(ByteString Name, SymKind Kind) { + auto Scm = lookup(Name, Kind); + if (Scm == nullptr) { return nullptr; } + auto F = static_cast(Scm); + ZEN_ASSERT(F->TVs == nullptr || F->TVs->empty()); + return F->Type; +} - Type* Checker::lookupMono(ByteString Name, SymKind Kind) { - auto Scm = lookup(Name, Kind); - if (Scm == nullptr) { - return nullptr; +void Checker::addBinding(ByteString Name, Scheme* Scm, SymKind Kind) { + getContext().Env.add(Name, Scm, Kind); +} + +Type* Checker::getReturnType() { + auto Ty = getContext().ReturnType; + ZEN_ASSERT(Ty != nullptr); + return Ty; +} + +static bool hasTypeVar(TVSet& Set, Type* Type) { + for (auto TV: Type->getTypeVars()) { + if (Set.count(TV)) { + return true; } - auto F = static_cast(Scm); - ZEN_ASSERT(F->TVs == nullptr || F->TVs->empty()); - return F->Type; } + return false; +} - void Checker::addBinding(ByteString Name, Scheme* Scm, SymKind Kind) { - getContext().Env.add(Name, Scm, Kind); - } +void Checker::setContext(InferContext* Ctx) { + ActiveContext = Ctx; +} - Type* Checker::getReturnType() { - auto Ty = getContext().ReturnType; - ZEN_ASSERT(Ty != nullptr); - return Ty; - } +void Checker::popContext() { + ZEN_ASSERT(ActiveContext); + ActiveContext = ActiveContext->Parent; +} - static bool hasTypeVar(TVSet& Set, Type* Type) { - for (auto TV: Type->getTypeVars()) { - if (Set.count(TV)) { - return true; - } - } - return false; - } +InferContext& Checker::getContext() { + ZEN_ASSERT(ActiveContext); + return *ActiveContext; +} - void Checker::setContext(InferContext* Ctx) { - ActiveContext = Ctx; - } +void Checker::makeEqual(Type* A, Type* B, Node* Source) { + addConstraint(new CEqual(A, B, Source)); +} - void Checker::popContext() { - ZEN_ASSERT(ActiveContext); - ActiveContext = ActiveContext->Parent; - } +void Checker::addConstraint(Constraint* C) { - InferContext& Checker::getContext() { - ZEN_ASSERT(ActiveContext); - return *ActiveContext; - } + switch (C->getKind()) { - void Checker::makeEqual(Type* A, Type* B, Node* Source) { - addConstraint(new CEqual(A, B, Source)); - } + case ConstraintKind::Field: + // FIXME Check if this is all that needs to be done + getContext().Constraints->push_back(C); + break; - void Checker::addConstraint(Constraint* C) { + case ConstraintKind::Equal: + { + auto Y = static_cast(C); - switch (C->getKind()) { - - case ConstraintKind::Field: - // FIXME Check if this is all that needs to be done - getContext().Constraints->push_back(C); - break; - - case ConstraintKind::Equal: - { - auto Y = static_cast(C); - - // This will store all inference contexts in Contexts, from most local - // one to most general one. Because this order is not ideal, the code - // below will have to handle that. - auto Curr = &getContext(); - std::vector Contexts; - for (;;) { - Contexts.push_back(Curr); - Curr = Curr->Parent; - if (!Curr) { - break; - } - } - - std::size_t Global = Contexts.size()-1; - - // If no MaxLevelLeft was found, that means that not a single - // corresponding type variable was found in the contexts. We set it to - // Contexts.size()-1, which corresponds to the global inference context. - std::size_t MaxLevelLeft = Global; - for (std::size_t I = 0; I < Global; I++) { - auto Ctx = Contexts[I]; - if (hasTypeVar(*Ctx->TVs, Y->Left)) { - MaxLevelLeft = I; - break; - } - } - - // Same as above but now mirrored for Y->Right - std::size_t MaxLevelRight = Global; - for (std::size_t I = 0; I < Global; I++) { - auto Ctx = Contexts[I]; - if (hasTypeVar(*Ctx->TVs, Y->Right)) { - MaxLevelRight = I; - break; - } - } - - // The lowest index is determined by the one that has no type variables - // in Y->Left AND in Y->Right. This implies max() must be used, so that - // the very first enounter of a type variable matters. - auto UpperLevel = std::max(MaxLevelLeft, MaxLevelRight); - - // Now find the lowest index LowerLevel such that all the contexts that are more - // local do not contain any type variables that are present in the - // equality constraint. - std::size_t LowerLevel = UpperLevel; - for (std::size_t I = Global; I-- > 0; ) { - auto Ctx = Contexts[I]; - if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) { - LowerLevel = I; - break; - } - } - - if (UpperLevel == LowerLevel || MaxLevelLeft == Global || MaxLevelRight == Global) { - unify(Y->Left, Y->Right, Y->Source); - } else { - Contexts[UpperLevel]->Constraints->push_back(C); - } - - break; - } - - case ConstraintKind::Many: - { - auto Y = static_cast(C); - for (auto Element: Y->Elements) { - addConstraint(Element); - } - break; - } - - case ConstraintKind::Empty: - break; - - } - - } - - void Checker::forwardDeclare(Node* X) { - - switch (X->getKind()) { - - case NodeKind::ExpressionStatement: - case NodeKind::ReturnStatement: - case NodeKind::IfStatement: - break; - - case NodeKind::SourceFile: - { - auto File = static_cast(X); - for (auto Element: File->Elements) { - forwardDeclare(Element) ; - } - break; - } - - case NodeKind::ClassDeclaration: - { - auto Class = static_cast(X); - // for (auto TE: Class->TypeVars) { - // auto TV = new TVarRigid(NextTypeVarId++, TE->Name->getCanonicalText()); - // // TV->Contexts.emplace(Class->Name->getCanonicalText()); - // TE->setType(TV); - // } - for (auto Element: Class->Elements) { - forwardDeclare(Element); - } - break; - } - - case NodeKind::InstanceDeclaration: - { - auto Decl = static_cast(X); - - // Needed to set the associated Type on the CST node - for (auto TE: Decl->TypeExps) { - inferTypeExpression(TE); - } - - auto Match = InstanceMap.find(getCanonicalText(Decl->Name)); - if (Match == InstanceMap.end()) { - InstanceMap.emplace(getCanonicalText(Decl->Name), std::vector { Decl }); - } else { - Match->second.push_back(Decl); - } - - for (auto Element: Decl->Elements) { - forwardDeclare(Element); - } - - break; - } - - case NodeKind::LetDeclaration: - { - // Function declarations are handled separately in forwardDeclareLetDeclaration() and inferExpression() - auto Decl = static_cast(X); - if (!Decl->isVariable()) { + // This will store all inference contexts in Contexts, from most local + // one to most general one. Because this order is not ideal, the code + // below will have to handle that. + auto Curr = &getContext(); + std::vector Contexts; + for (;;) { + Contexts.push_back(Curr); + Curr = Curr->Parent; + if (!Curr) { break; } - Type* Ty; - if (Decl->TypeAssert) { - Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression); - } else { - Ty = createTypeVar(); - } - Decl->setType(Ty); - break; } - case NodeKind::VariantDeclaration: - { - auto Decl = static_cast(X); + std::size_t Global = Contexts.size()-1; - setContext(Decl->Ctx); - - std::vector Vars; - for (auto TE: Decl->TVs) { - auto TV = createRigidVar(getCanonicalText(TE->Name)); - Decl->Ctx->TVs->emplace(TV); - Decl->Ctx->Env.add(getCanonicalText(TE->Name), new Forall(TV), SymKind::Type); - Vars.push_back(TV); + // If no MaxLevelLeft was found, that means that not a single + // corresponding type variable was found in the contexts. We set it to + // Contexts.size()-1, which corresponds to the global inference context. + std::size_t MaxLevelLeft = Global; + for (std::size_t I = 0; I < Global; I++) { + auto Ctx = Contexts[I]; + if (hasTypeVar(*Ctx->TVs, Y->Left)) { + MaxLevelLeft = I; + break; } - - Type* Ty = createConType(getCanonicalText(Decl->Name)); - - // Build the type that is actually returned by constructor functions - auto RetTy = Ty; - for (auto Var: Vars) { - RetTy = new Type(TApp(RetTy, Var)); - } - - // Must be added early so we can create recursive types - Decl->Ctx->Parent->Env.add(getCanonicalText(Decl->Name), new Forall(Ty), SymKind::Type); - - for (auto Member: Decl->Members) { - switch (Member->getKind()) { - case NodeKind::TupleVariantDeclarationMember: - { - auto TupleMember = static_cast(Member); - std::vector ParamTypes; - for (auto Element: TupleMember->Elements) { - // inferTypeExpression will look up any TVars that were part of the signature of Decl - ParamTypes.push_back(inferTypeExpression(Element, false)); - } - Decl->Ctx->Parent->Env.add( - getCanonicalText(TupleMember->Name), - new Forall( - Decl->Ctx->TVs, - Decl->Ctx->Constraints, - Type::buildArrow(ParamTypes, RetTy) - ), - SymKind::Var - ); - break; - } - case NodeKind::RecordVariantDeclarationMember: - { - // TODO - break; - } - default: - ZEN_UNREACHABLE - } - } - - popContext(); - - break; } - case NodeKind::RecordDeclaration: - { - auto Decl = static_cast(X); - - setContext(Decl->Ctx); - - std::vector Vars; - for (auto TE: Decl->Vars) { - auto TV = createRigidVar(getCanonicalText(TE->Name)); - Decl->Ctx->TVs->emplace(TV); - Decl->Ctx->Env.add(getCanonicalText(TE->Name), new Forall(TV), SymKind::Type); - Vars.push_back(TV); + // Same as above but now mirrored for Y->Right + std::size_t MaxLevelRight = Global; + for (std::size_t I = 0; I < Global; I++) { + auto Ctx = Contexts[I]; + if (hasTypeVar(*Ctx->TVs, Y->Right)) { + MaxLevelRight = I; + break; } - - auto Name = getCanonicalText(Decl->Name); - auto Ty = createConType(Name); - - // Must be added early so we can create recursive types - Decl->Ctx->Parent->Env.add(Name, new Forall(Ty), SymKind::Type); - - Type* RetTy = Ty; - for (auto TV: Vars) { - RetTy = new Type(TApp(RetTy, TV)); - } - - // Corresponds to the logic of one branch of a VariantDeclarationMember - Type* FieldsTy = new Type(TNil()); - for (auto Field: Decl->Fields) { - FieldsTy = new Type( - TField( - getCanonicalText(Field->Name), - new Type(TPresent(inferTypeExpression(Field->TypeExpression, false))), - FieldsTy - ) - ); - } - Decl->Ctx->Parent->Env.add( - Name, - new Forall( - Decl->Ctx->TVs, - Decl->Ctx->Constraints, - new Type(TArrow(FieldsTy, RetTy)) - ), - SymKind::Var - ); - - popContext(); - - break; } - default: - ZEN_UNREACHABLE + // The lowest index is determined by the one that has no type variables + // in Y->Left AND in Y->Right. This implies max() must be used, so that + // the very first enounter of a type variable matters. + auto UpperLevel = std::max(MaxLevelLeft, MaxLevelRight); + // Now find the lowest index LowerLevel such that all the contexts that are more + // local do not contain any type variables that are present in the + // equality constraint. + std::size_t LowerLevel = UpperLevel; + for (std::size_t I = Global; I-- > 0; ) { + auto Ctx = Contexts[I]; + if (hasTypeVar(*Ctx->TVs, Y->Left) || hasTypeVar(*Ctx->TVs, Y->Right)) { + LowerLevel = I; + break; + } + } + + if (UpperLevel == LowerLevel || MaxLevelLeft == Global || MaxLevelRight == Global) { + unify(Y->Left, Y->Right, Y->Source); + } else { + Contexts[UpperLevel]->Constraints->push_back(C); + } + + break; } + case ConstraintKind::Many: + { + auto Y = static_cast(C); + for (auto Element: Y->Elements) { + addConstraint(Element); + } + break; + } + + case ConstraintKind::Empty: + break; + } - void Checker::initialize(Node* N) { +} - struct Init : public CSTVisitor { +void Checker::forwardDeclare(Node* X) { - Checker& C; + switch (X->getKind()) { - std::stack Contexts; + case NodeKind::ExpressionStatement: + case NodeKind::ReturnStatement: + case NodeKind::IfStatement: + break; - InferContext* createDerivedContext() { - return C.createInferContext(Contexts.top()); + case NodeKind::SourceFile: + { + auto File = static_cast(X); + for (auto Element: File->Elements) { + forwardDeclare(Element) ; } - - void visitVariantDeclaration(VariantDeclaration* Decl) { - Decl->Ctx = createDerivedContext(); - } - - void visitRecordDeclaration(RecordDeclaration* Decl) { - Decl->Ctx = createDerivedContext(); - } - - void visitMatchCase(MatchCase* C) { - C->Ctx = createDerivedContext(); - Contexts.push(C->Ctx); - visitEachChild(C); - Contexts.pop(); - } - - void visitSourceFile(SourceFile* SF) { - SF->Ctx = C.createInferContext(); - Contexts.push(SF->Ctx); - visitEachChild(SF); - Contexts.pop(); - } - - void visitLetDeclaration(LetDeclaration* Let) { - if (Let->isFunction()) { - Let->Ctx = createDerivedContext(); - Contexts.push(Let->Ctx); - visitEachChild(Let); - Contexts.pop(); - } - } - - // void visitVariableDeclaration(VariableDeclaration* Var) { - // Var->Ctx = Contexts.top(); - // visitEachChild(Var); - // } - - }; - - Init I { {}, *this }; - I.visit(N); - - } - - void Checker::forwardDeclareFunctionDeclaration(LetDeclaration* Let, TVSet* TVs, ConstraintSet* Constraints) { - - if (!Let->isFunction()) { - return; + break; } - // std::cerr << "declare " << Let->getNameAsString() << std::endl; - - setContext(Let->Ctx); - - auto addClassVars = [&](ClassDeclaration* Class, bool IsRigid) { - auto Id = getCanonicalText(Class->Name); - auto Ctx = &getContext(); - std::vector Out; - for (auto TE: Class->TypeVars) { - auto Name = getCanonicalText(TE->Name); - auto TV = IsRigid ? createRigidVar(Name) : createTypeVar(); - TV->asVar().Context.emplace(Id); - Ctx->Env.add(Name, new Forall(TV), SymKind::Type); - Out.push_back(TV); - } - return Out; - }; - - // If declaring a let-declaration inside a type class declaration, - // we need to mark that the let-declaration requires this class. - // This marking is set on the rigid type variables of the class, which - // are then added to this local type environment. - if (Let->isClass()) { - addClassVars(static_cast(Let->Parent), true); - } - - // Here we infer the primary type of the let declaration. If there's a - // type assert, that assert should be authoritative so we use that. - // Otherwise, the type is not further specified and we create a new - // unification variable. - Type* Ty; - if (Let->TypeAssert) { - Ty = inferTypeExpression(Let->TypeAssert->TypeExpression); - } else { - Ty = createTypeVar(); - } - Let->setType(Ty); - - // If declaring a let-declaration inside a type instance declaration, - // we need to perform some work to make sure the type asserts of the - // corresponding let-declaration in the type class declaration are - // accounted for. - if (Let->isInstance()) { - - auto Instance = static_cast(Let->Parent); - auto Class = cast(Instance->getScope()->lookup({ {}, getCanonicalText(Instance->Name) }, SymbolKind::Class)); - // TODO check if `Class` is nullptr - auto SigLet = cast(Class->getScope()->lookupDirect({ {}, Let->getNameAsString() }, SymbolKind::Var)); - - auto Params = addClassVars(Class, false); - - // The type asserts in the type class declaration might make use of - // the type parameters of the type class declaration, so it is - // important to make them available in the type environment. Moreover, - // we will be unifying them with the actual types declared in the - // instance declaration, so we keep track of them. - // std::vector Params; - // TVSub Sub; + case NodeKind::ClassDeclaration: + { + auto Class = static_cast(X); // for (auto TE: Class->TypeVars) { - // auto TV = createTypeVar(); - // Sub.emplace(cast(TE->getType()), TV); - // Params.push_back(TV); + // auto TV = new TVarRigid(NextTypeVarId++, TE->Name->getCanonicalText()); + // // TV->Contexts.emplace(Class->Name->getCanonicalText()); + // TE->setType(TV); // } + for (auto Element: Class->Elements) { + forwardDeclare(Element); + } + break; + } - // Here we do the actual unification of e.g. Eq a with Eq Bool. The - // unification variables we created previously will be unified with - // e.g. Bool, which causes the type assert to also collapse to e.g. - // Bool -> Bool -> Bool. - for (auto [Param, TE] : zen::zip(Params, Instance->TypeExps)) { - makeEqual(Param, TE->getType(), TE); + case NodeKind::InstanceDeclaration: + { + auto Decl = static_cast(X); + + // Needed to set the associated Type on the CST node + for (auto TE: Decl->TypeExps) { + inferTypeExpression(TE); } - // It would be very strange if there was no type assert in the type - // class let-declaration but we rather not let the compiler crash if that happens. - if (SigLet->TypeAssert) { - // Note that we can't do SigLet->TypeAssert->TypeExpression->getType() - // because we need to re-generate the type within the local context of - // this let-declaration. - // TODO make CEqual accept multiple nodes - makeEqual(Ty, inferTypeExpression(SigLet->TypeAssert->TypeExpression), Let); + auto Match = InstanceMap.find(getCanonicalText(Decl->Name)); + if (Match == InstanceMap.end()) { + InstanceMap.emplace(getCanonicalText(Decl->Name), std::vector { Decl }); + } else { + Match->second.push_back(Decl); } - } - - if (Let->Body) { - switch (Let->Body->getKind()) { - case NodeKind::LetExprBody: - break; - case NodeKind::LetBlockBody: - { - auto Block = static_cast(Let->Body); - Let->Ctx->ReturnType = createTypeVar(); - for (auto Element: Block->Elements) { - forwardDeclare(Element); - } - break; - } - default: - ZEN_UNREACHABLE + for (auto Element: Decl->Elements) { + forwardDeclare(Element); } + + break; } - if (!Let->isInstance()) { - Let->Ctx->Parent->Env.add(Let->getNameAsString(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty), SymKind::Var); - } - - } - - void Checker::inferFunctionDeclaration(LetDeclaration* Decl) { - - if (!Decl->isFunction()) { - return; - } - - // std::cerr << "infer " << Decl->getNameAsString() << std::endl; - - auto OldCtx = ActiveContext; - setContext(Decl->Ctx); - - std::vector ParamTypes; - Type* RetType; - - for (auto Param: Decl->Params) { - ParamTypes.push_back(inferPattern(Param->Pattern)); - } - - if (Decl->Body) { - switch (Decl->Body->getKind()) { - case NodeKind::LetExprBody: - { - auto Expr = static_cast(Decl->Body); - RetType = inferExpression(Expr->Expression); - break; - } - case NodeKind::LetBlockBody: - { - auto Block = static_cast(Decl->Body); - RetType = Decl->Ctx->ReturnType; - for (auto Element: Block->Elements) { - infer(Element); - } - break; - } - default: - ZEN_UNREACHABLE - } - } else { - RetType = createTypeVar(); - } - - makeEqual(Decl->getType(), Type::buildArrow(ParamTypes, RetType), Decl); - - setContext(OldCtx); - } - - void Checker::infer(Node* N) { - - switch (N->getKind()) { - - case NodeKind::SourceFile: - { - auto File = static_cast(N); - for (auto Element: File->Elements) { - infer(Element); - } + case NodeKind::LetDeclaration: + { + // Function declarations are handled separately in forwardDeclareLetDeclaration() and inferExpression() + auto Decl = static_cast(X); + if (!Decl->isVariable()) { break; } - - case NodeKind::ClassDeclaration: - { - auto Decl = static_cast(N); - for (auto Element: Decl->Elements) { - infer(Element); - } - break; - } - - case NodeKind::InstanceDeclaration: - { - auto Decl = static_cast(N); - for (auto Element: Decl->Elements) { - infer(Element); - } - break; - } - - case NodeKind::VariantDeclaration: - case NodeKind::RecordDeclaration: - // Nothing to do for a type-level declaration - break; - - case NodeKind::IfStatement: - { - auto IfStmt = static_cast(N); - for (auto Part: IfStmt->Parts) { - if (Part->Test != nullptr) { - makeEqual(BoolType, inferExpression(Part->Test), Part->Test); - } - for (auto Element: Part->Elements) { - infer(Element); - } - } - break; - } - - case NodeKind::ReturnStatement: - { - auto RetStmt = static_cast(N); - Type* ReturnType; - if (RetStmt->Expression) { - makeEqual(inferExpression(RetStmt->Expression), getReturnType(), RetStmt->Expression); - } else { - ReturnType = UnitType; - makeEqual(UnitType, getReturnType(), N); - } - break; - } - - case NodeKind::LetDeclaration: - { - // Function declarations are handled separately in inferFunctionDeclaration() - auto Decl = static_cast(N); - if (Decl->Visited) { - break; - } - if (Decl->isFunction()) { - Decl->IsCycleActive = true; - Decl->Visited = true; - inferFunctionDeclaration(Decl); - Decl->IsCycleActive = false; - } else if (Decl->isVariable()) { - auto Ty = Decl->getType(); - if (Decl->Body) { - ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody); - auto E = static_cast(Decl->Body); - auto Ty2 = inferExpression(E->Expression); - makeEqual(Ty, Ty2, Decl); - } - auto Ty3 = inferPattern(Decl->Pattern); - makeEqual(Ty, Ty3, Decl); - } - break; - } - - case NodeKind::ExpressionStatement: - { - auto ExprStmt = static_cast(N); - inferExpression(ExprStmt->Expression); - break; - } - - default: - ZEN_UNREACHABLE - - } - - } - - Type* Checker::createConType(ByteString Name) { - return new Type(TCon(NextConTypeId++, Name)); - } - - Type* Checker::createRigidVar(ByteString Name) { - auto TV = new Type(TVar(VarKind::Rigid, NextTypeVarId++, {}, Name, {{}})); - getContext().TVs->emplace(TV); - return TV; - } - - Type* Checker::createTypeVar() { - auto TV = new Type(TVar(VarKind::Unification, NextTypeVarId++, {})); - getContext().TVs->emplace(TV); - return TV; - } - - InferContext* Checker::createInferContext(InferContext* Parent, TVSet* TVs, ConstraintSet* Constraints) { - auto Ctx = new InferContext; - Ctx->Parent = Parent; - Ctx->TVs = new TVSet; - Ctx->Constraints = new ConstraintSet; - return Ctx; - } - - Type* Checker::instantiate(Scheme* Scm, Node* Source) { - - switch (Scm->getKind()) { - - case SchemeKind::Forall: - { - auto F = static_cast(Scm); - - TVSub Sub; - for (auto TV: *F->TVs) { - auto Fresh = createTypeVar(); - // std::cerr << describe(TV) << " => " << describe(Fresh) << std::endl; - Fresh->asVar().Context = TV->asVar().Context; - Sub[TV] = Fresh; - } - - for (auto Constraint: *F->Constraints) { - - // FIXME improve this - if (Constraint->getKind() == ConstraintKind::Equal) { - auto Eq = static_cast(Constraint); - Eq->Left = solveType(Eq->Left); - Eq->Right = solveType(Eq->Right); - } - - auto NewConstraint = Constraint->substitute(Sub); - - // This makes error messages prettier by relating the typing failure - // to the call site rather than the definition. - if (NewConstraint->getKind() == ConstraintKind::Equal) { - auto Eq = static_cast(Constraint); - Eq->Source = Source; - } - - addConstraint(NewConstraint); - } - - // This call to solve happens because constraints may have already - // been solved, with some unification variables being erased. To make - // sure we instantiate unification variables that are still in use - // we solve before substituting. - return solveType(F->Type)->substitute(Sub); - } - - } - - ZEN_UNREACHABLE - } - - void Checker::inferConstraintExpression(ConstraintExpression* C) { - switch (C->getKind()) { - case NodeKind::TypeclassConstraintExpression: - { - auto D = static_cast(C); - std::vector Types; - for (auto TE: D->TEs) { - auto Ty = inferTypeExpression(TE); - Ty->asVar().Provided->emplace(getCanonicalText(D->Name)); - Types.push_back(Ty); - } - break; - } - case NodeKind::EqualityConstraintExpression: - { - auto D = static_cast(C); - makeEqual(inferTypeExpression(D->Left), inferTypeExpression(D->Right), C); - break; - } - default: - ZEN_UNREACHABLE - } - } - - Type* Checker::inferTypeExpression(TypeExpression* N, bool AutoVars) { - - switch (N->getKind()) { - - case NodeKind::ReferenceTypeExpression: - { - auto RefTE = static_cast(N); - auto Scm = lookup(getCanonicalText(RefTE->Name), SymKind::Type); - Type* Ty; - if (Scm == nullptr) { - DE.add(getCanonicalText(RefTE->Name), RefTE->Name); - Ty = createTypeVar(); - } else { - Ty = instantiate(Scm, RefTE); - } - N->setType(Ty); - return Ty; - } - - case NodeKind::AppTypeExpression: - { - auto AppTE = static_cast(N); - Type* Ty = inferTypeExpression(AppTE->Op, AutoVars); - for (auto Arg: AppTE->Args) { - Ty = new Type(TApp(Ty, inferTypeExpression(Arg, AutoVars))); - } - N->setType(Ty); - return Ty; - } - - case NodeKind::VarTypeExpression: - { - auto VarTE = static_cast(N); - auto Ty = lookupMono(getCanonicalText(VarTE->Name), SymKind::Type); - if (Ty == nullptr) { - if (!AutoVars || Config.typeVarsRequireForall()) { - DE.add(getCanonicalText(VarTE->Name), VarTE->Name); - } - Ty = createRigidVar(getCanonicalText(VarTE->Name)); - addBinding(getCanonicalText(VarTE->Name), new Forall(Ty), SymKind::Type); - } - ZEN_ASSERT(Ty->isVar()); - N->setType(Ty); - return Ty; - } - - case NodeKind::RecordTypeExpression: - { - auto RecTE = static_cast(N); - auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, AutoVars) : new Type(TNil()); - for (auto [Field, Comma]: RecTE->Fields) { - Ty = new Type(TField(getCanonicalText(Field->Name), new Type(TPresent(inferTypeExpression(Field->TE, AutoVars))), Ty)); - } - N->setType(Ty); - return Ty; - } - - case NodeKind::TupleTypeExpression: - { - auto TupleTE = static_cast(N); - std::vector ElementTypes; - for (auto [TE, Comma]: TupleTE->Elements) { - ElementTypes.push_back(inferTypeExpression(TE, AutoVars)); - } - auto Ty = new Type(TTuple(ElementTypes)); - N->setType(Ty); - return Ty; - } - - case NodeKind::NestedTypeExpression: - { - auto NestedTE = static_cast(N); - auto Ty = inferTypeExpression(NestedTE->TE, AutoVars); - N->setType(Ty); - return Ty; - } - - case NodeKind::ArrowTypeExpression: - { - auto ArrowTE = static_cast(N); - std::vector ParamTypes; - for (auto ParamType: ArrowTE->ParamTypes) { - ParamTypes.push_back(inferTypeExpression(ParamType, AutoVars)); - } - auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, AutoVars); - auto Ty = Type::buildArrow(ParamTypes, ReturnType); - N->setType(Ty); - return Ty; - } - - case NodeKind::QualifiedTypeExpression: - { - auto QTE = static_cast(N); - for (auto [C, Comma]: QTE->Constraints) { - inferConstraintExpression(C); - } - auto Ty = inferTypeExpression(QTE->TE, AutoVars); - N->setType(Ty); - return Ty; - } - - default: - ZEN_UNREACHABLE - - } - } - - Type* sortRow(Type* Ty) { - std::map Fields; - while (Ty->isField()) { - auto& Field = Ty->asField(); - Fields.emplace(Field.Name, Ty); - Ty = Field.RestTy; - } - for (auto [Name, Field]: Fields) { - Ty = new Type(TField(Name, Field->asField().Ty, Ty)); - } - return Ty; - } - - Type* Checker::inferExpression(Expression* X) { - - Type* Ty; - - for (auto A: X->Annotations) { - if (A->getKind() == NodeKind::TypeAssertAnnotation) { - inferTypeExpression(static_cast(A)->TE); - } - } - - switch (X->getKind()) { - - case NodeKind::MatchExpression: - { - auto Match = static_cast(X); - Type* ValTy; - if (Match->Value) { - ValTy = inferExpression(Match->Value); - } else { - ValTy = createTypeVar(); - } + Type* Ty; + if (Decl->TypeAssert) { + Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression); + } else { Ty = createTypeVar(); - for (auto Case: Match->Cases) { - auto OldCtx = &getContext(); - setContext(Case->Ctx); - auto PattTy = inferPattern(Case->Pattern); - makeEqual(PattTy, ValTy, Case); - auto ExprTy = inferExpression(Case->Expression); - makeEqual(ExprTy, Ty, Case->Expression); - setContext(OldCtx); - } - if (!Match->Value) { - Ty = new Type(TArrow(ValTy, Ty)); - } - break; + } + Decl->setType(Ty); + break; + } + + case NodeKind::VariantDeclaration: + { + auto Decl = static_cast(X); + + setContext(Decl->Ctx); + + std::vector Vars; + for (auto TE: Decl->TVs) { + auto TV = createRigidVar(getCanonicalText(TE->Name)); + Decl->Ctx->TVs->emplace(TV); + Decl->Ctx->Env.add(getCanonicalText(TE->Name), new Forall(TV), SymKind::Type); + Vars.push_back(TV); } - case NodeKind::RecordExpression: - { - auto Record = static_cast(X); - Ty = new Type(TNil()); - for (auto [Field, Comma]: Record->Fields) { - Ty = new Type(TField( - getCanonicalText(Field->Name), - new Type(TPresent(inferExpression(Field->getExpression()))), - Ty - )); - } - Ty = sortRow(Ty); - break; + Type* Ty = createConType(getCanonicalText(Decl->Name)); + + // Build the type that is actually returned by constructor functions + auto RetTy = Ty; + for (auto Var: Vars) { + RetTy = new Type(TApp(RetTy, Var)); } - case NodeKind::LiteralExpression: - { - auto Const = static_cast(X); - Ty = inferLiteral(Const->Token); - break; - } + // Must be added early so we can create recursive types + Decl->Ctx->Parent->Env.add(getCanonicalText(Decl->Name), new Forall(Ty), SymKind::Type); - case NodeKind::ReferenceExpression: - { - auto Ref = static_cast(X); - ZEN_ASSERT(Ref->ModulePath.empty()); - if (Ref->Name->is()) { - auto Scm = lookup(getCanonicalText(Ref->Name), SymKind::Var); - if (!Scm) { - DE.add(getCanonicalText(Ref->Name), Ref->Name); - Ty = createTypeVar(); - break; - } - Ty = instantiate(Scm, X); - break; - } - auto Target = Ref->getScope()->lookup(Ref->getSymbolPath()); - if (!Target) { - DE.add(getCanonicalText(Ref->Name), Ref->Name); - Ty = createTypeVar(); - break; - } - if (Target->getKind() == NodeKind::LetDeclaration) { - auto Let = static_cast(Target); - if (Let->IsCycleActive) { - Ty = Let->getType(); - break; - } - if (!Let->Visited) { - infer(Let); - } - } - auto Scm = lookup(getCanonicalText(Ref->Name), SymKind::Var); - ZEN_ASSERT(Scm); - Ty = instantiate(Scm, X); - break; - } - - case NodeKind::CallExpression: - { - auto Call = static_cast(X); - auto OpTy = inferExpression(Call->Function); - Ty = createTypeVar(); - std::vector ArgTypes; - for (auto Arg: Call->Args) { - ArgTypes.push_back(inferExpression(Arg)); - } - makeEqual(OpTy, Type::buildArrow(ArgTypes, Ty), X); - break; - } - - case NodeKind::InfixExpression: - { - auto Infix = static_cast(X); - auto Scm = lookup(Infix->Operator->getText(), SymKind::Var); - if (Scm == nullptr) { - DE.add(Infix->Operator->getText(), Infix->Operator); - Ty = createTypeVar(); - break; - } - auto OpTy = instantiate(Scm, Infix->Operator); - Ty = createTypeVar(); - std::vector ArgTys; - ArgTys.push_back(inferExpression(Infix->Left)); - ArgTys.push_back(inferExpression(Infix->Right)); - makeEqual(Type::buildArrow(ArgTys, Ty), OpTy, X); - break; - } - - case NodeKind::TupleExpression: - { - auto Tuple = static_cast(X); - std::vector Types; - for (auto [E, Comma]: Tuple->Elements) { - Types.push_back(inferExpression(E)); - } - Ty = new Type(TTuple(Types)); - break; - } - - case NodeKind::MemberExpression: - { - auto Member = static_cast(X); - auto ExprTy = inferExpression(Member->E); - switch (Member->Name->getKind()) { - case NodeKind::IntegerLiteral: + for (auto Member: Decl->Members) { + switch (Member->getKind()) { + case NodeKind::TupleVariantDeclarationMember: { - auto I = static_cast(Member->Name); - Ty = createTypeVar(); - addConstraint(new CField(ExprTy, I->asInt(), Ty, Member)); + auto TupleMember = static_cast(Member); + std::vector ParamTypes; + for (auto Element: TupleMember->Elements) { + // inferTypeExpression will look up any TVars that were part of the signature of Decl + ParamTypes.push_back(inferTypeExpression(Element, false)); + } + Decl->Ctx->Parent->Env.add( + getCanonicalText(TupleMember->Name), + new Forall( + Decl->Ctx->TVs, + Decl->Ctx->Constraints, + Type::buildArrow(ParamTypes, RetTy) + ), + SymKind::Var + ); break; } - case NodeKind::Identifier: + case NodeKind::RecordVariantDeclarationMember: { - auto K = static_cast(Member->Name); - Ty = createTypeVar(); - auto RestTy = createTypeVar(); - makeEqual(new Type(TField(getCanonicalText(K), Ty, RestTy)), ExprTy, Member); + // TODO break; } default: ZEN_UNREACHABLE } - break; } - case NodeKind::NestedExpression: + popContext(); + + break; + } + + case NodeKind::RecordDeclaration: + { + auto Decl = static_cast(X); + + setContext(Decl->Ctx); + + std::vector Vars; + for (auto TE: Decl->Vars) { + auto TV = createRigidVar(getCanonicalText(TE->Name)); + Decl->Ctx->TVs->emplace(TV); + Decl->Ctx->Env.add(getCanonicalText(TE->Name), new Forall(TV), SymKind::Type); + Vars.push_back(TV); + } + + auto Name = getCanonicalText(Decl->Name); + auto Ty = createConType(Name); + + // Must be added early so we can create recursive types + Decl->Ctx->Parent->Env.add(Name, new Forall(Ty), SymKind::Type); + + Type* RetTy = Ty; + for (auto TV: Vars) { + RetTy = new Type(TApp(RetTy, TV)); + } + + // Corresponds to the logic of one branch of a VariantDeclarationMember + Type* FieldsTy = new Type(TNil()); + for (auto Field: Decl->Fields) { + FieldsTy = new Type( + TField( + getCanonicalText(Field->Name), + new Type(TPresent(inferTypeExpression(Field->TypeExpression, false))), + FieldsTy + ) + ); + } + Decl->Ctx->Parent->Env.add( + Name, + new Forall( + Decl->Ctx->TVs, + Decl->Ctx->Constraints, + new Type(TArrow(FieldsTy, RetTy)) + ), + SymKind::Var + ); + + popContext(); + + break; + } + + default: + ZEN_UNREACHABLE + + } + +} + +void Checker::initialize(Node* N) { + + struct Init : public CSTVisitor { + + Checker& C; + + std::stack Contexts; + + InferContext* createDerivedContext() { + return C.createInferContext(Contexts.top()); + } + + void visitVariantDeclaration(VariantDeclaration* Decl) { + Decl->Ctx = createDerivedContext(); + } + + void visitRecordDeclaration(RecordDeclaration* Decl) { + Decl->Ctx = createDerivedContext(); + } + + void visitMatchCase(MatchCase* C) { + C->Ctx = createDerivedContext(); + Contexts.push(C->Ctx); + visitEachChild(C); + Contexts.pop(); + } + + void visitSourceFile(SourceFile* SF) { + SF->Ctx = C.createInferContext(); + Contexts.push(SF->Ctx); + visitEachChild(SF); + Contexts.pop(); + } + + void visitLetDeclaration(LetDeclaration* Let) { + if (Let->isFunction()) { + Let->Ctx = createDerivedContext(); + Contexts.push(Let->Ctx); + visitEachChild(Let); + Contexts.pop(); + } + } + + // void visitVariableDeclaration(VariableDeclaration* Var) { + // Var->Ctx = Contexts.top(); + // visitEachChild(Var); + // } + + }; + + Init I { {}, *this }; + I.visit(N); + +} + +void Checker::forwardDeclareFunctionDeclaration(LetDeclaration* Let, TVSet* TVs, ConstraintSet* Constraints) { + + if (!Let->isFunction()) { + return; + } + + // std::cerr << "declare " << Let->getNameAsString() << std::endl; + + setContext(Let->Ctx); + + auto addClassVars = [&](ClassDeclaration* Class, bool IsRigid) { + auto Id = getCanonicalText(Class->Name); + auto Ctx = &getContext(); + std::vector Out; + for (auto TE: Class->TypeVars) { + auto Name = getCanonicalText(TE->Name); + auto TV = IsRigid ? createRigidVar(Name) : createTypeVar(); + TV->asVar().Context.emplace(Id); + Ctx->Env.add(Name, new Forall(TV), SymKind::Type); + Out.push_back(TV); + } + return Out; + }; + + // If declaring a let-declaration inside a type class declaration, + // we need to mark that the let-declaration requires this class. + // This marking is set on the rigid type variables of the class, which + // are then added to this local type environment. + if (Let->isClass()) { + addClassVars(static_cast(Let->Parent), true); + } + + // Here we infer the primary type of the let declaration. If there's a + // type assert, that assert should be authoritative so we use that. + // Otherwise, the type is not further specified and we create a new + // unification variable. + Type* Ty; + if (Let->TypeAssert) { + Ty = inferTypeExpression(Let->TypeAssert->TypeExpression); + } else { + Ty = createTypeVar(); + } + Let->setType(Ty); + + // If declaring a let-declaration inside a type instance declaration, + // we need to perform some work to make sure the type asserts of the + // corresponding let-declaration in the type class declaration are + // accounted for. + if (Let->isInstance()) { + + auto Instance = static_cast(Let->Parent); + auto Class = cast(Instance->getScope()->lookup({ {}, getCanonicalText(Instance->Name) }, SymbolKind::Class)); + // TODO check if `Class` is nullptr + auto SigLet = cast(Class->getScope()->lookupDirect({ {}, Let->getNameAsString() }, SymbolKind::Var)); + + auto Params = addClassVars(Class, false); + + // The type asserts in the type class declaration might make use of + // the type parameters of the type class declaration, so it is + // important to make them available in the type environment. Moreover, + // we will be unifying them with the actual types declared in the + // instance declaration, so we keep track of them. + // std::vector Params; + // TVSub Sub; + // for (auto TE: Class->TypeVars) { + // auto TV = createTypeVar(); + // Sub.emplace(cast(TE->getType()), TV); + // Params.push_back(TV); + // } + + // Here we do the actual unification of e.g. Eq a with Eq Bool. The + // unification variables we created previously will be unified with + // e.g. Bool, which causes the type assert to also collapse to e.g. + // Bool -> Bool -> Bool. + for (auto [Param, TE] : zen::zip(Params, Instance->TypeExps)) { + makeEqual(Param, TE->getType(), TE); + } + + // It would be very strange if there was no type assert in the type + // class let-declaration but we rather not let the compiler crash if that happens. + if (SigLet->TypeAssert) { + // Note that we can't do SigLet->TypeAssert->TypeExpression->getType() + // because we need to re-generate the type within the local context of + // this let-declaration. + // TODO make CEqual accept multiple nodes + makeEqual(Ty, inferTypeExpression(SigLet->TypeAssert->TypeExpression), Let); + } + + } + + if (Let->Body) { + switch (Let->Body->getKind()) { + case NodeKind::LetExprBody: + break; + case NodeKind::LetBlockBody: { - auto Nested = static_cast(X); - Ty = inferExpression(Nested->Inner); + auto Block = static_cast(Let->Body); + Let->Ctx->ReturnType = createTypeVar(); + for (auto Element: Block->Elements) { + forwardDeclare(Element); + } break; } - default: ZEN_UNREACHABLE + } + } + + if (!Let->isInstance()) { + Let->Ctx->Parent->Env.add(Let->getNameAsString(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty), SymKind::Var); + } + +} + +void Checker::inferFunctionDeclaration(LetDeclaration* Decl) { + + if (!Decl->isFunction()) { + return; + } + + // std::cerr << "infer " << Decl->getNameAsString() << std::endl; + + auto OldCtx = ActiveContext; + setContext(Decl->Ctx); + + std::vector ParamTypes; + Type* RetType; + + for (auto Param: Decl->Params) { + ParamTypes.push_back(inferPattern(Param->Pattern)); + } + + if (Decl->Body) { + switch (Decl->Body->getKind()) { + case NodeKind::LetExprBody: + { + auto Expr = static_cast(Decl->Body); + RetType = inferExpression(Expr->Expression); + break; + } + case NodeKind::LetBlockBody: + { + auto Block = static_cast(Decl->Body); + RetType = Decl->Ctx->ReturnType; + for (auto Element: Block->Elements) { + infer(Element); + } + break; + } + default: + ZEN_UNREACHABLE + } + } else { + RetType = createTypeVar(); + } + + makeEqual(Decl->getType(), Type::buildArrow(ParamTypes, RetType), Decl); + + setContext(OldCtx); +} + +void Checker::infer(Node* N) { + + switch (N->getKind()) { + + case NodeKind::SourceFile: + { + auto File = static_cast(N); + for (auto Element: File->Elements) { + infer(Element); + } + break; + } + + case NodeKind::ClassDeclaration: + { + auto Decl = static_cast(N); + for (auto Element: Decl->Elements) { + infer(Element); + } + break; + } + + case NodeKind::InstanceDeclaration: + { + auto Decl = static_cast(N); + for (auto Element: Decl->Elements) { + infer(Element); + } + break; + } + + case NodeKind::VariantDeclaration: + case NodeKind::RecordDeclaration: + // Nothing to do for a type-level declaration + break; + + case NodeKind::IfStatement: + { + auto IfStmt = static_cast(N); + for (auto Part: IfStmt->Parts) { + if (Part->Test != nullptr) { + makeEqual(BoolType, inferExpression(Part->Test), Part->Test); + } + for (auto Element: Part->Elements) { + infer(Element); + } + } + break; + } + + case NodeKind::ReturnStatement: + { + auto RetStmt = static_cast(N); + Type* ReturnType; + if (RetStmt->Expression) { + makeEqual(inferExpression(RetStmt->Expression), getReturnType(), RetStmt->Expression); + } else { + ReturnType = UnitType; + makeEqual(UnitType, getReturnType(), N); + } + break; + } + + case NodeKind::LetDeclaration: + { + // Function declarations are handled separately in inferFunctionDeclaration() + auto Decl = static_cast(N); + if (Decl->Visited) { + break; + } + if (Decl->isFunction()) { + Decl->IsCycleActive = true; + Decl->Visited = true; + inferFunctionDeclaration(Decl); + Decl->IsCycleActive = false; + } else if (Decl->isVariable()) { + auto Ty = Decl->getType(); + if (Decl->Body) { + ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody); + auto E = static_cast(Decl->Body); + auto Ty2 = inferExpression(E->Expression); + makeEqual(Ty, Ty2, Decl); + } + auto Ty3 = inferPattern(Decl->Pattern); + makeEqual(Ty, Ty3, Decl); + } + break; + } + + case NodeKind::ExpressionStatement: + { + auto ExprStmt = static_cast(N); + inferExpression(ExprStmt->Expression); + break; + } + + default: + ZEN_UNREACHABLE + + } + +} + +Type* Checker::createConType(ByteString Name) { + return new Type(TCon(NextConTypeId++, Name)); +} + +Type* Checker::createRigidVar(ByteString Name) { + auto TV = new Type(TVar(VarKind::Rigid, NextTypeVarId++, {}, Name, {{}})); + getContext().TVs->emplace(TV); + return TV; +} + +Type* Checker::createTypeVar() { + auto TV = new Type(TVar(VarKind::Unification, NextTypeVarId++, {})); + getContext().TVs->emplace(TV); + return TV; +} + +InferContext* Checker::createInferContext(InferContext* Parent, TVSet* TVs, ConstraintSet* Constraints) { + auto Ctx = new InferContext; + Ctx->Parent = Parent; + Ctx->TVs = new TVSet; + Ctx->Constraints = new ConstraintSet; + return Ctx; +} + +Type* Checker::instantiate(Scheme* Scm, Node* Source) { + + switch (Scm->getKind()) { + + case SchemeKind::Forall: + { + auto F = static_cast(Scm); + + TVSub Sub; + for (auto TV: *F->TVs) { + auto Fresh = createTypeVar(); + // std::cerr << describe(TV) << " => " << describe(Fresh) << std::endl; + Fresh->asVar().Context = TV->asVar().Context; + Sub[TV] = Fresh; + } + + for (auto Constraint: *F->Constraints) { + + // FIXME improve this + if (Constraint->getKind() == ConstraintKind::Equal) { + auto Eq = static_cast(Constraint); + Eq->Left = solveType(Eq->Left); + Eq->Right = solveType(Eq->Right); + } + + auto NewConstraint = Constraint->substitute(Sub); + + // This makes error messages prettier by relating the typing failure + // to the call site rather than the definition. + if (NewConstraint->getKind() == ConstraintKind::Equal) { + auto Eq = static_cast(Constraint); + Eq->Source = Source; + } + + addConstraint(NewConstraint); + } + + // This call to solve happens because constraints may have already + // been solved, with some unification variables being erased. To make + // sure we instantiate unification variables that are still in use + // we solve before substituting. + return solveType(F->Type)->substitute(Sub); + } + + } + + ZEN_UNREACHABLE +} + +void Checker::inferConstraintExpression(ConstraintExpression* C) { + switch (C->getKind()) { + case NodeKind::TypeclassConstraintExpression: + { + auto D = static_cast(C); + std::vector Types; + for (auto TE: D->TEs) { + auto Ty = inferTypeExpression(TE); + Ty->asVar().Provided->emplace(getCanonicalText(D->Name)); + Types.push_back(Ty); + } + break; + } + case NodeKind::EqualityConstraintExpression: + { + auto D = static_cast(C); + makeEqual(inferTypeExpression(D->Left), inferTypeExpression(D->Right), C); + break; + } + default: + ZEN_UNREACHABLE + } +} + +Type* Checker::inferTypeExpression(TypeExpression* N, bool AutoVars) { + + switch (N->getKind()) { + + case NodeKind::ReferenceTypeExpression: + { + auto RefTE = static_cast(N); + auto Scm = lookup(getCanonicalText(RefTE->Name), SymKind::Type); + Type* Ty; + if (Scm == nullptr) { + DE.add(getCanonicalText(RefTE->Name), RefTE->Name); + Ty = createTypeVar(); + } else { + Ty = instantiate(Scm, RefTE); + } + N->setType(Ty); + return Ty; + } + + case NodeKind::AppTypeExpression: + { + auto AppTE = static_cast(N); + Type* Ty = inferTypeExpression(AppTE->Op, AutoVars); + for (auto Arg: AppTE->Args) { + Ty = new Type(TApp(Ty, inferTypeExpression(Arg, AutoVars))); + } + N->setType(Ty); + return Ty; + } + + case NodeKind::VarTypeExpression: + { + auto VarTE = static_cast(N); + auto Ty = lookupMono(getCanonicalText(VarTE->Name), SymKind::Type); + if (Ty == nullptr) { + if (!AutoVars || Config.typeVarsRequireForall()) { + DE.add(getCanonicalText(VarTE->Name), VarTE->Name); + } + Ty = createRigidVar(getCanonicalText(VarTE->Name)); + addBinding(getCanonicalText(VarTE->Name), new Forall(Ty), SymKind::Type); + } + ZEN_ASSERT(Ty->isVar()); + N->setType(Ty); + return Ty; + } + + case NodeKind::RecordTypeExpression: + { + auto RecTE = static_cast(N); + auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, AutoVars) : new Type(TNil()); + for (auto [Field, Comma]: RecTE->Fields) { + Ty = new Type(TField(getCanonicalText(Field->Name), new Type(TPresent(inferTypeExpression(Field->TE, AutoVars))), Ty)); + } + N->setType(Ty); + return Ty; + } + + case NodeKind::TupleTypeExpression: + { + auto TupleTE = static_cast(N); + std::vector ElementTypes; + for (auto [TE, Comma]: TupleTE->Elements) { + ElementTypes.push_back(inferTypeExpression(TE, AutoVars)); + } + auto Ty = new Type(TTuple(ElementTypes)); + N->setType(Ty); + return Ty; + } + + case NodeKind::NestedTypeExpression: + { + auto NestedTE = static_cast(N); + auto Ty = inferTypeExpression(NestedTE->TE, AutoVars); + N->setType(Ty); + return Ty; + } + + case NodeKind::ArrowTypeExpression: + { + auto ArrowTE = static_cast(N); + std::vector ParamTypes; + for (auto ParamType: ArrowTE->ParamTypes) { + ParamTypes.push_back(inferTypeExpression(ParamType, AutoVars)); + } + auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, AutoVars); + auto Ty = Type::buildArrow(ParamTypes, ReturnType); + N->setType(Ty); + return Ty; + } + + case NodeKind::QualifiedTypeExpression: + { + auto QTE = static_cast(N); + for (auto [C, Comma]: QTE->Constraints) { + inferConstraintExpression(C); + } + auto Ty = inferTypeExpression(QTE->TE, AutoVars); + N->setType(Ty); + return Ty; + } + + default: + ZEN_UNREACHABLE + + } +} + +Type* sortRow(Type* Ty) { + std::map Fields; + while (Ty->isField()) { + auto& Field = Ty->asField(); + Fields.emplace(Field.Name, Ty); + Ty = Field.RestTy; + } + for (auto [Name, Field]: Fields) { + Ty = new Type(TField(Name, Field->asField().Ty, Ty)); + } + return Ty; +} + +Type* Checker::inferExpression(Expression* X) { + + Type* Ty; + + for (auto A: X->Annotations) { + if (A->getKind() == NodeKind::TypeAssertAnnotation) { + inferTypeExpression(static_cast(A)->TE); + } + } + + switch (X->getKind()) { + + case NodeKind::MatchExpression: + { + auto Match = static_cast(X); + Type* ValTy; + if (Match->Value) { + ValTy = inferExpression(Match->Value); + } else { + ValTy = createTypeVar(); + } + Ty = createTypeVar(); + for (auto Case: Match->Cases) { + auto OldCtx = &getContext(); + setContext(Case->Ctx); + auto PattTy = inferPattern(Case->Pattern); + makeEqual(PattTy, ValTy, Case); + auto ExprTy = inferExpression(Case->Expression); + makeEqual(ExprTy, Ty, Case->Expression); + setContext(OldCtx); + } + if (!Match->Value) { + Ty = new Type(TArrow(ValTy, Ty)); + } + break; + } + + case NodeKind::RecordExpression: + { + auto Record = static_cast(X); + Ty = new Type(TNil()); + for (auto [Field, Comma]: Record->Fields) { + Ty = new Type(TField( + getCanonicalText(Field->Name), + new Type(TPresent(inferExpression(Field->getExpression()))), + Ty + )); + } + Ty = sortRow(Ty); + break; + } + + case NodeKind::LiteralExpression: + { + auto Const = static_cast(X); + Ty = inferLiteral(Const->Token); + break; + } + + case NodeKind::ReferenceExpression: + { + auto Ref = static_cast(X); + ZEN_ASSERT(Ref->ModulePath.empty()); + if (Ref->Name->is()) { + auto Scm = lookup(getCanonicalText(Ref->Name), SymKind::Var); + if (!Scm) { + DE.add(getCanonicalText(Ref->Name), Ref->Name); + Ty = createTypeVar(); + break; + } + Ty = instantiate(Scm, X); + break; + } + auto Target = Ref->getScope()->lookup(Ref->getSymbolPath()); + if (!Target) { + DE.add(getCanonicalText(Ref->Name), Ref->Name); + Ty = createTypeVar(); + break; + } + if (Target->getKind() == NodeKind::LetDeclaration) { + auto Let = static_cast(Target); + if (Let->IsCycleActive) { + Ty = Let->getType(); + break; + } + if (!Let->Visited) { + infer(Let); + } + } + auto Scm = lookup(getCanonicalText(Ref->Name), SymKind::Var); + ZEN_ASSERT(Scm); + Ty = instantiate(Scm, X); + break; + } + + case NodeKind::CallExpression: + { + auto Call = static_cast(X); + auto OpTy = inferExpression(Call->Function); + Ty = createTypeVar(); + std::vector ArgTypes; + for (auto Arg: Call->Args) { + ArgTypes.push_back(inferExpression(Arg)); + } + makeEqual(OpTy, Type::buildArrow(ArgTypes, Ty), X); + break; + } + + case NodeKind::InfixExpression: + { + auto Infix = static_cast(X); + auto Scm = lookup(Infix->Operator->getText(), SymKind::Var); + if (Scm == nullptr) { + DE.add(Infix->Operator->getText(), Infix->Operator); + Ty = createTypeVar(); + break; + } + auto OpTy = instantiate(Scm, Infix->Operator); + Ty = createTypeVar(); + std::vector ArgTys; + ArgTys.push_back(inferExpression(Infix->Left)); + ArgTys.push_back(inferExpression(Infix->Right)); + makeEqual(Type::buildArrow(ArgTys, Ty), OpTy, X); + break; + } + + case NodeKind::TupleExpression: + { + auto Tuple = static_cast(X); + std::vector Types; + for (auto [E, Comma]: Tuple->Elements) { + Types.push_back(inferExpression(E)); + } + Ty = new Type(TTuple(Types)); + break; + } + + case NodeKind::MemberExpression: + { + auto Member = static_cast(X); + auto ExprTy = inferExpression(Member->E); + switch (Member->Name->getKind()) { + case NodeKind::IntegerLiteral: + { + auto I = static_cast(Member->Name); + Ty = createTypeVar(); + addConstraint(new CField(ExprTy, I->asInt(), Ty, Member)); + break; + } + case NodeKind::Identifier: + { + auto K = static_cast(Member->Name); + Ty = createTypeVar(); + auto RestTy = createTypeVar(); + makeEqual(new Type(TField(getCanonicalText(K), Ty, RestTy)), ExprTy, Member); + break; + } + default: + ZEN_UNREACHABLE + } + break; + } + + case NodeKind::NestedExpression: + { + auto Nested = static_cast(X); + Ty = inferExpression(Nested->Inner); + break; + } + + default: + ZEN_UNREACHABLE + + } + + // Ty = find(Ty); + X->setType(Ty); + return Ty; +} + +RecordPatternField* getRestField(std::vector> Fields) { + for (auto [Field, Comma]: Fields) { + if (Field->DotDot) { + return Field; + } + } + return nullptr; +} + +Type* Checker::inferPattern( + Pattern* Pattern, + ConstraintSet* Constraints, + TVSet* TVs +) { + + switch (Pattern->getKind()) { + + case NodeKind::BindPattern: + { + auto P = static_cast(Pattern); + auto Ty = createTypeVar(); + addBinding(getCanonicalText(P->Name), new Forall(TVs, Constraints, Ty), SymKind::Var); + return Ty; + } + + case NodeKind::NamedTuplePattern: + { + auto P = static_cast(Pattern); + auto Scm = lookup(getCanonicalText(P->Name), SymKind::Var); + std::vector ElementTypes; + for (auto P2: P->Patterns) { + ElementTypes.push_back(inferPattern(P2, Constraints, TVs)); + } + if (!Scm) { + DE.add(getCanonicalText(P->Name), P->Name); + return createTypeVar(); + } + auto Ty = instantiate(Scm, P); + auto RetTy = createTypeVar(); + makeEqual(Ty, Type::buildArrow(ElementTypes, RetTy), P); + return RetTy; + } + + case NodeKind::RecordPattern: + { + auto P = static_cast(Pattern); + auto RestField = getRestField(P->Fields); + Type* RecordTy; + if (RestField == nullptr) { + RecordTy = new Type(TNil()); + } else if (RestField->Pattern) { + RecordTy = inferPattern(RestField->Pattern); + } else { + RecordTy = createTypeVar(); + } + for (auto [Field, Comma]: P->Fields) { + if (Field->DotDot) { + continue; + } + Type* FieldTy; + if (Field->Pattern) { + FieldTy = inferPattern(Field->Pattern, Constraints, TVs); + } else { + FieldTy = createTypeVar(); + addBinding(getCanonicalText(Field->Name), new Forall(TVs, Constraints, FieldTy), SymKind::Var); + } + RecordTy = new Type(TField(getCanonicalText(Field->Name), new Type(TPresent(FieldTy)), RecordTy)); + } + return RecordTy; + } + + case NodeKind::NamedRecordPattern: + { + auto P = static_cast(Pattern); + auto Scm = lookup(getCanonicalText(P->Name), SymKind::Var); + if (Scm == nullptr) { + DE.add(getCanonicalText(P->Name), P->Name); + return createTypeVar(); + } + auto RestField = getRestField(P->Fields); + Type* RecordTy; + if (RestField == nullptr) { + RecordTy = new Type(TNil()); + } else if (RestField->Pattern) { + RecordTy = inferPattern(RestField->Pattern); + } else { + RecordTy = createTypeVar(); + } + for (auto [Field, Comma]: P->Fields) { + if (Field->DotDot) { + continue; + } + Type* FieldTy; + if (Field->Pattern) { + FieldTy = inferPattern(Field->Pattern, Constraints, TVs); + } else { + FieldTy = createTypeVar(); + addBinding(getCanonicalText(Field->Name), new Forall(TVs, Constraints, FieldTy), SymKind::Var); + } + RecordTy = new Type(TField(getCanonicalText(Field->Name), new Type(TPresent(FieldTy)), RecordTy)); + } + auto Ty = instantiate(Scm, P); + auto RetTy = createTypeVar(); + makeEqual(Ty, new Type(TArrow(RecordTy, RetTy)), P); + return RetTy; + } + + case NodeKind::TuplePattern: + { + auto P = static_cast(Pattern); + std::vector ElementTypes; + for (auto [Element, Comma]: P->Elements) { + ElementTypes.push_back(inferPattern(Element)); + } + return new Type(TTuple(ElementTypes)); + } + + case NodeKind::ListPattern: + { + auto P = static_cast(Pattern); + auto ElementType = createTypeVar(); + for (auto [Element, Separator]: P->Elements) { + makeEqual(ElementType, inferPattern(Element), P); + } + return new Type(TApp(ListType, ElementType)); + } + + case NodeKind::NestedPattern: + { + auto P = static_cast(Pattern); + return inferPattern(P->P, Constraints, TVs); + } + + case NodeKind::LiteralPattern: + { + auto P = static_cast(Pattern); + return inferLiteral(P->Literal); + } + + default: + ZEN_UNREACHABLE + + } + +} + +Type* Checker::inferLiteral(Literal* L) { + Type* Ty; + switch (L->getKind()) { + case NodeKind::IntegerLiteral: + Ty = lookupMono("Int", SymKind::Type); + break; + case NodeKind::StringLiteral: + Ty = lookupMono("String", SymKind::Type); + break; + default: + ZEN_UNREACHABLE + } + ZEN_ASSERT(Ty != nullptr); + return Ty; +} + +void Checker::populate(SourceFile* SF) { + + struct Visitor : public CSTVisitor { + + Graph& RefGraph; + + std::stack Stack; + + void visitLetDeclaration(LetDeclaration* N) { + RefGraph.addVertex(N); + Stack.push(N); + visitEachChild(N); + Stack.pop(); + } + + void visitReferenceExpression(ReferenceExpression* N) { + auto Y = static_cast(N); + auto Def = Y->getScope()->lookup(Y->getSymbolPath()); + // Name lookup failures will be reported directly in inferExpression(). + if (Def == nullptr || Def->getKind() != NodeKind::LetDeclaration) { + return; + } + // This case ensures that a deeply nested structure that references a + // parameter of a parent node but is not referenced itself is correctly handled. + // Note that the edge goes from the parent let to the parameter. This is normal. + if (Def->getKind() == NodeKind::Parameter) { + RefGraph.addEdge(Stack.top(), Def->Parent); + return; + } + if (!Stack.empty()) { + RefGraph.addEdge(Def, Stack.top()); + } + } + + }; + + Visitor V { {}, RefGraph }; + V.visit(SF); + +} + +Type* Checker::getType(TypedNode *Node) { + auto Ty = Node->getType(); + if (Node->Flags & NodeFlags_TypeIsSolved) { + return Ty; + } + Ty = solveType(Ty); + Node->setType(Ty); + Node->Flags |= NodeFlags_TypeIsSolved; + return Ty; +} + +void Checker::check(SourceFile *SF) { + initialize(SF); + setContext(SF->Ctx); + addBinding("String", new Forall(StringType), SymKind::Type); + addBinding("Int", new Forall(IntType), SymKind::Type); + addBinding("Bool", new Forall(BoolType), SymKind::Type); + addBinding("List", new Forall(ListType), SymKind::Type); + addBinding("True", new Forall(BoolType), SymKind::Var); + addBinding("False", new Forall(BoolType), SymKind::Var); + auto A = createTypeVar(); + addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, Type::buildArrow({ A, A }, BoolType)), SymKind::Var); + addBinding("+", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); + addBinding("-", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); + addBinding("*", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); + addBinding("/", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); + populate(SF); + forwardDeclare(SF); + auto SCCs = RefGraph.strongconnect(); + for (auto Nodes: SCCs) { + auto TVs = new TVSet; + auto Constraints = new ConstraintSet; + for (auto N: Nodes) { + if (N->getKind() != NodeKind::LetDeclaration) { + continue; + } + auto Decl = static_cast(N); + forwardDeclareFunctionDeclaration(Decl, TVs, Constraints); + } + } + setContext(SF->Ctx); + infer(SF); + + // Important because otherwise some logic for some optimisations will kick in that are no longer active. + ActiveContext = nullptr; + + solve(new CMany(*SF->Ctx->Constraints)); + + class Visitor : public CSTVisitor { + + Checker& C; + + public: + + Visitor(Checker& C): + C(C) {} + + void visitAnnotation(Annotation* A) { + + } + + void visitExpression(Expression* X) { + C.getType(X); + } + + } V(*this); + + V.visit(SF); +} + +void Checker::solve(Constraint* Constraint) { + + Queue.push_back(Constraint); + bool DidJoin = false; + std::deque NextQueue; + + while (true) { + + if (Queue.empty()) { + if (NextQueue.empty() || !DidJoin) { + break; + } + DidJoin = false; + std::swap(Queue, NextQueue); + } + + auto Constraint = Queue.front(); + Queue.pop_front(); + + switch (Constraint->getKind()) { + + case ConstraintKind::Empty: + break; + + case ConstraintKind::Field: + { + auto Field = static_cast(Constraint); + auto MaybeTuple = Field->TupleTy->find(); + if (MaybeTuple->isTuple()) { + auto& Tuple = MaybeTuple->asTuple(); + if (Field->I >= Tuple.ElementTypes.size()) { + DE.add(MaybeTuple, Field->I, Field->Source); + } else { + auto ElementTy = Tuple.ElementTypes[Field->I]; + unify(ElementTy, Field->FieldTy, Field->Source); + } + } else if (MaybeTuple->isVar()) { + NextQueue.push_back(Constraint); + } else { + DE.add(MaybeTuple, Field->Source); + } + break; + } + + case ConstraintKind::Many: + { + auto Many = static_cast(Constraint); + for (auto Constraint: Many->Elements) { + Queue.push_back(Constraint); + } + break; + } + + case ConstraintKind::Equal: + { + auto Equal = static_cast(Constraint); + if (unify(Equal->Left, Equal->Right, Equal->Source)) { + DidJoin = true; + } + break; + } } - // Ty = find(Ty); - X->setType(Ty); - return Ty; } - RecordPatternField* getRestField(std::vector> Fields) { - for (auto [Field, Comma]: Fields) { - if (Field->DotDot) { - return Field; +} + +bool assignableTo(Type* A, Type* B) { + if (A->isCon() && B->isCon()) { + auto& Con1 = A->asCon(); + auto& Con2 = B->asCon(); + if (Con1.Id != Con2.Id) { + return false; + } + return true; + } + // TODO must handle a TApp + ZEN_UNREACHABLE +} + +class ArrowCursor { + + /// Types on this stack are guaranteed to be arrow types. + std::stack> Stack; + + TypePath& Path; + std::size_t I; + +public: + + ArrowCursor(Type* Arr, TypePath& Path): + Path(Path) { + Stack.push({ Arr, true }); + Path.push_back(Arr->getStartIndex()); + } + + Type* next() { + while (!Stack.empty()) { + auto& [Arrow, First] = Stack.top(); + auto& Index = Path.back(); + if (!First) { + Index.advance(Arrow); + } else { + First = false; + } + Type* Ty; + if (Index == Arrow->getEndIndex()) { + Path.pop_back(); + Stack.pop(); + continue; + } + Ty = Arrow->resolve(Index); + if (Ty->isArrow()) { + auto NewIndex = Arrow->getStartIndex(); + Stack.push({ Ty, true }); + Path.push_back(NewIndex); + } else { + return Ty; } } return nullptr; } - Type* Checker::inferPattern( - Pattern* Pattern, - ConstraintSet* Constraints, - TVSet* TVs - ) { +}; - switch (Pattern->getKind()) { +struct Unifier { - case NodeKind::BindPattern: - { - auto P = static_cast(Pattern); - auto Ty = createTypeVar(); - addBinding(getCanonicalText(P->Name), new Forall(TVs, Constraints, Ty), SymKind::Var); - return Ty; - } + Checker& C; + // CEqual* Constraint; + Type* Left; + Type* Right; + Node* Source; - case NodeKind::NamedTuplePattern: - { - auto P = static_cast(Pattern); - auto Scm = lookup(getCanonicalText(P->Name), SymKind::Var); - std::vector ElementTypes; - for (auto P2: P->Patterns) { - ElementTypes.push_back(inferPattern(P2, Constraints, TVs)); - } - if (!Scm) { - DE.add(getCanonicalText(P->Name), P->Name); - return createTypeVar(); - } - auto Ty = instantiate(Scm, P); - auto RetTy = createTypeVar(); - makeEqual(Ty, Type::buildArrow(ElementTypes, RetTy), P); - return RetTy; - } - - case NodeKind::RecordPattern: - { - auto P = static_cast(Pattern); - auto RestField = getRestField(P->Fields); - Type* RecordTy; - if (RestField == nullptr) { - RecordTy = new Type(TNil()); - } else if (RestField->Pattern) { - RecordTy = inferPattern(RestField->Pattern); - } else { - RecordTy = createTypeVar(); - } - for (auto [Field, Comma]: P->Fields) { - if (Field->DotDot) { - continue; - } - Type* FieldTy; - if (Field->Pattern) { - FieldTy = inferPattern(Field->Pattern, Constraints, TVs); - } else { - FieldTy = createTypeVar(); - addBinding(getCanonicalText(Field->Name), new Forall(TVs, Constraints, FieldTy), SymKind::Var); - } - RecordTy = new Type(TField(getCanonicalText(Field->Name), new Type(TPresent(FieldTy)), RecordTy)); - } - return RecordTy; - } - - case NodeKind::NamedRecordPattern: - { - auto P = static_cast(Pattern); - auto Scm = lookup(getCanonicalText(P->Name), SymKind::Var); - if (Scm == nullptr) { - DE.add(getCanonicalText(P->Name), P->Name); - return createTypeVar(); - } - auto RestField = getRestField(P->Fields); - Type* RecordTy; - if (RestField == nullptr) { - RecordTy = new Type(TNil()); - } else if (RestField->Pattern) { - RecordTy = inferPattern(RestField->Pattern); - } else { - RecordTy = createTypeVar(); - } - for (auto [Field, Comma]: P->Fields) { - if (Field->DotDot) { - continue; - } - Type* FieldTy; - if (Field->Pattern) { - FieldTy = inferPattern(Field->Pattern, Constraints, TVs); - } else { - FieldTy = createTypeVar(); - addBinding(getCanonicalText(Field->Name), new Forall(TVs, Constraints, FieldTy), SymKind::Var); - } - RecordTy = new Type(TField(getCanonicalText(Field->Name), new Type(TPresent(FieldTy)), RecordTy)); - } - auto Ty = instantiate(Scm, P); - auto RetTy = createTypeVar(); - makeEqual(Ty, new Type(TArrow(RecordTy, RetTy)), P); - return RetTy; - } - - case NodeKind::TuplePattern: - { - auto P = static_cast(Pattern); - std::vector ElementTypes; - for (auto [Element, Comma]: P->Elements) { - ElementTypes.push_back(inferPattern(Element)); - } - return new Type(TTuple(ElementTypes)); - } - - case NodeKind::ListPattern: - { - auto P = static_cast(Pattern); - auto ElementType = createTypeVar(); - for (auto [Element, Separator]: P->Elements) { - makeEqual(ElementType, inferPattern(Element), P); - } - return new Type(TApp(ListType, ElementType)); - } - - case NodeKind::NestedPattern: - { - auto P = static_cast(Pattern); - return inferPattern(P->P, Constraints, TVs); - } - - case NodeKind::LiteralPattern: - { - auto P = static_cast(Pattern); - return inferLiteral(P->Literal); - } - - default: - ZEN_UNREACHABLE - - } + // Internal state used by the unifier + ByteString CurrentFieldName; + TypePath LeftPath; + TypePath RightPath; + bool DidJoin = false; + Type* getLeft() const { + return Left; } - Type* Checker::inferLiteral(Literal* L) { - Type* Ty; - switch (L->getKind()) { - case NodeKind::IntegerLiteral: - Ty = lookupMono("Int", SymKind::Type); - break; - case NodeKind::StringLiteral: - Ty = lookupMono("String", SymKind::Type); - break; - default: - ZEN_UNREACHABLE - } - ZEN_ASSERT(Ty != nullptr); - return Ty; + Type* getRight() const { + return Right; } - void Checker::populate(SourceFile* SF) { - - struct Visitor : public CSTVisitor { - - Graph& RefGraph; - - std::stack Stack; - - void visitLetDeclaration(LetDeclaration* N) { - RefGraph.addVertex(N); - Stack.push(N); - visitEachChild(N); - Stack.pop(); - } - - void visitReferenceExpression(ReferenceExpression* N) { - auto Y = static_cast(N); - auto Def = Y->getScope()->lookup(Y->getSymbolPath()); - // Name lookup failures will be reported directly in inferExpression(). - if (Def == nullptr || Def->getKind() != NodeKind::LetDeclaration) { - return; - } - // This case ensures that a deeply nested structure that references a - // parameter of a parent node but is not referenced itself is correctly handled. - // Note that the edge goes from the parent let to the parameter. This is normal. - if (Def->getKind() == NodeKind::Parameter) { - RefGraph.addEdge(Stack.top(), Def->Parent); - return; - } - if (!Stack.empty()) { - RefGraph.addEdge(Def, Stack.top()); - } - } - - }; - - Visitor V { {}, RefGraph }; - V.visit(SF); - + Node* getSource() const { + return Source; } - Type* Checker::getType(TypedNode *Node) { - auto Ty = Node->getType(); - if (Node->Flags & NodeFlags_TypeIsSolved) { - return Ty; - } - Ty = solveType(Ty); - Node->setType(Ty); - Node->Flags |= NodeFlags_TypeIsSolved; - return Ty; + bool unifyField(Type* A, Type* B, bool DidSwap); + + bool unify(Type* A, Type* B, bool DidSwap); + + bool unify() { + return unify(Left, Right, false); } - void Checker::check(SourceFile *SF) { - initialize(SF); - setContext(SF->Ctx); - addBinding("String", new Forall(StringType), SymKind::Type); - addBinding("Int", new Forall(IntType), SymKind::Type); - addBinding("Bool", new Forall(BoolType), SymKind::Type); - addBinding("List", new Forall(ListType), SymKind::Type); - addBinding("True", new Forall(BoolType), SymKind::Var); - addBinding("False", new Forall(BoolType), SymKind::Var); - auto A = createTypeVar(); - addBinding("==", new Forall(new TVSet { A }, new ConstraintSet, Type::buildArrow({ A, A }, BoolType)), SymKind::Var); - addBinding("+", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); - addBinding("-", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); - addBinding("*", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); - addBinding("/", new Forall(Type::buildArrow({ IntType, IntType }, IntType)), SymKind::Var); - populate(SF); - forwardDeclare(SF); - auto SCCs = RefGraph.strongconnect(); - for (auto Nodes: SCCs) { - auto TVs = new TVSet; - auto Constraints = new ConstraintSet; - for (auto N: Nodes) { - if (N->getKind() != NodeKind::LetDeclaration) { - continue; - } - auto Decl = static_cast(N); - forwardDeclareFunctionDeclaration(Decl, TVs, Constraints); - } - } - setContext(SF->Ctx); - infer(SF); - - // Important because otherwise some logic for some optimisations will kick in that are no longer active. - ActiveContext = nullptr; - - solve(new CMany(*SF->Ctx->Constraints)); - - class Visitor : public CSTVisitor { - - Checker& C; - - public: - - Visitor(Checker& C): - C(C) {} - - void visitAnnotation(Annotation* A) { - - } - - void visitExpression(Expression* X) { - C.getType(X); - } - - } V(*this); - - V.visit(SF); - } - - void Checker::solve(Constraint* Constraint) { - - Queue.push_back(Constraint); - bool DidJoin = false; - std::deque NextQueue; - - while (true) { - - if (Queue.empty()) { - if (NextQueue.empty() || !DidJoin) { - break; - } - DidJoin = false; - std::swap(Queue, NextQueue); - } - - auto Constraint = Queue.front(); - Queue.pop_front(); - - switch (Constraint->getKind()) { - - case ConstraintKind::Empty: - break; - - case ConstraintKind::Field: - { - auto Field = static_cast(Constraint); - auto MaybeTuple = Field->TupleTy->find(); - if (MaybeTuple->isTuple()) { - auto& Tuple = MaybeTuple->asTuple(); - if (Field->I >= Tuple.ElementTypes.size()) { - DE.add(MaybeTuple, Field->I, Field->Source); - } else { - auto ElementTy = Tuple.ElementTypes[Field->I]; - unify(ElementTy, Field->FieldTy, Field->Source); - } - } else if (MaybeTuple->isVar()) { - NextQueue.push_back(Constraint); - } else { - DE.add(MaybeTuple, Field->Source); + std::vector findInstanceContext(const TypeSig& Ty, TypeclassId& Class) { + auto Match = C.InstanceMap.find(Class); + std::vector S; + if (Match != C.InstanceMap.end()) { + for (auto Instance: Match->second) { + if (assignableTo(Ty.Orig, Instance->TypeExps[0]->getType())) { + std::vector S; + for (auto Arg: Ty.Args) { + TypeclassContext Classes; + // TODO + S.push_back(Classes); } - break; + return S; } - - case ConstraintKind::Many: - { - auto Many = static_cast(Constraint); - for (auto Constraint: Many->Elements) { - Queue.push_back(Constraint); - } - break; - } - - case ConstraintKind::Equal: - { - auto Equal = static_cast(Constraint); - if (unify(Equal->Left, Equal->Right, Equal->Source)) { - DidJoin = true; - } - break; - } - } - } - + C.DE.add(Class, Ty.Orig, getSource()); + for (auto Arg: Ty.Args) { + S.push_back({}); + } + return S; } - bool assignableTo(Type* A, Type* B) { - if (A->isCon() && B->isCon()) { - auto& Con1 = A->asCon(); - auto& Con2 = B->asCon(); - if (Con1.Id != Con2.Id) { - return false; - } - return true; - } - // TODO must handle a TApp - ZEN_UNREACHABLE - } - - class ArrowCursor { - - /// Types on this stack are guaranteed to be arrow types. - std::stack> Stack; - - TypePath& Path; - std::size_t I; - - public: - - ArrowCursor(Type* Arr, TypePath& Path): - Path(Path) { - Stack.push({ Arr, true }); - Path.push_back(Arr->getStartIndex()); - } - - Type* next() { - while (!Stack.empty()) { - auto& [Arrow, First] = Stack.top(); - auto& Index = Path.back(); - if (!First) { - Index.advance(Arrow); - } else { - First = false; - } - Type* Ty; - if (Index == Arrow->getEndIndex()) { - Path.pop_back(); - Stack.pop(); - continue; - } - Ty = Arrow->resolve(Index); - if (Ty->isArrow()) { - auto NewIndex = Arrow->getStartIndex(); - Stack.push({ Ty, true }); - Path.push_back(NewIndex); - } else { - return Ty; - } - } - return nullptr; - } - - }; - - struct Unifier { - - Checker& C; - // CEqual* Constraint; - Type* Left; - Type* Right; - Node* Source; - - // Internal state used by the unifier - ByteString CurrentFieldName; - TypePath LeftPath; - TypePath RightPath; - bool DidJoin = false; - - Type* getLeft() const { - return Left; - } - - Type* getRight() const { - return Right; - } - - Node* getSource() const { - return Source; - } - - bool unifyField(Type* A, Type* B, bool DidSwap); - - bool unify(Type* A, Type* B, bool DidSwap); - - bool unify() { - return unify(Left, Right, false); - } - - std::vector findInstanceContext(const TypeSig& Ty, TypeclassId& Class) { - auto Match = C.InstanceMap.find(Class); - std::vector S; - if (Match != C.InstanceMap.end()) { - for (auto Instance: Match->second) { - if (assignableTo(Ty.Orig, Instance->TypeExps[0]->getType())) { - std::vector S; - for (auto Arg: Ty.Args) { - TypeclassContext Classes; - // TODO - S.push_back(Classes); - } - return S; - } - } - } - C.DE.add(Class, Ty.Orig, getSource()); - for (auto Arg: Ty.Args) { - S.push_back({}); - } - return S; - } - - TypeSig getTypeSig(Type* Ty) { - Type* Op = nullptr; - std::vector Args; - std::function Visit = [&](Type* Ty) { - if (Ty->isApp()) { - Visit(Ty->asApp().Op); - Visit(Ty->asApp().Arg); - } else if (!Op) { - Op = Ty; - } else { - Args.push_back(Ty); - } - }; - Visit(Ty); - return TypeSig { Ty, Op, Args }; - } - - void propagateClasses(std::unordered_set& Classes, Type* Ty) { - if (Ty->isVar()) { - auto TV = Ty->asVar(); - for (auto Class: Classes) { - TV.Context.emplace(Class); - } - if (TV.isRigid()) { - for (auto Id: TV.Context) { - if (!TV.Provided->count(Id)) { - C.DE.add(TypeclassSignature { Id, { Ty } }, getSource()); - } - } - } - } else if (Ty->isCon() || Ty->isApp()) { - auto Sig = getTypeSig(Ty); - for (auto Class: Classes) { - propagateClassTycon(Class, Sig); - } - } else if (!Classes.empty()) { - C.DE.add(Ty, std::vector(Classes.begin(), Classes.end()), getSource()); - } - }; - - void propagateClassTycon(TypeclassId& Class, const TypeSig& Sig) { - auto S = findInstanceContext(Sig, Class); - for (auto [Classes, Arg]: zen::zip(S, Sig.Args)) { - propagateClasses(Classes, Arg); - } - }; - - /** - * Assign a type to a unification variable. - * - * If there are class constraints, those are propagated. - * - * If this type variable is solved during inference, it will be removed from - * the inference context. - * - * Other side effects may occur. - */ - void join(Type* TV, Type* Ty) { - - // std::cerr << describe(TV) << " => " << describe(Ty) << std::endl; - - TV->set(Ty); - - DidJoin = true; - - propagateClasses(TV->asVar().Context, Ty); - - // This is a very specific adjustment that is critical to the - // well-functioning of the infer/unify algorithm. When addConstraint() is - // called, it may decide to solve the constraint immediately during - // inference. If this happens, a type variable might get assigned a concrete - // type such as Int. We therefore never want the variable to be polymorphic - // and be instantiated with a fresh variable, as that would allow Bool to - // collide with Int. - // - // Should it get assigned another unification variable, that's OK too - // because then that variable is what matters and it will become the new - // (possibly polymorphic) variable. - if (C.ActiveContext) { - // std::cerr << "erase " << describe(TV) << std::endl; - auto TVs = C.ActiveContext->TVs; - TVs->erase(TV); - } - - } - - }; - - bool Unifier::unifyField(Type* A, Type* B, bool DidSwap) { - if (A->isAbsent() && B->isAbsent()) { - return true; - } - if (B->isAbsent()) { - std::swap(A, B); - DidSwap = !DidSwap; - } - if (A->isAbsent()) { - auto& Present = B->asPresent(); - C.DE.add(CurrentFieldName, C.solveType(getLeft()), LeftPath, getSource()); - return false; - } - auto& Present1 = A->asPresent(); - auto& Present2 = B->asPresent(); - return unify(Present1.Ty, Present2.Ty, DidSwap); - }; - - bool Unifier::unify(Type* A, Type* B, bool DidSwap) { - - A = A->find(); - B = B->find(); - - auto unifyError = [&]() { - C.DE.add( - Left, - Right, - LeftPath, - RightPath, - Source - ); - }; - - auto pushLeft = [&](TypeIndex I) { - if (DidSwap) { - RightPath.push_back(I); + TypeSig getTypeSig(Type* Ty) { + Type* Op = nullptr; + std::vector Args; + std::function Visit = [&](Type* Ty) { + if (Ty->isApp()) { + Visit(Ty->asApp().Op); + Visit(Ty->asApp().Arg); + } else if (!Op) { + Op = Ty; } else { - LeftPath.push_back(I); + Args.push_back(Ty); } }; + Visit(Ty); + return TypeSig { Ty, Op, Args }; + } - auto popLeft = [&]() { - if (DidSwap) { - RightPath.pop_back(); - } else { - LeftPath.pop_back(); + void propagateClasses(std::unordered_set& Classes, Type* Ty) { + if (Ty->isVar()) { + auto TV = Ty->asVar(); + for (auto Class: Classes) { + TV.Context.emplace(Class); } - }; - - auto pushRight = [&](TypeIndex I) { - if (DidSwap) { - LeftPath.push_back(I); - } else { - RightPath.push_back(I); - } - }; - - auto popRight = [&]() { - if (DidSwap) { - LeftPath.pop_back(); - } else { - RightPath.pop_back(); - } - }; - - auto swap = [&]() { - std::swap(A, B); - DidSwap = !DidSwap; - }; - - if (A->isVar() && B->isVar()) { - auto& Var1 = A->asVar(); - auto& Var2 = B->asVar(); - if (Var1.isRigid() && Var2.isRigid()) { - if (Var1.Id != Var2.Id) { - unifyError(); - return false; - } - return true; - } - Type* To; - Type* From; - if (Var1.isRigid() && Var2.isUni()) { - To = A; - From = B; - } else { - // Only cases left are Var1 = Unification, Var2 = Rigid and Var1 = Unification, Var2 = Unification - // Either way, Var1, being Unification, is a good candidate for being unified away - To = B; - From = A; - } - if (From->asVar().Id != To->asVar().Id) { - join(From, To); - } - return true; - } - - if (B->isVar()) { - swap(); - } - - if (A->isVar()) { - - auto& TV = A->asVar(); - - // Rigid type variables can never unify with antything else than what we - // have already handled in the previous if-statement, so issue an error. if (TV.isRigid()) { - unifyError(); - return false; - } - - // Occurs check - if (B->hasTypeVar(A)) { - // NOTE Just like GHC, we just display an error message indicating that - // A cannot match B, e.g. a cannot match [a]. It looks much better - // than obsure references to an occurs check - unifyError(); - return false; - } - - join(A, B); - - return true; - } - - if (A->isArrow() && B->isArrow()) { - auto& Arrow1 = A->asArrow(); - auto& Arrow2 = B->asArrow(); - bool Success = true; - LeftPath.push_back(TypeIndex::forArrowParamType()); - RightPath.push_back(TypeIndex::forArrowParamType()); - if (!unify(Arrow1.ParamType, Arrow2.ParamType, DidSwap)) { - Success = false; - } - LeftPath.pop_back(); - RightPath.pop_back(); - LeftPath.push_back(TypeIndex::forArrowReturnType()); - RightPath.push_back(TypeIndex::forArrowReturnType()); - if (!unify(Arrow1.ReturnType, Arrow2.ReturnType, DidSwap)) { - Success = false; - } - LeftPath.pop_back(); - RightPath.pop_back(); - return Success; - } - - if (A->isApp() && B->isApp()) { - auto& App1 = A->asApp(); - auto& App2 = B->asApp(); - bool Success = true; - LeftPath.push_back(TypeIndex::forAppOpType()); - RightPath.push_back(TypeIndex::forAppOpType()); - if (!unify(App1.Op, App2.Op, DidSwap)) { - Success = false; - } - LeftPath.pop_back(); - RightPath.pop_back(); - LeftPath.push_back(TypeIndex::forAppArgType()); - RightPath.push_back(TypeIndex::forAppArgType()); - if (!unify(App1.Arg, App2.Arg, DidSwap)) { - Success = false; - } - LeftPath.pop_back(); - RightPath.pop_back(); - return Success; - } - - if (A->isTuple() && B->isTuple()) { - auto& Tuple1 = A->asTuple(); - auto& Tuple2 = B->asTuple(); - if (Tuple1.ElementTypes.size() != Tuple2.ElementTypes.size()) { - unifyError(); - return false; - } - auto Count = Tuple1.ElementTypes.size(); - bool Success = true; - for (size_t I = 0; I < Count; I++) { - LeftPath.push_back(TypeIndex::forTupleElement(I)); - RightPath.push_back(TypeIndex::forTupleElement(I)); - if (!unify(Tuple1.ElementTypes[I], Tuple2.ElementTypes[I], DidSwap)) { - Success = false; + for (auto Id: TV.Context) { + if (!TV.Provided->count(Id)) { + C.DE.add(TypeclassSignature { Id, { Ty } }, getSource()); + } } - LeftPath.pop_back(); - RightPath.pop_back(); } - return Success; + } else if (Ty->isCon() || Ty->isApp()) { + auto Sig = getTypeSig(Ty); + for (auto Class: Classes) { + propagateClassTycon(Class, Sig); + } + } else if (!Classes.empty()) { + C.DE.add(Ty, std::vector(Classes.begin(), Classes.end()), getSource()); + } + }; + + void propagateClassTycon(TypeclassId& Class, const TypeSig& Sig) { + auto S = findInstanceContext(Sig, Class); + for (auto [Classes, Arg]: zen::zip(S, Sig.Args)) { + propagateClasses(Classes, Arg); + } + }; + + /** + * Assign a type to a unification variable. + * + * If there are class constraints, those are propagated. + * + * If this type variable is solved during inference, it will be removed from + * the inference context. + * + * Other side effects may occur. + */ + void join(Type* TV, Type* Ty) { + + // std::cerr << describe(TV) << " => " << describe(Ty) << std::endl; + + TV->set(Ty); + + DidJoin = true; + + propagateClasses(TV->asVar().Context, Ty); + + // This is a very specific adjustment that is critical to the + // well-functioning of the infer/unify algorithm. When addConstraint() is + // called, it may decide to solve the constraint immediately during + // inference. If this happens, a type variable might get assigned a concrete + // type such as Int. We therefore never want the variable to be polymorphic + // and be instantiated with a fresh variable, as that would allow Bool to + // collide with Int. + // + // Should it get assigned another unification variable, that's OK too + // because then that variable is what matters and it will become the new + // (possibly polymorphic) variable. + if (C.ActiveContext) { + // std::cerr << "erase " << describe(TV) << std::endl; + auto TVs = C.ActiveContext->TVs; + TVs->erase(TV); } - // if (A->isTupleIndex() || B->isTupleIndex()) { - // // Type(s) could not be simplified at the beginning of this function, - // // so we have to re-visit the constraint when there is more information. - // C.Queue.push_back(Constraint); - // return true; - // } + } - // This does not work because it ignores the indices - // if (A->isTupleIndex() && B->isTupleIndex()) { - // auto Index1 = static_cast(A); - // auto Index2 = static_cast(B); - // return unify(Index1->Ty, Index2->Ty, Source); - // } +}; - if (A->isCon() && B->isCon()) { - auto& Con1 = A->asCon(); - auto& Con2 = B->asCon(); - if (Con1.Id != Con2.Id) { - unifyError(); - return false; - } - return true; - } - - if (A->isNil() && B->isNil()) { - return true; - } - - if (A->isField() && B->isField()) { - auto& Field1 = A->asField(); - auto& Field2 = B->asField(); - bool Success = true; - if (Field1.Name == Field2.Name) { - LeftPath.push_back(TypeIndex::forFieldType()); - RightPath.push_back(TypeIndex::forFieldType()); - CurrentFieldName = Field1.Name; - if (!unifyField(Field1.Ty, Field2.Ty, DidSwap)) { - Success = false; - } - LeftPath.pop_back(); - RightPath.pop_back(); - LeftPath.push_back(TypeIndex::forFieldRest()); - RightPath.push_back(TypeIndex::forFieldRest()); - if (!unify(Field1.RestTy, Field2.RestTy, DidSwap)) { - Success = false; - } - LeftPath.pop_back(); - RightPath.pop_back(); - return Success; - } - auto NewRestTy = new Type(TVar(VarKind::Unification, C.NextTypeVarId++)); - pushLeft(TypeIndex::forFieldRest()); - if (!unify(Field1.RestTy, new Type(TField(Field2.Name, Field2.Ty, NewRestTy)), DidSwap)) { - Success = false; - } - popLeft(); - pushRight(TypeIndex::forFieldRest()); - if (!unify(new Type(TField(Field1.Name, Field1.Ty, NewRestTy)), Field2.RestTy, DidSwap)) { - Success = false; - } - popRight(); - return Success; - } - - if (A->isNil() && B->isField()) { - swap(); - } - - if (A->isField() && B->isNil()) { - auto& Field = A->asField(); - bool Success = true; - pushLeft(TypeIndex::forFieldType()); - CurrentFieldName = Field.Name; - if (!unifyField(Field.Ty, new Type(TAbsent()), DidSwap)) { - Success = false; - } - popLeft(); - pushLeft(TypeIndex::forFieldRest()); - if (!unify(Field.RestTy, B, DidSwap)) { - Success = false; - } - popLeft(); - return Success; - } - - unifyError(); +bool Unifier::unifyField(Type* A, Type* B, bool DidSwap) { + if (A->isAbsent() && B->isAbsent()) { + return true; + } + if (B->isAbsent()) { + std::swap(A, B); + DidSwap = !DidSwap; + } + if (A->isAbsent()) { + auto& Present = B->asPresent(); + C.DE.add(CurrentFieldName, C.solveType(getLeft()), LeftPath, getSource()); return false; } + auto& Present1 = A->asPresent(); + auto& Present2 = B->asPresent(); + return unify(Present1.Ty, Present2.Ty, DidSwap); +}; - bool Checker::unify(Type* Left, Type* Right, Node* Source) { - // std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl; - Unifier A { *this, Left, Right, Source }; - A.unify(); - return A.DidJoin; +bool Unifier::unify(Type* A, Type* B, bool DidSwap) { + + A = A->find(); + B = B->find(); + + auto unifyError = [&]() { + C.DE.add( + Left, + Right, + LeftPath, + RightPath, + Source + ); + }; + + auto pushLeft = [&](TypeIndex I) { + if (DidSwap) { + RightPath.push_back(I); + } else { + LeftPath.push_back(I); + } + }; + + auto popLeft = [&]() { + if (DidSwap) { + RightPath.pop_back(); + } else { + LeftPath.pop_back(); + } + }; + + auto pushRight = [&](TypeIndex I) { + if (DidSwap) { + LeftPath.push_back(I); + } else { + RightPath.push_back(I); + } + }; + + auto popRight = [&]() { + if (DidSwap) { + LeftPath.pop_back(); + } else { + RightPath.pop_back(); + } + }; + + auto swap = [&]() { + std::swap(A, B); + DidSwap = !DidSwap; + }; + + if (A->isVar() && B->isVar()) { + auto& Var1 = A->asVar(); + auto& Var2 = B->asVar(); + if (Var1.isRigid() && Var2.isRigid()) { + if (Var1.Id != Var2.Id) { + unifyError(); + return false; + } + return true; + } + Type* To; + Type* From; + if (Var1.isRigid() && Var2.isUni()) { + To = A; + From = B; + } else { + // Only cases left are Var1 = Unification, Var2 = Rigid and Var1 = Unification, Var2 = Unification + // Either way, Var1, being Unification, is a good candidate for being unified away + To = B; + From = A; + } + if (From->asVar().Id != To->asVar().Id) { + join(From, To); + } + return true; } + if (B->isVar()) { + swap(); + } + + if (A->isVar()) { + + auto& TV = A->asVar(); + + // Rigid type variables can never unify with antything else than what we + // have already handled in the previous if-statement, so issue an error. + if (TV.isRigid()) { + unifyError(); + return false; + } + + // Occurs check + if (B->hasTypeVar(A)) { + // NOTE Just like GHC, we just display an error message indicating that + // A cannot match B, e.g. a cannot match [a]. It looks much better + // than obsure references to an occurs check + unifyError(); + return false; + } + + join(A, B); + + return true; + } + + if (A->isArrow() && B->isArrow()) { + auto& Arrow1 = A->asArrow(); + auto& Arrow2 = B->asArrow(); + bool Success = true; + LeftPath.push_back(TypeIndex::forArrowParamType()); + RightPath.push_back(TypeIndex::forArrowParamType()); + if (!unify(Arrow1.ParamType, Arrow2.ParamType, DidSwap)) { + Success = false; + } + LeftPath.pop_back(); + RightPath.pop_back(); + LeftPath.push_back(TypeIndex::forArrowReturnType()); + RightPath.push_back(TypeIndex::forArrowReturnType()); + if (!unify(Arrow1.ReturnType, Arrow2.ReturnType, DidSwap)) { + Success = false; + } + LeftPath.pop_back(); + RightPath.pop_back(); + return Success; + } + + if (A->isApp() && B->isApp()) { + auto& App1 = A->asApp(); + auto& App2 = B->asApp(); + bool Success = true; + LeftPath.push_back(TypeIndex::forAppOpType()); + RightPath.push_back(TypeIndex::forAppOpType()); + if (!unify(App1.Op, App2.Op, DidSwap)) { + Success = false; + } + LeftPath.pop_back(); + RightPath.pop_back(); + LeftPath.push_back(TypeIndex::forAppArgType()); + RightPath.push_back(TypeIndex::forAppArgType()); + if (!unify(App1.Arg, App2.Arg, DidSwap)) { + Success = false; + } + LeftPath.pop_back(); + RightPath.pop_back(); + return Success; + } + + if (A->isTuple() && B->isTuple()) { + auto& Tuple1 = A->asTuple(); + auto& Tuple2 = B->asTuple(); + if (Tuple1.ElementTypes.size() != Tuple2.ElementTypes.size()) { + unifyError(); + return false; + } + auto Count = Tuple1.ElementTypes.size(); + bool Success = true; + for (size_t I = 0; I < Count; I++) { + LeftPath.push_back(TypeIndex::forTupleElement(I)); + RightPath.push_back(TypeIndex::forTupleElement(I)); + if (!unify(Tuple1.ElementTypes[I], Tuple2.ElementTypes[I], DidSwap)) { + Success = false; + } + LeftPath.pop_back(); + RightPath.pop_back(); + } + return Success; + } + + // if (A->isTupleIndex() || B->isTupleIndex()) { + // // Type(s) could not be simplified at the beginning of this function, + // // so we have to re-visit the constraint when there is more information. + // C.Queue.push_back(Constraint); + // return true; + // } + + // This does not work because it ignores the indices + // if (A->isTupleIndex() && B->isTupleIndex()) { + // auto Index1 = static_cast(A); + // auto Index2 = static_cast(B); + // return unify(Index1->Ty, Index2->Ty, Source); + // } + + if (A->isCon() && B->isCon()) { + auto& Con1 = A->asCon(); + auto& Con2 = B->asCon(); + if (Con1.Id != Con2.Id) { + unifyError(); + return false; + } + return true; + } + + if (A->isNil() && B->isNil()) { + return true; + } + + if (A->isField() && B->isField()) { + auto& Field1 = A->asField(); + auto& Field2 = B->asField(); + bool Success = true; + if (Field1.Name == Field2.Name) { + LeftPath.push_back(TypeIndex::forFieldType()); + RightPath.push_back(TypeIndex::forFieldType()); + CurrentFieldName = Field1.Name; + if (!unifyField(Field1.Ty, Field2.Ty, DidSwap)) { + Success = false; + } + LeftPath.pop_back(); + RightPath.pop_back(); + LeftPath.push_back(TypeIndex::forFieldRest()); + RightPath.push_back(TypeIndex::forFieldRest()); + if (!unify(Field1.RestTy, Field2.RestTy, DidSwap)) { + Success = false; + } + LeftPath.pop_back(); + RightPath.pop_back(); + return Success; + } + auto NewRestTy = new Type(TVar(VarKind::Unification, C.NextTypeVarId++)); + pushLeft(TypeIndex::forFieldRest()); + if (!unify(Field1.RestTy, new Type(TField(Field2.Name, Field2.Ty, NewRestTy)), DidSwap)) { + Success = false; + } + popLeft(); + pushRight(TypeIndex::forFieldRest()); + if (!unify(new Type(TField(Field1.Name, Field1.Ty, NewRestTy)), Field2.RestTy, DidSwap)) { + Success = false; + } + popRight(); + return Success; + } + + if (A->isNil() && B->isField()) { + swap(); + } + + if (A->isField() && B->isNil()) { + auto& Field = A->asField(); + bool Success = true; + pushLeft(TypeIndex::forFieldType()); + CurrentFieldName = Field.Name; + if (!unifyField(Field.Ty, new Type(TAbsent()), DidSwap)) { + Success = false; + } + popLeft(); + pushLeft(TypeIndex::forFieldRest()); + if (!unify(Field.RestTy, B, DidSwap)) { + Success = false; + } + popLeft(); + return Success; + } + + unifyError(); + return false; +} + +bool Checker::unify(Type* Left, Type* Right, Node* Source) { + // std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl; + Unifier A { *this, Left, Right, Source }; + A.unify(); + return A.DidJoin; +} + } diff --git a/bootstrap/cxx/src/ConsolePrinter.cc b/bootstrap/cxx/src/ConsolePrinter.cc index 7fd0ebcb6..919b80d02 100644 --- a/bootstrap/cxx/src/ConsolePrinter.cc +++ b/bootstrap/cxx/src/ConsolePrinter.cc @@ -37,899 +37,899 @@ namespace bolt { - template - T countDigits(T number) { - if (number == 0) { - return 1; - } - return std::ceil(std::log10(number+1)); +template +T countDigits(T number) { + if (number == 0) { + return 1; } + return std::ceil(std::log10(number+1)); +} - static std::string describe(NodeKind Type) { - switch (Type) { - case NodeKind::Identifier: - return "an identifier starting with a lowercase letter"; - case NodeKind::IdentifierAlt: - return "an identifier starting with a capital letter"; - case NodeKind::CustomOperator: - return "an operator"; - case NodeKind::IntegerLiteral: - return "an integer literal"; - case NodeKind::EndOfFile: - return "end-of-file"; - case NodeKind::BlockStart: - return "the start of a new indented block"; - case NodeKind::BlockEnd: - return "the end of the current indented block"; - case NodeKind::LineFoldEnd: - return "the end of the current line-fold"; - case NodeKind::Assignment: - return "an assignment such as := or +="; - case NodeKind::ExpressionAnnotation: - return "a user-defined annotation"; - case NodeKind::TypeAssertAnnotation: - return "a built-in annotation for a type assertion"; - case NodeKind::TypeclassConstraintExpression: - return "a type class constraint"; - case NodeKind::EqualityConstraintExpression: - return "an equality constraint"; - case NodeKind::QualifiedTypeExpression: - return "a type expression with some constraints"; - case NodeKind::ReferenceTypeExpression: - return "a reference to another type"; - case NodeKind::ArrowTypeExpression: - return "a function type signature"; - case NodeKind::AppTypeExpression: - return "an application of one type to another"; - case NodeKind::VarTypeExpression: - return "a rigid variable"; - case NodeKind::NestedTypeExpression: - return "a type expression wrapped in '(' and ')'"; - case NodeKind::TupleTypeExpression: - return "a tuple type expression"; - case NodeKind::BindPattern: - return "a variable binder"; - case NodeKind::NamedTuplePattern: - return "a pattern for a variant member"; - case NodeKind::TuplePattern: - return "a pattern for a tuple"; - case NodeKind::ListPattern: - return "a pattern for a list"; - case NodeKind::LParen: - return "'('"; - case NodeKind::RParen: - return "')'"; - case NodeKind::LBrace: - return "'['"; - case NodeKind::RBrace: - return "']'"; - case NodeKind::LBracket: - return "'{'"; - case NodeKind::RBracket: - return "'}'"; - case NodeKind::Colon: - return "':'"; - case NodeKind::At: - return "'@'"; - case NodeKind::Comma: - return "','"; - case NodeKind::Equals: - return "'='"; - case NodeKind::StringLiteral: - return "a string literal"; - case NodeKind::Dot: - return "'.'"; - case NodeKind::DotDot: - return "'..'"; - case NodeKind::Tilde: - return "'~'"; - case NodeKind::RArrow: - return "'->'"; - case NodeKind::RArrowAlt: - return "'=>'"; - case NodeKind::PubKeyword: - return "'pub'"; - case NodeKind::LetKeyword: - return "'let'"; - case NodeKind::ForeignKeyword: - return "'foreign'"; - case NodeKind::MutKeyword: - return "'mut'"; - case NodeKind::MatchKeyword: - return "'match'"; - case NodeKind::ReturnKeyword: - return "'return'"; - case NodeKind::TypeKeyword: - return "'type'"; - case NodeKind::IfKeyword: - return "'if'"; - case NodeKind::ElifKeyword: - return "'elif'"; - case NodeKind::ElseKeyword: - return "'else'"; - case NodeKind::StructKeyword: - return "'struct'"; - case NodeKind::EnumKeyword: - return "'enum'"; - case NodeKind::ClassKeyword: - return "'class'"; - case NodeKind::InstanceKeyword: - return "'instance'"; - case NodeKind::LetDeclaration: - return "a let-declaration"; - case NodeKind::CallExpression: - return "a call-expression"; - case NodeKind::InfixExpression: - return "an infix-expression"; - case NodeKind::ReferenceExpression: - return "a reference to a function or variable"; - case NodeKind::MatchExpression: - return "a match-expression"; - case NodeKind::LiteralExpression: - return "a literal expression"; - case NodeKind::MemberExpression: - return "an accessor of a member"; - case NodeKind::IfStatement: - return "an if-statement"; - case NodeKind::IfStatementPart: - return "a branch of an if-statement"; - case NodeKind::VariantDeclaration: - return "a variant"; - case NodeKind::MatchCase: - return "a match-arm"; - default: - ZEN_UNREACHABLE - } +static std::string describe(NodeKind Type) { + switch (Type) { + case NodeKind::Identifier: + return "an identifier starting with a lowercase letter"; + case NodeKind::IdentifierAlt: + return "an identifier starting with a capital letter"; + case NodeKind::CustomOperator: + return "an operator"; + case NodeKind::IntegerLiteral: + return "an integer literal"; + case NodeKind::EndOfFile: + return "end-of-file"; + case NodeKind::BlockStart: + return "the start of a new indented block"; + case NodeKind::BlockEnd: + return "the end of the current indented block"; + case NodeKind::LineFoldEnd: + return "the end of the current line-fold"; + case NodeKind::Assignment: + return "an assignment such as := or +="; + case NodeKind::ExpressionAnnotation: + return "a user-defined annotation"; + case NodeKind::TypeAssertAnnotation: + return "a built-in annotation for a type assertion"; + case NodeKind::TypeclassConstraintExpression: + return "a type class constraint"; + case NodeKind::EqualityConstraintExpression: + return "an equality constraint"; + case NodeKind::QualifiedTypeExpression: + return "a type expression with some constraints"; + case NodeKind::ReferenceTypeExpression: + return "a reference to another type"; + case NodeKind::ArrowTypeExpression: + return "a function type signature"; + case NodeKind::AppTypeExpression: + return "an application of one type to another"; + case NodeKind::VarTypeExpression: + return "a rigid variable"; + case NodeKind::NestedTypeExpression: + return "a type expression wrapped in '(' and ')'"; + case NodeKind::TupleTypeExpression: + return "a tuple type expression"; + case NodeKind::BindPattern: + return "a variable binder"; + case NodeKind::NamedTuplePattern: + return "a pattern for a variant member"; + case NodeKind::TuplePattern: + return "a pattern for a tuple"; + case NodeKind::ListPattern: + return "a pattern for a list"; + case NodeKind::LParen: + return "'('"; + case NodeKind::RParen: + return "')'"; + case NodeKind::LBrace: + return "'['"; + case NodeKind::RBrace: + return "']'"; + case NodeKind::LBracket: + return "'{'"; + case NodeKind::RBracket: + return "'}'"; + case NodeKind::Colon: + return "':'"; + case NodeKind::At: + return "'@'"; + case NodeKind::Comma: + return "','"; + case NodeKind::Equals: + return "'='"; + case NodeKind::StringLiteral: + return "a string literal"; + case NodeKind::Dot: + return "'.'"; + case NodeKind::DotDot: + return "'..'"; + case NodeKind::Tilde: + return "'~'"; + case NodeKind::RArrow: + return "'->'"; + case NodeKind::RArrowAlt: + return "'=>'"; + case NodeKind::PubKeyword: + return "'pub'"; + case NodeKind::LetKeyword: + return "'let'"; + case NodeKind::ForeignKeyword: + return "'foreign'"; + case NodeKind::MutKeyword: + return "'mut'"; + case NodeKind::MatchKeyword: + return "'match'"; + case NodeKind::ReturnKeyword: + return "'return'"; + case NodeKind::TypeKeyword: + return "'type'"; + case NodeKind::IfKeyword: + return "'if'"; + case NodeKind::ElifKeyword: + return "'elif'"; + case NodeKind::ElseKeyword: + return "'else'"; + case NodeKind::StructKeyword: + return "'struct'"; + case NodeKind::EnumKeyword: + return "'enum'"; + case NodeKind::ClassKeyword: + return "'class'"; + case NodeKind::InstanceKeyword: + return "'instance'"; + case NodeKind::LetDeclaration: + return "a let-declaration"; + case NodeKind::CallExpression: + return "a call-expression"; + case NodeKind::InfixExpression: + return "an infix-expression"; + case NodeKind::ReferenceExpression: + return "a reference to a function or variable"; + case NodeKind::MatchExpression: + return "a match-expression"; + case NodeKind::LiteralExpression: + return "a literal expression"; + case NodeKind::MemberExpression: + return "an accessor of a member"; + case NodeKind::IfStatement: + return "an if-statement"; + case NodeKind::IfStatementPart: + return "a branch of an if-statement"; + case NodeKind::VariantDeclaration: + return "a variant"; + case NodeKind::MatchCase: + return "a match-arm"; + default: + ZEN_UNREACHABLE } +} - static std::string describe(Token* T) { - switch (T->getKind()) { - case NodeKind::LineFoldEnd: - case NodeKind::BlockStart: - case NodeKind::BlockEnd: - case NodeKind::EndOfFile: - return describe(T->getKind()); - default: - return "'" + T->getText() + "'"; - } +static std::string describe(Token* T) { + switch (T->getKind()) { + case NodeKind::LineFoldEnd: + case NodeKind::BlockStart: + case NodeKind::BlockEnd: + case NodeKind::EndOfFile: + return describe(T->getKind()); + default: + return "'" + T->getText() + "'"; } +} - std::string describe(const Type* Ty) { - Ty = Ty->find(); - switch (Ty->getKind()) { - case TypeKind::Var: - { - auto TV = Ty->asVar(); - if (TV.isRigid()) { - return *TV.Name; +std::string describe(const Type* Ty) { + Ty = Ty->find(); + switch (Ty->getKind()) { + case TypeKind::Var: + { + auto TV = Ty->asVar(); + if (TV.isRigid()) { + return *TV.Name; + } + return "a" + std::to_string(TV.Id); + } + case TypeKind::Arrow: + { + auto Y = Ty->asArrow(); + std::ostringstream Out; + Out << describe(Y.ParamType) << " -> " << describe(Y.ReturnType); + return Out.str(); + } + case TypeKind::Con: + { + auto Y = Ty->asCon(); + return Y.DisplayName; + } + case TypeKind::App: + { + auto Y = Ty->asApp(); + return describe(Y.Op) + " " + describe(Y.Arg); + } + case TypeKind::Tuple: + { + std::ostringstream Out; + auto Y = Ty->asTuple(); + Out << "("; + if (Y.ElementTypes.size()) { + auto Iter = Y.ElementTypes.begin(); + Out << describe(*Iter++); + while (Iter != Y.ElementTypes.end()) { + Out << ", " << describe(*Iter++); } - return "a" + std::to_string(TV.Id); } - case TypeKind::Arrow: - { - auto Y = Ty->asArrow(); - std::ostringstream Out; - Out << describe(Y.ParamType) << " -> " << describe(Y.ReturnType); - return Out.str(); - } - case TypeKind::Con: - { - auto Y = Ty->asCon(); - return Y.DisplayName; - } - case TypeKind::App: - { - auto Y = Ty->asApp(); - return describe(Y.Op) + " " + describe(Y.Arg); - } - case TypeKind::Tuple: - { - std::ostringstream Out; - auto Y = Ty->asTuple(); - Out << "("; - if (Y.ElementTypes.size()) { - auto Iter = Y.ElementTypes.begin(); - Out << describe(*Iter++); - while (Iter != Y.ElementTypes.end()) { - Out << ", " << describe(*Iter++); - } - } - Out << ")"; - return Out.str(); - } - case TypeKind::Nil: - return "{}"; - case TypeKind::Absent: - return "Abs"; - case TypeKind::Present: - { - auto Y = Ty->asPresent(); - return describe(Y.Ty); - } - case TypeKind::Field: - { + Out << ")"; + return Out.str(); + } + case TypeKind::Nil: + return "{}"; + case TypeKind::Absent: + return "Abs"; + case TypeKind::Present: + { + auto Y = Ty->asPresent(); + return describe(Y.Ty); + } + case TypeKind::Field: + { + auto Y = Ty->asField(); + std::ostringstream out; + out << "{ " << Y.Name << ": " << describe(Y.Ty); + Ty = Y.RestTy; + while (Ty->getKind() == TypeKind::Field) { auto Y = Ty->asField(); - std::ostringstream out; - out << "{ " << Y.Name << ": " << describe(Y.Ty); + out << "; " + Y.Name + ": " + describe(Y.Ty); Ty = Y.RestTy; - while (Ty->getKind() == TypeKind::Field) { - auto Y = Ty->asField(); - out << "; " + Y.Name + ": " + describe(Y.Ty); - Ty = Y.RestTy; - } - if (Ty->getKind() != TypeKind::Nil) { - out << "; " + describe(Ty); - } - out << " }"; - return out.str(); } - } - ZEN_UNREACHABLE - } - - void writeForegroundANSI(Color C, std::ostream& Out) { - switch (C) { - case Color::None: - break; - case Color::Black: - Out << ANSI_FG_BLACK; - break; - case Color::White: - Out << ANSI_FG_WHITE; - break; - case Color::Red: - Out << ANSI_FG_RED; - break; - case Color::Yellow: - Out << ANSI_FG_YELLOW; - break; - case Color::Green: - Out << ANSI_FG_GREEN; - break; - case Color::Blue: - Out << ANSI_FG_BLUE; - break; - case Color::Cyan: - Out << ANSI_FG_CYAN; - break; - case Color::Magenta: - Out << ANSI_FG_MAGENTA; - break; + if (Ty->getKind() != TypeKind::Nil) { + out << "; " + describe(Ty); + } + out << " }"; + return out.str(); } } + ZEN_UNREACHABLE +} - void writeBackgroundANSI(Color C, std::ostream& Out) { - switch (C) { - case Color::None: - break; - case Color::Black: - Out << ANSI_BG_BLACK; - break; - case Color::White: - Out << ANSI_BG_WHITE; - break; - case Color::Red: - Out << ANSI_BG_RED; - break; - case Color::Yellow: - Out << ANSI_BG_YELLOW; - break; - case Color::Green: - Out << ANSI_BG_GREEN; - break; - case Color::Blue: - Out << ANSI_BG_BLUE; - break; - case Color::Cyan: - Out << ANSI_BG_CYAN; - break; - case Color::Magenta: - Out << ANSI_BG_MAGENTA; - break; - } +void writeForegroundANSI(Color C, std::ostream& Out) { + switch (C) { + case Color::None: + break; + case Color::Black: + Out << ANSI_FG_BLACK; + break; + case Color::White: + Out << ANSI_FG_WHITE; + break; + case Color::Red: + Out << ANSI_FG_RED; + break; + case Color::Yellow: + Out << ANSI_FG_YELLOW; + break; + case Color::Green: + Out << ANSI_FG_GREEN; + break; + case Color::Blue: + Out << ANSI_FG_BLUE; + break; + case Color::Cyan: + Out << ANSI_FG_CYAN; + break; + case Color::Magenta: + Out << ANSI_FG_MAGENTA; + break; } +} - ConsolePrinter::ConsolePrinter(std::ostream& Out): - Out(Out) {} - - void ConsolePrinter::setForegroundColor(Color C) { - ActiveStyle.setForegroundColor(C); - if (!EnableColors) { - return; - } - writeForegroundANSI(C, Out); +void writeBackgroundANSI(Color C, std::ostream& Out) { + switch (C) { + case Color::None: + break; + case Color::Black: + Out << ANSI_BG_BLACK; + break; + case Color::White: + Out << ANSI_BG_WHITE; + break; + case Color::Red: + Out << ANSI_BG_RED; + break; + case Color::Yellow: + Out << ANSI_BG_YELLOW; + break; + case Color::Green: + Out << ANSI_BG_GREEN; + break; + case Color::Blue: + Out << ANSI_BG_BLUE; + break; + case Color::Cyan: + Out << ANSI_BG_CYAN; + break; + case Color::Magenta: + Out << ANSI_BG_MAGENTA; + break; } +} - void ConsolePrinter::setBackgroundColor(Color C) { - ActiveStyle.setBackgroundColor(C); - if (!EnableColors) { - return; - } - if (C == Color::None) { - Out << ANSI_RESET; - applyStyles(); - } - writeBackgroundANSI(C, Out); +ConsolePrinter::ConsolePrinter(std::ostream& Out): + Out(Out) {} + +void ConsolePrinter::setForegroundColor(Color C) { + ActiveStyle.setForegroundColor(C); + if (!EnableColors) { + return; } + writeForegroundANSI(C, Out); +} - void ConsolePrinter::applyStyles() { - if (ActiveStyle.isBold()) { - Out << ANSI_BOLD; - } - if (ActiveStyle.isUnderline()) { - Out << ANSI_UNDERLINE; - } - if (ActiveStyle.isItalic()) { - Out << ANSI_ITALIC; - } - if (ActiveStyle.hasBackgroundColor()) { - setBackgroundColor(ActiveStyle.getBackgroundColor()); - } - if (ActiveStyle.hasForegroundColor()) { - setForegroundColor(ActiveStyle.getForegroundColor()); - } +void ConsolePrinter::setBackgroundColor(Color C) { + ActiveStyle.setBackgroundColor(C); + if (!EnableColors) { + return; } - - void ConsolePrinter::setBold(bool Enable) { - ActiveStyle.setBold(Enable); - if (!EnableColors) { - return; - } - if (Enable) { - Out << ANSI_BOLD; - } else { - Out << ANSI_RESET; - applyStyles(); - } + if (C == Color::None) { + Out << ANSI_RESET; + applyStyles(); } + writeBackgroundANSI(C, Out); +} - void ConsolePrinter::setItalic(bool Enable) { - ActiveStyle.setItalic(Enable); - if (!EnableColors) { - return; - } - if (Enable) { - Out << ANSI_ITALIC; - } else { - Out << ANSI_RESET; - applyStyles(); - } +void ConsolePrinter::applyStyles() { + if (ActiveStyle.isBold()) { + Out << ANSI_BOLD; } - - void ConsolePrinter::setUnderline(bool Enable) { - ActiveStyle.setItalic(Enable); - if (!EnableColors) { - return; - } - if (Enable) { - Out << ANSI_UNDERLINE; - } else { - Out << ANSI_RESET; - applyStyles(); - } + if (ActiveStyle.isUnderline()) { + Out << ANSI_UNDERLINE; } - - void ConsolePrinter::resetStyles() { - ActiveStyle.reset(); - if (EnableColors) { - Out << ANSI_RESET; - } + if (ActiveStyle.isItalic()) { + Out << ANSI_ITALIC; } - - void ConsolePrinter::writeGutter( - std::size_t GutterWidth, - std::string Text - ) { - ZEN_ASSERT(Text.size() <= GutterWidth); - auto LeadingSpaces = GutterWidth - Text.size(); - Out << " "; - setForegroundColor(Color::Black); - setBackgroundColor(Color::White); - for (std::size_t i = 0; i < LeadingSpaces; i++) { - Out << ' '; - } - Out << Text; - resetStyles(); - Out << " "; + if (ActiveStyle.hasBackgroundColor()) { + setBackgroundColor(ActiveStyle.getBackgroundColor()); } + if (ActiveStyle.hasForegroundColor()) { + setForegroundColor(ActiveStyle.getForegroundColor()); + } +} - void ConsolePrinter::writeHighlight( - std::size_t GutterWidth, - TextRange Range, - Color HighlightColor, - std::size_t Line, - std::size_t LineLength - ) { - if (Line < Range.Start.Line || Range.End.Line < Line) { - return; - } - Out << " "; - setBackgroundColor(Color::White); - for (std::size_t i = 0; i < GutterWidth; i++) { - Out << ' '; - } - resetStyles(); +void ConsolePrinter::setBold(bool Enable) { + ActiveStyle.setBold(Enable); + if (!EnableColors) { + return; + } + if (Enable) { + Out << ANSI_BOLD; + } else { + Out << ANSI_RESET; + applyStyles(); + } +} + +void ConsolePrinter::setItalic(bool Enable) { + ActiveStyle.setItalic(Enable); + if (!EnableColors) { + return; + } + if (Enable) { + Out << ANSI_ITALIC; + } else { + Out << ANSI_RESET; + applyStyles(); + } +} + +void ConsolePrinter::setUnderline(bool Enable) { + ActiveStyle.setItalic(Enable); + if (!EnableColors) { + return; + } + if (Enable) { + Out << ANSI_UNDERLINE; + } else { + Out << ANSI_RESET; + applyStyles(); + } +} + +void ConsolePrinter::resetStyles() { + ActiveStyle.reset(); + if (EnableColors) { + Out << ANSI_RESET; + } +} + +void ConsolePrinter::writeGutter( + std::size_t GutterWidth, + std::string Text +) { + ZEN_ASSERT(Text.size() <= GutterWidth); + auto LeadingSpaces = GutterWidth - Text.size(); + Out << " "; + setForegroundColor(Color::Black); + setBackgroundColor(Color::White); + for (std::size_t i = 0; i < LeadingSpaces; i++) { Out << ' '; - std::size_t start_column = Range.Start.Line == Line ? Range.Start.Column : 1; - std::size_t end_column = Range.End.Line == Line ? Range.End.Column : LineLength+1; - for (std::size_t i = 1; i < start_column; i++) { - Out << ' '; + } + Out << Text; + resetStyles(); + Out << " "; +} + +void ConsolePrinter::writeHighlight( + std::size_t GutterWidth, + TextRange Range, + Color HighlightColor, + std::size_t Line, + std::size_t LineLength +) { + if (Line < Range.Start.Line || Range.End.Line < Line) { + return; + } + Out << " "; + setBackgroundColor(Color::White); + for (std::size_t i = 0; i < GutterWidth; i++) { + Out << ' '; + } + resetStyles(); + Out << ' '; + std::size_t start_column = Range.Start.Line == Line ? Range.Start.Column : 1; + std::size_t end_column = Range.End.Line == Line ? Range.End.Column : LineLength+1; + for (std::size_t i = 1; i < start_column; i++) { + Out << ' '; + } + setForegroundColor(HighlightColor); + if (start_column == end_column) { + Out << "↖"; + } else { + for (std::size_t i = start_column; i < end_column; i++) { + Out << '~'; } - setForegroundColor(HighlightColor); - if (start_column == end_column) { - Out << "↖"; + } + resetStyles(); + Out << '\n'; +} + +void ConsolePrinter::writeExcerpt( + const TextFile& File, + TextRange ToPrint, + TextRange ToHighlight, + Color HighlightColor +) { + + auto LineCount = File.getLineCount(); + auto Text = File.getText(); + auto StartPos = ToPrint.Start; + auto EndPos = ToPrint.End; + auto StartLine = StartPos.Line-1 > ExcerptLinesPre ? StartPos.Line - ExcerptLinesPre : 1; + auto StartOffset = File.getStartOffsetOfLine(StartLine); + auto EndLine = std::min(LineCount, EndPos.Line + ExcerptLinesPost); + auto EndOffset = File.getEndOffsetOfLine(EndLine); + auto GutterWidth = std::max(2, countDigits(EndLine+1)); + auto HighlightStart = ToHighlight.Start; + auto HighlightEnd = ToHighlight.End; + auto HighlightRange = TextRange { HighlightStart, HighlightEnd }; + + std::size_t CurrColumn = 1; + std::size_t CurrLine = StartLine; + bool AtBlankLine = true; + for (std::size_t I = StartOffset; I < EndOffset; I++) { + auto C = Text[I]; + if (AtBlankLine) { + writeGutter(GutterWidth, std::to_string(CurrLine)); + } + if (C == '\n') { + Out << C; + writeHighlight(GutterWidth, HighlightRange, HighlightColor, CurrLine, CurrColumn); + CurrLine++; + CurrColumn = 1; + AtBlankLine = true; } else { - for (std::size_t i = start_column; i < end_column; i++) { - Out << '~'; - } + AtBlankLine = false; + Out << C; + CurrColumn++; } - resetStyles(); - Out << '\n'; - } - - void ConsolePrinter::writeExcerpt( - const TextFile& File, - TextRange ToPrint, - TextRange ToHighlight, - Color HighlightColor - ) { - - auto LineCount = File.getLineCount(); - auto Text = File.getText(); - auto StartPos = ToPrint.Start; - auto EndPos = ToPrint.End; - auto StartLine = StartPos.Line-1 > ExcerptLinesPre ? StartPos.Line - ExcerptLinesPre : 1; - auto StartOffset = File.getStartOffsetOfLine(StartLine); - auto EndLine = std::min(LineCount, EndPos.Line + ExcerptLinesPost); - auto EndOffset = File.getEndOffsetOfLine(EndLine); - auto GutterWidth = std::max(2, countDigits(EndLine+1)); - auto HighlightStart = ToHighlight.Start; - auto HighlightEnd = ToHighlight.End; - auto HighlightRange = TextRange { HighlightStart, HighlightEnd }; - - std::size_t CurrColumn = 1; - std::size_t CurrLine = StartLine; - bool AtBlankLine = true; - for (std::size_t I = StartOffset; I < EndOffset; I++) { - auto C = Text[I]; - if (AtBlankLine) { - writeGutter(GutterWidth, std::to_string(CurrLine)); - } - if (C == '\n') { - Out << C; - writeHighlight(GutterWidth, HighlightRange, HighlightColor, CurrLine, CurrColumn); - CurrLine++; - CurrColumn = 1; - AtBlankLine = true; - } else { - AtBlankLine = false; - Out << C; - CurrColumn++; - } - } - - } - - void ConsolePrinter::write(const std::string_view& S) { - Out << S; - } - - void ConsolePrinter::write(char C) { - Out << C; - } - - void ConsolePrinter::write(std::size_t I) { - Out << I; - } - - void ConsolePrinter::writeBinding(const ByteString& Name) { - write("'"); - write(Name); - write("'"); - } - - void ConsolePrinter::writeType(const Type* Ty) { - TypePath Path; - writeType(Ty, Path); - } - - void ConsolePrinter::writeType(const Type* Ty, const TypePath& Underline) { - - setForegroundColor(Color::Green); - - class TypePrinter : public ConstTypeVisitor { - - TypePath Path; - ConsolePrinter& W; - const TypePath& Underline; - - public: - - TypePrinter(ConsolePrinter& W, const TypePath& Underline): - W(W), Underline(Underline) {} - - bool shouldUnderline() const { - return !Underline.empty() && Path == Underline; - } - - void enterType(const Type* Ty) override { - if (shouldUnderline()) { - W.setUnderline(true); - } - } - - void exitType(const Type* Ty) override { - if (shouldUnderline()) { - W.setUnderline(false); // FIXME Should set to old value - } - } - - void visitAppType(const TApp& Ty) override { - Path.push_back(TypeIndex::forAppOpType()); - visit(Ty.Op); - Path.pop_back(); - W.write(" "); - Path.push_back(TypeIndex::forAppArgType()); - visit(Ty.Arg); - Path.pop_back(); - } - - void visitVarType(const TVar& Ty) override { - if (Ty.isRigid()) { - W.write(*Ty.Name); - return; - } - W.write("a"); - W.write(Ty.Id); - } - - void visitConType(const TCon& Ty) override { - W.write(Ty.DisplayName); - } - - void visitArrowType(const TArrow& Ty) override { - Path.push_back(TypeIndex::forArrowParamType()); - visit(Ty.ParamType); - Path.pop_back(); - W.write(" -> "); - Path.push_back(TypeIndex::forArrowReturnType()); - visit(Ty.ReturnType); - Path.pop_back(); - } - - void visitTupleType(const TTuple& Ty) override { - W.write("("); - if (Ty.ElementTypes.size()) { - auto Iter = Ty.ElementTypes.begin(); - Path.push_back(TypeIndex::forTupleElement(0)); - visit(*Iter++); - Path.pop_back(); - std::size_t I = 1; - while (Iter != Ty.ElementTypes.end()) { - W.write(", "); - Path.push_back(TypeIndex::forTupleElement(I++)); - visit(*Iter++); - Path.pop_back(); - } - } - W.write(")"); - } - - void visitNilType(const TNil& Ty) override { - W.write("{}"); - } - - void visitAbsentType(const TAbsent& Ty) override { - W.write("Abs"); - } - - void visitPresentType(const TPresent& Ty) override { - Path.push_back(TypeIndex::forPresentType()); - visit(Ty.Ty); - Path.pop_back(); - } - - void visitFieldType(const TField& Ty) override { - W.write("{ "); - W.write(Ty.Name); - W.write(": "); - Path.push_back(TypeIndex::forFieldType()); - visit(Ty.Ty); - Path.pop_back(); - auto Ty2 = Ty.RestTy; - Path.push_back(TypeIndex::forFieldRest()); - std::size_t I = 1; - while (Ty2->isField()) { - auto Y = Ty2->asField(); - W.write("; "); - W.write(Y.Name); - W.write(": "); - Path.push_back(TypeIndex::forFieldType()); - visit(Y.Ty); - Path.pop_back(); - Ty2 = Y.RestTy; - Path.push_back(TypeIndex::forFieldRest()); - ++I; - } - if (Ty2->getKind() != TypeKind::Nil) { - W.write("; "); - visit(Ty2); - } - W.write(" }"); - for (auto K = 0; K < I; K++) { - Path.pop_back(); - } - } - - }; - - TypePrinter P { *this, Underline }; - P.visit(Ty); - - resetStyles(); - } - - void ConsolePrinter::writeType(std::size_t I) { - setForegroundColor(Color::Green); - write(I); - resetStyles(); - } - - void ConsolePrinter::writeNode(const Node* N) { - auto Range = N->getRange(); - writeExcerpt(N->getSourceFile()->getTextFile(), Range, Range, Color::Red); - } - - void ConsolePrinter::writeLoc(const TextFile& File, const TextLoc& Loc) { - setForegroundColor(Color::Yellow); - write(File.getPath()); - write(":"); - write(Loc.Line); - write(":"); - write(Loc.Column); - write(":"); - resetStyles(); - } - - void ConsolePrinter::writePrefix(const Diagnostic& D) { - setForegroundColor(Color::Red); - setBold(true); - write("error: "); - resetStyles(); - } - - void ConsolePrinter::writeTypeclassName(const ByteString& Name) { - setForegroundColor(Color::Magenta); - write(Name); - resetStyles(); - } - - void ConsolePrinter::writeTypeclassSignature(const TypeclassSignature& Sig) { - setForegroundColor(Color::Magenta); - write(Sig.Id); - for (auto TV: Sig.Params) { - write(" "); - write(describe(TV)); - } - resetStyles(); - } - - void ConsolePrinter::writeDiagnostic(const Diagnostic& D) { - - switch (D.getKind()) { - - case DiagnosticKind::BindingNotFound: - { - auto& E = static_cast(D); - writePrefix(E); - write("binding "); - writeBinding(E.Name); - write(" was not found\n\n"); - if (E.Initiator != nullptr) { - auto Range = E.Initiator->getRange(); - //std::cerr << Range.Start.Line << ":" << Range.Start.Column << "-" << Range.End.Line << ":" << Range.End.Column << "\n"; - writeExcerpt(E.Initiator->getSourceFile()->getTextFile(), Range, Range, Color::Red); - Out << "\n"; - } - return; - } - - case DiagnosticKind::UnexpectedToken: - { - auto& E = static_cast(D); - writePrefix(E); - writeLoc(E.File, E.Actual->getStartLoc()); - write(" expected "); - switch (E.Expected.size()) { - case 0: - write("nothing"); - break; - case 1: - write(describe(E.Expected[0])); - break; - default: - auto Iter = E.Expected.begin(); - Out << describe(*Iter++); - NodeKind Prev = *Iter++; - while (Iter != E.Expected.end()) { - write(", "); - write(describe(Prev)); - Prev = *Iter++; - } - write(" or "); - write(describe(Prev)); - break; - } - write(" but instead got "); - write(describe(E.Actual)); - write("\n\n"); - writeExcerpt(E.File, E.Actual->getRange(), E.Actual->getRange(), Color::Red); - write("\n"); - return; - } - - case DiagnosticKind::UnexpectedString: - { - auto& E = static_cast(D); - writePrefix(E); - writeLoc(E.File, E.Location); - write(" unexpected '"); - for (auto Chr: E.Actual) { - switch (Chr) { - case '\\': - write("\\\\"); - break; - case '\'': - write("\\'"); - break; - default: - write(Chr); - break; - } - } - write("'\n\n"); - TextRange Range { E.Location, E.Location + E.Actual }; - writeExcerpt(E.File, Range, Range, Color::Red); - write("\n"); - return; - } - - case DiagnosticKind::UnificationError: - { - auto& E = static_cast(D); - auto Left = E.OrigLeft->resolve(E.LeftPath); - auto Right = E.OrigRight->resolve(E.RightPath); - writePrefix(E); - write("the types "); - writeType(Left); - write(" and "); - writeType(Right); - write(" failed to match\n\n"); - setForegroundColor(Color::Yellow); - setBold(true); - write(" info: "); - resetStyles(); - write("due to an equality constraint on "); - write(describe(E.Source->getKind())); - write(":\n\n"); - // write(" - left type "); - // writeType(E.OrigLeft, E.LeftPath); - // write("\n"); - // write(" - right type "); - // writeType(E.OrigRight, E.RightPath); - // write("\n\n"); - writeNode(E.Source); - write("\n"); - // if (E.Left != E.OrigLeft) { - // setForegroundColor(Color::Yellow); - // setBold(true); - // write(" info: "); - // resetStyles(); - // write("the type "); - // writeType(E.Left); - // write(" occurs in the full type "); - // writeType(E.OrigLeft); - // write("\n\n"); - // } - // if (E.Right != E.OrigRight) { - // setForegroundColor(Color::Yellow); - // setBold(true); - // write(" info: "); - // resetStyles(); - // write("the type "); - // writeType(E.Right); - // write(" occurs in the full type "); - // writeType(E.OrigRight); - // write("\n\n"); - // } - return; - } - - case DiagnosticKind::TypeclassMissing: - { - auto& E = static_cast(D); - writePrefix(E); - write("the type class "); - writeTypeclassSignature(E.Sig); - write(" is missing from the declaration's type signature\n\n"); - writeNode(E.Decl); - write("\n\n"); - return; - } - - case DiagnosticKind::InstanceNotFound: - { - auto& E = static_cast(D); - writePrefix(E); - write("a type class instance "); - writeTypeclassName(E.TypeclassName); - write(" "); - writeType(E.Ty); - write(" was not found.\n\n"); - writeNode(E.Source); - write("\n"); - return; - } - - case DiagnosticKind::TupleIndexOutOfRange: - { - auto& E = static_cast(D); - writePrefix(E); - write("the index "); - writeType(E.I); - write(" is out of range for tuple "); - writeType(E.Tuple); - write("\n\n"); - writeNode(E.Source); - write("\n"); - return; - } - - case DiagnosticKind::InvalidTypeToTypeclass: - { - auto& E = static_cast(D); - writePrefix(E); - write("the type "); - writeType(E.Actual); - write(" was applied to type class names "); - bool First = true; - for (auto Class: E.Classes) { - if (First) First = false; - else write(", "); - writeTypeclassName(Class); - } - write(" but this is invalid\n\n"); - return; - } - - case DiagnosticKind::FieldNotFound: - { - auto& E = static_cast(D); - writePrefix(E); - write("the field '"); - write(E.Name); - write("' was required in one type but not found in another\n\n"); - writeNode(E.Source); - write("\n"); - return; - } - - case DiagnosticKind::NotATuple: - { - auto& E = static_cast(D); - writePrefix(E); - write("the type "); - writeType(E.Ty); - write(" is not a tuple.\n\n"); - writeNode(E.Source); - write("\n"); - return; - } - - } - - ZEN_UNREACHABLE - } } + +void ConsolePrinter::write(const std::string_view& S) { + Out << S; +} + +void ConsolePrinter::write(char C) { + Out << C; +} + +void ConsolePrinter::write(std::size_t I) { + Out << I; +} + +void ConsolePrinter::writeBinding(const ByteString& Name) { + write("'"); + write(Name); + write("'"); +} + +void ConsolePrinter::writeType(const Type* Ty) { + TypePath Path; + writeType(Ty, Path); +} + +void ConsolePrinter::writeType(const Type* Ty, const TypePath& Underline) { + + setForegroundColor(Color::Green); + + class TypePrinter : public ConstTypeVisitor { + + TypePath Path; + ConsolePrinter& W; + const TypePath& Underline; + + public: + + TypePrinter(ConsolePrinter& W, const TypePath& Underline): + W(W), Underline(Underline) {} + + bool shouldUnderline() const { + return !Underline.empty() && Path == Underline; + } + + void enterType(const Type* Ty) override { + if (shouldUnderline()) { + W.setUnderline(true); + } + } + + void exitType(const Type* Ty) override { + if (shouldUnderline()) { + W.setUnderline(false); // FIXME Should set to old value + } + } + + void visitAppType(const TApp& Ty) override { + Path.push_back(TypeIndex::forAppOpType()); + visit(Ty.Op); + Path.pop_back(); + W.write(" "); + Path.push_back(TypeIndex::forAppArgType()); + visit(Ty.Arg); + Path.pop_back(); + } + + void visitVarType(const TVar& Ty) override { + if (Ty.isRigid()) { + W.write(*Ty.Name); + return; + } + W.write("a"); + W.write(Ty.Id); + } + + void visitConType(const TCon& Ty) override { + W.write(Ty.DisplayName); + } + + void visitArrowType(const TArrow& Ty) override { + Path.push_back(TypeIndex::forArrowParamType()); + visit(Ty.ParamType); + Path.pop_back(); + W.write(" -> "); + Path.push_back(TypeIndex::forArrowReturnType()); + visit(Ty.ReturnType); + Path.pop_back(); + } + + void visitTupleType(const TTuple& Ty) override { + W.write("("); + if (Ty.ElementTypes.size()) { + auto Iter = Ty.ElementTypes.begin(); + Path.push_back(TypeIndex::forTupleElement(0)); + visit(*Iter++); + Path.pop_back(); + std::size_t I = 1; + while (Iter != Ty.ElementTypes.end()) { + W.write(", "); + Path.push_back(TypeIndex::forTupleElement(I++)); + visit(*Iter++); + Path.pop_back(); + } + } + W.write(")"); + } + + void visitNilType(const TNil& Ty) override { + W.write("{}"); + } + + void visitAbsentType(const TAbsent& Ty) override { + W.write("Abs"); + } + + void visitPresentType(const TPresent& Ty) override { + Path.push_back(TypeIndex::forPresentType()); + visit(Ty.Ty); + Path.pop_back(); + } + + void visitFieldType(const TField& Ty) override { + W.write("{ "); + W.write(Ty.Name); + W.write(": "); + Path.push_back(TypeIndex::forFieldType()); + visit(Ty.Ty); + Path.pop_back(); + auto Ty2 = Ty.RestTy; + Path.push_back(TypeIndex::forFieldRest()); + std::size_t I = 1; + while (Ty2->isField()) { + auto Y = Ty2->asField(); + W.write("; "); + W.write(Y.Name); + W.write(": "); + Path.push_back(TypeIndex::forFieldType()); + visit(Y.Ty); + Path.pop_back(); + Ty2 = Y.RestTy; + Path.push_back(TypeIndex::forFieldRest()); + ++I; + } + if (Ty2->getKind() != TypeKind::Nil) { + W.write("; "); + visit(Ty2); + } + W.write(" }"); + for (auto K = 0; K < I; K++) { + Path.pop_back(); + } + } + + }; + + TypePrinter P { *this, Underline }; + P.visit(Ty); + + resetStyles(); +} + +void ConsolePrinter::writeType(std::size_t I) { + setForegroundColor(Color::Green); + write(I); + resetStyles(); +} + +void ConsolePrinter::writeNode(const Node* N) { + auto Range = N->getRange(); + writeExcerpt(N->getSourceFile()->getTextFile(), Range, Range, Color::Red); +} + +void ConsolePrinter::writeLoc(const TextFile& File, const TextLoc& Loc) { + setForegroundColor(Color::Yellow); + write(File.getPath()); + write(":"); + write(Loc.Line); + write(":"); + write(Loc.Column); + write(":"); + resetStyles(); +} + +void ConsolePrinter::writePrefix(const Diagnostic& D) { + setForegroundColor(Color::Red); + setBold(true); + write("error: "); + resetStyles(); +} + +void ConsolePrinter::writeTypeclassName(const ByteString& Name) { + setForegroundColor(Color::Magenta); + write(Name); + resetStyles(); +} + +void ConsolePrinter::writeTypeclassSignature(const TypeclassSignature& Sig) { + setForegroundColor(Color::Magenta); + write(Sig.Id); + for (auto TV: Sig.Params) { + write(" "); + write(describe(TV)); + } + resetStyles(); +} + +void ConsolePrinter::writeDiagnostic(const Diagnostic& D) { + + switch (D.getKind()) { + + case DiagnosticKind::BindingNotFound: + { + auto& E = static_cast(D); + writePrefix(E); + write("binding "); + writeBinding(E.Name); + write(" was not found\n\n"); + if (E.Initiator != nullptr) { + auto Range = E.Initiator->getRange(); + //std::cerr << Range.Start.Line << ":" << Range.Start.Column << "-" << Range.End.Line << ":" << Range.End.Column << "\n"; + writeExcerpt(E.Initiator->getSourceFile()->getTextFile(), Range, Range, Color::Red); + Out << "\n"; + } + return; + } + + case DiagnosticKind::UnexpectedToken: + { + auto& E = static_cast(D); + writePrefix(E); + writeLoc(E.File, E.Actual->getStartLoc()); + write(" expected "); + switch (E.Expected.size()) { + case 0: + write("nothing"); + break; + case 1: + write(describe(E.Expected[0])); + break; + default: + auto Iter = E.Expected.begin(); + Out << describe(*Iter++); + NodeKind Prev = *Iter++; + while (Iter != E.Expected.end()) { + write(", "); + write(describe(Prev)); + Prev = *Iter++; + } + write(" or "); + write(describe(Prev)); + break; + } + write(" but instead got "); + write(describe(E.Actual)); + write("\n\n"); + writeExcerpt(E.File, E.Actual->getRange(), E.Actual->getRange(), Color::Red); + write("\n"); + return; + } + + case DiagnosticKind::UnexpectedString: + { + auto& E = static_cast(D); + writePrefix(E); + writeLoc(E.File, E.Location); + write(" unexpected '"); + for (auto Chr: E.Actual) { + switch (Chr) { + case '\\': + write("\\\\"); + break; + case '\'': + write("\\'"); + break; + default: + write(Chr); + break; + } + } + write("'\n\n"); + TextRange Range { E.Location, E.Location + E.Actual }; + writeExcerpt(E.File, Range, Range, Color::Red); + write("\n"); + return; + } + + case DiagnosticKind::UnificationError: + { + auto& E = static_cast(D); + auto Left = E.OrigLeft->resolve(E.LeftPath); + auto Right = E.OrigRight->resolve(E.RightPath); + writePrefix(E); + write("the types "); + writeType(Left); + write(" and "); + writeType(Right); + write(" failed to match\n\n"); + setForegroundColor(Color::Yellow); + setBold(true); + write(" info: "); + resetStyles(); + write("due to an equality constraint on "); + write(describe(E.Source->getKind())); + write(":\n\n"); + // write(" - left type "); + // writeType(E.OrigLeft, E.LeftPath); + // write("\n"); + // write(" - right type "); + // writeType(E.OrigRight, E.RightPath); + // write("\n\n"); + writeNode(E.Source); + write("\n"); + // if (E.Left != E.OrigLeft) { + // setForegroundColor(Color::Yellow); + // setBold(true); + // write(" info: "); + // resetStyles(); + // write("the type "); + // writeType(E.Left); + // write(" occurs in the full type "); + // writeType(E.OrigLeft); + // write("\n\n"); + // } + // if (E.Right != E.OrigRight) { + // setForegroundColor(Color::Yellow); + // setBold(true); + // write(" info: "); + // resetStyles(); + // write("the type "); + // writeType(E.Right); + // write(" occurs in the full type "); + // writeType(E.OrigRight); + // write("\n\n"); + // } + return; + } + + case DiagnosticKind::TypeclassMissing: + { + auto& E = static_cast(D); + writePrefix(E); + write("the type class "); + writeTypeclassSignature(E.Sig); + write(" is missing from the declaration's type signature\n\n"); + writeNode(E.Decl); + write("\n\n"); + return; + } + + case DiagnosticKind::InstanceNotFound: + { + auto& E = static_cast(D); + writePrefix(E); + write("a type class instance "); + writeTypeclassName(E.TypeclassName); + write(" "); + writeType(E.Ty); + write(" was not found.\n\n"); + writeNode(E.Source); + write("\n"); + return; + } + + case DiagnosticKind::TupleIndexOutOfRange: + { + auto& E = static_cast(D); + writePrefix(E); + write("the index "); + writeType(E.I); + write(" is out of range for tuple "); + writeType(E.Tuple); + write("\n\n"); + writeNode(E.Source); + write("\n"); + return; + } + + case DiagnosticKind::InvalidTypeToTypeclass: + { + auto& E = static_cast(D); + writePrefix(E); + write("the type "); + writeType(E.Actual); + write(" was applied to type class names "); + bool First = true; + for (auto Class: E.Classes) { + if (First) First = false; + else write(", "); + writeTypeclassName(Class); + } + write(" but this is invalid\n\n"); + return; + } + + case DiagnosticKind::FieldNotFound: + { + auto& E = static_cast(D); + writePrefix(E); + write("the field '"); + write(E.Name); + write("' was required in one type but not found in another\n\n"); + writeNode(E.Source); + write("\n"); + return; + } + + case DiagnosticKind::NotATuple: + { + auto& E = static_cast(D); + writePrefix(E); + write("the type "); + writeType(E.Ty); + write(" is not a tuple.\n\n"); + writeNode(E.Source); + write("\n"); + return; + } + + } + + ZEN_UNREACHABLE + +} + +} diff --git a/bootstrap/cxx/src/Diagnostics.cc b/bootstrap/cxx/src/Diagnostics.cc index ff9aeda35..e0ec04e0b 100644 --- a/bootstrap/cxx/src/Diagnostics.cc +++ b/bootstrap/cxx/src/Diagnostics.cc @@ -38,44 +38,44 @@ namespace bolt { - Diagnostic::Diagnostic(DiagnosticKind Kind): - Kind(Kind) {} +Diagnostic::Diagnostic(DiagnosticKind Kind): + Kind(Kind) {} - bool sourceLocLessThan(const Diagnostic* L, const Diagnostic* R) { - auto N1 = L->getNode(); - auto N2 = R->getNode(); - if (N1 == nullptr && N2 == nullptr) { - return false; - } - if (N1 == nullptr) { - return true; - } - if (N2 == nullptr) { - return false; - } - return N1->getStartLine() < N2->getStartLine() || N1->getStartColumn() < N2->getStartColumn(); - }; - - void DiagnosticStore::sort() { - std::sort(Diagnostics.begin(), Diagnostics.end(), sourceLocLessThan); +bool sourceLocLessThan(const Diagnostic* L, const Diagnostic* R) { + auto N1 = L->getNode(); + auto N2 = R->getNode(); + if (N1 == nullptr && N2 == nullptr) { + return false; } - - DiagnosticStore::~DiagnosticStore() { - for (auto D: Diagnostics) { - delete D; - } + if (N1 == nullptr) { + return true; } + if (N2 == nullptr) { + return false; + } + return N1->getStartLine() < N2->getStartLine() || N1->getStartColumn() < N2->getStartColumn(); +}; - ConsoleDiagnostics::ConsoleDiagnostics(ConsolePrinter& P): - ThePrinter(P) {} +void DiagnosticStore::sort() { + std::sort(Diagnostics.begin(), Diagnostics.end(), sourceLocLessThan); +} - void ConsoleDiagnostics::addDiagnostic(Diagnostic* D) { - - ThePrinter.writeDiagnostic(*D); - - // Since this DiagnosticEngine is expected to own the diagnostic, we simply - // destroy the processed diagnostic so that there are no memory leaks. +DiagnosticStore::~DiagnosticStore() { + for (auto D: Diagnostics) { delete D; } +} + +ConsoleDiagnostics::ConsoleDiagnostics(ConsolePrinter& P): + ThePrinter(P) {} + +void ConsoleDiagnostics::addDiagnostic(Diagnostic* D) { + + ThePrinter.writeDiagnostic(*D); + + // Since this DiagnosticEngine is expected to own the diagnostic, we simply + // destroy the processed diagnostic so that there are no memory leaks. + delete D; +} } diff --git a/bootstrap/cxx/src/Evaluator.cc b/bootstrap/cxx/src/Evaluator.cc index a36156bbe..13817435e 100644 --- a/bootstrap/cxx/src/Evaluator.cc +++ b/bootstrap/cxx/src/Evaluator.cc @@ -6,122 +6,122 @@ namespace bolt { - Value Evaluator::evaluateExpression(Expression* X, Env& Env) { - switch (X->getKind()) { - case NodeKind::ReferenceExpression: - { - auto RE = static_cast(X); - return Env.lookup(getCanonicalText(RE->Name)); - // auto Decl = RE->getScope()->lookup(RE->getSymbolPath()); - // ZEN_ASSERT(Decl && Decl->getKind() == NodeKind::FunctionDeclaration); - // return static_cast(Decl); - } - case NodeKind::LiteralExpression: - { - auto CE = static_cast(X); - switch (CE->Token->getKind()) { - case NodeKind::IntegerLiteral: - return static_cast(CE->Token)->V; - case NodeKind::StringLiteral: - return static_cast(CE->Token)->Text; - default: - ZEN_UNREACHABLE - } - } - case NodeKind::CallExpression: - { - auto CE = static_cast(X); - auto Op = evaluateExpression(CE->Function, Env); - std::vector Args; - for (auto Arg: CE->Args) { - Args.push_back(evaluateExpression(Arg, Env)); - } - return apply(Op, Args); - } - default: - ZEN_UNREACHABLE +Value Evaluator::evaluateExpression(Expression* X, Env& Env) { + switch (X->getKind()) { + case NodeKind::ReferenceExpression: + { + auto RE = static_cast(X); + return Env.lookup(getCanonicalText(RE->Name)); + // auto Decl = RE->getScope()->lookup(RE->getSymbolPath()); + // ZEN_ASSERT(Decl && Decl->getKind() == NodeKind::FunctionDeclaration); + // return static_cast(Decl); } - } - - void Evaluator::assignPattern(Pattern* P, Value& V, Env& E) { - switch (P->getKind()) { - case NodeKind::BindPattern: - { - auto BP = static_cast(P); - E.add(getCanonicalText(BP->Name), V); - break; + case NodeKind::LiteralExpression: + { + auto CE = static_cast(X); + switch (CE->Token->getKind()) { + case NodeKind::IntegerLiteral: + return static_cast(CE->Token)->V; + case NodeKind::StringLiteral: + return static_cast(CE->Token)->Text; + default: + ZEN_UNREACHABLE } - default: - ZEN_UNREACHABLE } - } - - Value Evaluator::apply(Value Op, std::vector Args) { - switch (Op.getKind()) { - case ValueKind::SourceFunction: - { - auto Fn = Op.getDeclaration(); - Env NewEnv; - for (auto [Param, Arg]: zen::zip(Fn->Params, Args)) { - assignPattern(Param->Pattern, Arg, NewEnv); - } - switch (Fn->Body->getKind()) { - case NodeKind::LetExprBody: - return evaluateExpression(static_cast(Fn->Body)->Expression, NewEnv); - default: - ZEN_UNREACHABLE - } + case NodeKind::CallExpression: + { + auto CE = static_cast(X); + auto Op = evaluateExpression(CE->Function, Env); + std::vector Args; + for (auto Arg: CE->Args) { + Args.push_back(evaluateExpression(Arg, Env)); } - case ValueKind::NativeFunction: - { - auto Fn = Op.getBinding(); - return Fn(Args); - } - default: - ZEN_UNREACHABLE + return apply(Op, Args); } + default: + ZEN_UNREACHABLE } +} - void Evaluator::evaluate(Node* N, Env& E) { - switch (N->getKind()) { - case NodeKind::SourceFile: - { - auto SF = static_cast(N); - for (auto Element: SF->Elements) { - evaluate(Element, E); - } - break; +void Evaluator::assignPattern(Pattern* P, Value& V, Env& E) { + switch (P->getKind()) { + case NodeKind::BindPattern: + { + auto BP = static_cast(P); + E.add(getCanonicalText(BP->Name), V); + break; + } + default: + ZEN_UNREACHABLE + } +} + +Value Evaluator::apply(Value Op, std::vector Args) { + switch (Op.getKind()) { + case ValueKind::SourceFunction: + { + auto Fn = Op.getDeclaration(); + Env NewEnv; + for (auto [Param, Arg]: zen::zip(Fn->Params, Args)) { + assignPattern(Param->Pattern, Arg, NewEnv); } - case NodeKind::ExpressionStatement: - { - auto ES = static_cast(N); - evaluateExpression(ES->Expression, E); - break; + switch (Fn->Body->getKind()) { + case NodeKind::LetExprBody: + return evaluateExpression(static_cast(Fn->Body)->Expression, NewEnv); + default: + ZEN_UNREACHABLE } - case NodeKind::LetDeclaration: - { - auto Decl = static_cast(N); - if (Decl->isFunction()) { - E.add(Decl->getNameAsString(), Decl); - } else { - Value V; - if (Decl->Body) { - switch (Decl->Body->getKind()) { - case NodeKind::LetExprBody: - { - auto Body = static_cast(Decl->Body); - V = evaluateExpression(Body->Expression, E); - } - default: - ZEN_UNREACHABLE + } + case ValueKind::NativeFunction: + { + auto Fn = Op.getBinding(); + return Fn(Args); + } + default: + ZEN_UNREACHABLE + } +} + +void Evaluator::evaluate(Node* N, Env& E) { + switch (N->getKind()) { + case NodeKind::SourceFile: + { + auto SF = static_cast(N); + for (auto Element: SF->Elements) { + evaluate(Element, E); + } + break; + } + case NodeKind::ExpressionStatement: + { + auto ES = static_cast(N); + evaluateExpression(ES->Expression, E); + break; + } + case NodeKind::LetDeclaration: + { + auto Decl = static_cast(N); + if (Decl->isFunction()) { + E.add(Decl->getNameAsString(), Decl); + } else { + Value V; + if (Decl->Body) { + switch (Decl->Body->getKind()) { + case NodeKind::LetExprBody: + { + auto Body = static_cast(Decl->Body); + V = evaluateExpression(Body->Expression, E); } + default: + ZEN_UNREACHABLE } } - break; } - default: - ZEN_UNREACHABLE + break; } + default: + ZEN_UNREACHABLE } +} } diff --git a/bootstrap/cxx/src/Parser.cc b/bootstrap/cxx/src/Parser.cc index 52f180a69..3d4b77429 100644 --- a/bootstrap/cxx/src/Parser.cc +++ b/bootstrap/cxx/src/Parser.cc @@ -32,199 +32,219 @@ namespace bolt { - bool isOperator(Token* T) { - switch (T->getKind()) { - case NodeKind::VBar: - case NodeKind::CustomOperator: - return true; - default: - return false; - } +bool isOperator(Token* T) { + switch (T->getKind()) { + case NodeKind::VBar: + case NodeKind::CustomOperator: + return true; + default: + return false; + } +} + +std::optional OperatorTable::getInfix(Token* T) { + auto Match = Mapping.find(T->getText()); + if (Match == Mapping.end() || !Match->second.isInfix()) { + return {}; + } + return Match->second; +} + +bool OperatorTable::isInfix(Token* T) { + auto Match = Mapping.find(T->getText()); + return Match != Mapping.end() && Match->second.isInfix(); +} + +bool OperatorTable::isPrefix(Token* T) { + auto Match = Mapping.find(T->getText()); + return Match != Mapping.end() && Match->second.isPrefix(); +} + +bool OperatorTable::isSuffix(Token* T) { + auto Match = Mapping.find(T->getText()); + return Match != Mapping.end() && Match->second.isSuffix(); +} + +void OperatorTable::add(std::string Name, unsigned Flags, int Precedence) { + Mapping.emplace(Name, OperatorInfo { Precedence, Flags }); +} + +Parser::Parser(TextFile& File, Stream& S, DiagnosticEngine& DE): + File(File), Tokens(S), DE(DE) { + ExprOperators.add("**", OperatorFlags_InfixR, 10); + ExprOperators.add("*", OperatorFlags_InfixL, 5); + ExprOperators.add("/", OperatorFlags_InfixL, 5); + ExprOperators.add("+", OperatorFlags_InfixL, 4); + ExprOperators.add("-", OperatorFlags_InfixL, 4); + ExprOperators.add("<", OperatorFlags_InfixL, 3); + ExprOperators.add(">", OperatorFlags_InfixL, 3); + ExprOperators.add("<=", OperatorFlags_InfixL, 3); + ExprOperators.add(">=", OperatorFlags_InfixL, 3); + ExprOperators.add("==", OperatorFlags_InfixL, 3); + ExprOperators.add("!=", OperatorFlags_InfixL, 3); + ExprOperators.add(":", OperatorFlags_InfixL, 2); + ExprOperators.add("<|>", OperatorFlags_InfixL, 1); + ExprOperators.add("$", OperatorFlags_InfixR, 0); } - std::optional OperatorTable::getInfix(Token* T) { - auto Match = Mapping.find(T->getText()); - if (Match == Mapping.end() || !Match->second.isInfix()) { - return {}; - } - return Match->second; - } - - bool OperatorTable::isInfix(Token* T) { - auto Match = Mapping.find(T->getText()); - return Match != Mapping.end() && Match->second.isInfix(); - } - - bool OperatorTable::isPrefix(Token* T) { - auto Match = Mapping.find(T->getText()); - return Match != Mapping.end() && Match->second.isPrefix(); - } - - bool OperatorTable::isSuffix(Token* T) { - auto Match = Mapping.find(T->getText()); - return Match != Mapping.end() && Match->second.isSuffix(); - } - - void OperatorTable::add(std::string Name, unsigned Flags, int Precedence) { - Mapping.emplace(Name, OperatorInfo { Precedence, Flags }); - } - - Parser::Parser(TextFile& File, Stream& S, DiagnosticEngine& DE): - File(File), Tokens(S), DE(DE) { - ExprOperators.add("**", OperatorFlags_InfixR, 10); - ExprOperators.add("*", OperatorFlags_InfixL, 5); - ExprOperators.add("/", OperatorFlags_InfixL, 5); - ExprOperators.add("+", OperatorFlags_InfixL, 4); - ExprOperators.add("-", OperatorFlags_InfixL, 4); - ExprOperators.add("<", OperatorFlags_InfixL, 3); - ExprOperators.add(">", OperatorFlags_InfixL, 3); - ExprOperators.add("<=", OperatorFlags_InfixL, 3); - ExprOperators.add(">=", OperatorFlags_InfixL, 3); - ExprOperators.add("==", OperatorFlags_InfixL, 3); - ExprOperators.add("!=", OperatorFlags_InfixL, 3); - ExprOperators.add(":", OperatorFlags_InfixL, 2); - ExprOperators.add("<|>", OperatorFlags_InfixL, 1); - ExprOperators.add("$", OperatorFlags_InfixR, 0); - } - - Token* Parser::peekFirstTokenAfterAnnotationsAndModifiers() { - std::size_t I = 0; - for (;;) { - auto T0 = Tokens.peek(I++); - switch (T0->getKind()) { - case NodeKind::PubKeyword: - case NodeKind::MutKeyword: - continue; - case NodeKind::At: - for (;;) { - auto T1 = Tokens.peek(I++); - if (T1->getKind() == NodeKind::LineFoldEnd) { - break; - } +Token* Parser::peekFirstTokenAfterAnnotationsAndModifiers() { + std::size_t I = 0; + for (;;) { + auto T0 = Tokens.peek(I++); + switch (T0->getKind()) { + case NodeKind::PubKeyword: + case NodeKind::MutKeyword: + continue; + case NodeKind::At: + for (;;) { + auto T1 = Tokens.peek(I++); + if (T1->getKind() == NodeKind::LineFoldEnd) { + break; } - continue; - default: - return T0; - } + } + continue; + default: + return T0; } } +} - Token* Parser::expectToken(NodeKind Kind) { - auto T = Tokens.peek(); - if (T->getKind() != Kind) { - DE.add(File, T, std::vector { Kind }); - return nullptr; - } +Token* Parser::expectToken(NodeKind Kind) { + auto T = Tokens.peek(); + if (T->getKind() != Kind) { + DE.add(File, T, std::vector { Kind }); + return nullptr; + } + Tokens.get(); + return T; +} + +ListPattern* Parser::parseListPattern() { + auto LBracket = expectToken(); + if (!LBracket) { + return nullptr; + } + std::vector> Elements; + RBracket* RBracket; + auto T0 = Tokens.peek(); + if (T0->getKind() == NodeKind::RBracket) { Tokens.get(); - return T; + RBracket = static_cast(T0); + goto finish; } - - ListPattern* Parser::parseListPattern() { - auto LBracket = expectToken(); - if (!LBracket) { + for (;;) { + auto P = parseWidePattern(); + if (!P) { + LBracket->unref(); + for (auto [Element, Separator]: Elements) { + Element->unref(); + Separator->unref(); + } return nullptr; } - std::vector> Elements; - RBracket* RBracket; - auto T0 = Tokens.peek(); - if (T0->getKind() == NodeKind::RBracket) { - Tokens.get(); - RBracket = static_cast(T0); - goto finish; + auto T1 = Tokens.peek(); + switch (T1->getKind()) { + case NodeKind::Comma: + Tokens.get(); + Elements.push_back(std::make_tuple(P, static_cast(T1))); + break; + case NodeKind::RBracket: + Tokens.get(); + Elements.push_back(std::make_tuple(P, nullptr)); + RBracket = static_cast(T1); + goto finish; + default: + DE.add(File, T1, std::vector { NodeKind::Comma, NodeKind::RBracket }); } - for (;;) { - auto P = parseWidePattern(); - if (!P) { - LBracket->unref(); - for (auto [Element, Separator]: Elements) { - Element->unref(); - Separator->unref(); - } - return nullptr; - } - auto T1 = Tokens.peek(); - switch (T1->getKind()) { - case NodeKind::Comma: - Tokens.get(); - Elements.push_back(std::make_tuple(P, static_cast(T1))); - break; - case NodeKind::RBracket: - Tokens.get(); - Elements.push_back(std::make_tuple(P, nullptr)); - RBracket = static_cast(T1); - goto finish; - default: - DE.add(File, T1, std::vector { NodeKind::Comma, NodeKind::RBracket }); - } - } -finish: - return new ListPattern { LBracket, Elements, RBracket }; } +finish: + return new ListPattern { LBracket, Elements, RBracket }; +} - std::optional>> Parser::parseRecordPatternFields() { - std::vector> Fields; - for (;;) { - auto T0 = Tokens.peek(); - if (T0->getKind() == NodeKind::RBrace) { - break; - } - if (T0->getKind() == NodeKind::DotDot) { - Tokens.get(); - auto DotDot = static_cast(T0); - auto T1 = Tokens.peek(); - if (T1->getKind() == NodeKind::RBrace) { - Fields.push_back(std::make_tuple(new RecordPatternField(DotDot), nullptr)); - break; - } - auto P = parseWidePattern(); - auto T2 = Tokens.peek(); - if (T2->getKind() != NodeKind::RBrace) { - DE.add(File, T2, std::vector { NodeKind::RBrace, NodeKind::Comma }); - return {}; - } - Fields.push_back(std::make_tuple(new RecordPatternField(DotDot, P), nullptr)); - break; - } - auto Name = expectToken(); - Equals* Equals = nullptr; - Pattern* Pattern = nullptr; +std::optional>> Parser::parseRecordPatternFields() { + std::vector> Fields; + for (;;) { + auto T0 = Tokens.peek(); + if (T0->getKind() == NodeKind::RBrace) { + break; + } + if (T0->getKind() == NodeKind::DotDot) { + Tokens.get(); + auto DotDot = static_cast(T0); auto T1 = Tokens.peek(); - if (T1->getKind() == NodeKind::Equals) { - Tokens.get(); - Equals = static_cast(T1); - Pattern = parseWidePattern(); - } - auto Field = new RecordPatternField(Name, Equals, Pattern); - auto T2 = Tokens.peek(); - if (T2->getKind() == NodeKind::RBrace) { - Fields.push_back(std::make_tuple(Field, nullptr)); + if (T1->getKind() == NodeKind::RBrace) { + Fields.push_back(std::make_tuple(new RecordPatternField(DotDot), nullptr)); break; } - if (T2->getKind() != NodeKind::Comma) { + auto P = parseWidePattern(); + auto T2 = Tokens.peek(); + if (T2->getKind() != NodeKind::RBrace) { DE.add(File, T2, std::vector { NodeKind::RBrace, NodeKind::Comma }); return {}; } - Tokens.get(); - auto Comma = static_cast(T2); - Fields.push_back(std::make_tuple(Field, Comma)); + Fields.push_back(std::make_tuple(new RecordPatternField(DotDot, P), nullptr)); + break; } - return Fields; + auto Name = expectToken(); + Equals* Equals = nullptr; + Pattern* Pattern = nullptr; + auto T1 = Tokens.peek(); + if (T1->getKind() == NodeKind::Equals) { + Tokens.get(); + Equals = static_cast(T1); + Pattern = parseWidePattern(); + } + auto Field = new RecordPatternField(Name, Equals, Pattern); + auto T2 = Tokens.peek(); + if (T2->getKind() == NodeKind::RBrace) { + Fields.push_back(std::make_tuple(Field, nullptr)); + break; + } + if (T2->getKind() != NodeKind::Comma) { + DE.add(File, T2, std::vector { NodeKind::RBrace, NodeKind::Comma }); + return {}; + } + Tokens.get(); + auto Comma = static_cast(T2); + Fields.push_back(std::make_tuple(Field, Comma)); } + return Fields; +} - Pattern* Parser::parsePrimitivePattern(bool IsNarrow) { - auto T0 = Tokens.peek(); - switch (T0->getKind()) { - case NodeKind::StringLiteral: - case NodeKind::IntegerLiteral: +Pattern* Parser::parsePrimitivePattern(bool IsNarrow) { + auto T0 = Tokens.peek(); + switch (T0->getKind()) { + case NodeKind::StringLiteral: + case NodeKind::IntegerLiteral: + Tokens.get(); + return new LiteralPattern(static_cast(T0)); + case NodeKind::Identifier: + Tokens.get(); + return new BindPattern(static_cast(T0)); + case NodeKind::LBrace: + { + Tokens.get(); + auto LBrace = static_cast(T0); + auto Fields = parseRecordPatternFields(); + if (!Fields) { + LBrace->unref(); + skipToRBrace(); + return nullptr; + } + auto RBrace = static_cast(Tokens.get()); + return new RecordPattern(LBrace, *Fields, RBrace); + } + case NodeKind::IdentifierAlt: + { + Tokens.get(); + auto Name = static_cast(T0); + if (IsNarrow) { + return new NamedTuplePattern(Name, {}); + } + auto T1 = Tokens.peek(); + if (T1->getKind() == NodeKind::LBrace) { + auto LBrace = static_cast(T1); Tokens.get(); - return new LiteralPattern(static_cast(T0)); - case NodeKind::Identifier: - Tokens.get(); - return new BindPattern(static_cast(T0)); - case NodeKind::LBrace: - { - Tokens.get(); - auto LBrace = static_cast(T0); auto Fields = parseRecordPatternFields(); if (!Fields) { LBrace->unref(); @@ -232,1224 +252,1192 @@ finish: return nullptr; } auto RBrace = static_cast(Tokens.get()); - return new RecordPattern(LBrace, *Fields, RBrace); + return new NamedRecordPattern({}, Name, LBrace, *Fields, RBrace); } - case NodeKind::IdentifierAlt: - { - Tokens.get(); - auto Name = static_cast(T0); - if (IsNarrow) { - return new NamedTuplePattern(Name, {}); + std::vector Patterns; + for (;;) { + auto T2 = Tokens.peek(); + if (T2->getKind() == NodeKind::RParen + || T2->getKind() == NodeKind::RBracket + || T2->getKind() == NodeKind::RBrace + || T2->getKind() == NodeKind::Comma + || T2->getKind() == NodeKind::Colon + || T2->getKind() == NodeKind::Equals + || T2->getKind() == NodeKind::BlockStart + || T2->getKind() == NodeKind::RArrowAlt) { + break; + } + auto P = parseNarrowPattern(); + if (!P) { + Name->unref(); + for (auto P: Patterns) { + P->unref(); + } + return nullptr; + } + Patterns.push_back(P); + } + return new NamedTuplePattern { Name, Patterns }; + } + case NodeKind::LBracket: + return parseListPattern(); + case NodeKind::LParen: + { + Tokens.get(); + auto LParen = static_cast(T0); + std::vector> Elements; + RParen* RParen; + for (;;) { + auto P = parseWidePattern(); + if (!P) { + LParen->unref(); + for (auto [P, Comma]: Elements) { + P->unref(); + Comma->unref(); + } + // TODO maybe skip to next comma? + return nullptr; } auto T1 = Tokens.peek(); - if (T1->getKind() == NodeKind::LBrace) { - auto LBrace = static_cast(T1); + if (T1->getKind() == NodeKind::Comma) { Tokens.get(); - auto Fields = parseRecordPatternFields(); - if (!Fields) { - LBrace->unref(); - skipToRBrace(); - return nullptr; + Elements.push_back(std::make_tuple(P, static_cast(T1))); + } else if (T1->getKind() == NodeKind::RParen) { + Tokens.get(); + RParen = static_cast(T1); + Elements.push_back(std::make_tuple(P, nullptr)); + break; + } else { + DE.add(File, T1, std::vector { NodeKind::Comma, NodeKind::RParen }); + LParen->unref(); + for (auto [P, Comma]: Elements) { + P->unref(); + Comma->unref(); } - auto RBrace = static_cast(Tokens.get()); - return new NamedRecordPattern({}, Name, LBrace, *Fields, RBrace); - } - std::vector Patterns; - for (;;) { - auto T2 = Tokens.peek(); - if (T2->getKind() == NodeKind::RParen - || T2->getKind() == NodeKind::RBracket - || T2->getKind() == NodeKind::RBrace - || T2->getKind() == NodeKind::Comma - || T2->getKind() == NodeKind::Colon - || T2->getKind() == NodeKind::Equals - || T2->getKind() == NodeKind::BlockStart - || T2->getKind() == NodeKind::RArrowAlt) { - break; - } - auto P = parseNarrowPattern(); - if (!P) { - Name->unref(); - for (auto P: Patterns) { - P->unref(); - } - return nullptr; - } - Patterns.push_back(P); - } - return new NamedTuplePattern { Name, Patterns }; - } - case NodeKind::LBracket: - return parseListPattern(); - case NodeKind::LParen: - { - Tokens.get(); - auto LParen = static_cast(T0); - std::vector> Elements; - RParen* RParen; - for (;;) { - auto P = parseWidePattern(); - if (!P) { - LParen->unref(); - for (auto [P, Comma]: Elements) { - P->unref(); - Comma->unref(); - } - // TODO maybe skip to next comma? - return nullptr; - } - auto T1 = Tokens.peek(); - if (T1->getKind() == NodeKind::Comma) { - Tokens.get(); - Elements.push_back(std::make_tuple(P, static_cast(T1))); - } else if (T1->getKind() == NodeKind::RParen) { - Tokens.get(); - RParen = static_cast(T1); - Elements.push_back(std::make_tuple(P, nullptr)); - break; - } else { - DE.add(File, T1, std::vector { NodeKind::Comma, NodeKind::RParen }); - LParen->unref(); - for (auto [P, Comma]: Elements) { - P->unref(); - Comma->unref(); - } - // TODO maybe skip to next comma? - return nullptr; + // TODO maybe skip to next comma? + return nullptr; - } } - if (Elements.size() == 1) { - return new NestedPattern { LParen, std::get<0>(Elements.front()), RParen }; - } - return new TuplePattern(LParen, Elements, RParen); } + if (Elements.size() == 1) { + return new NestedPattern { LParen, std::get<0>(Elements.front()), RParen }; + } + return new TuplePattern(LParen, Elements, RParen); + } + default: + // Tokens.get(); + DE.add(File, T0, std::vector { + NodeKind::Identifier, + NodeKind::IdentifierAlt, + NodeKind::StringLiteral, + NodeKind::IntegerLiteral, + NodeKind::LParen, + NodeKind::LBracket + }); + return nullptr; + } +} + +Pattern* Parser::parseWidePattern() { + return parsePrimitivePattern(false); +} + +Pattern* Parser::parseNarrowPattern() { + return parsePrimitivePattern(true); +} + +TypeExpression* Parser::parseTypeExpression() { + return parseQualifiedTypeExpression(); +} + +TypeExpression* Parser::parseQualifiedTypeExpression() { + bool HasConstraints = false; + auto T0 = Tokens.peek(); + if (isa(T0)) { + std::size_t I = 1; + for (;;) { + auto T0 = Tokens.peek(I++); + switch (T0->getKind()) { + case NodeKind::RArrowAlt: + HasConstraints = true; + goto after_lookahead; + case NodeKind::Equals: + case NodeKind::BlockStart: + case NodeKind::LineFoldEnd: + case NodeKind::EndOfFile: + goto after_lookahead; + default: + break; + } + } + } +after_lookahead: + if (!HasConstraints) { + return parseArrowTypeExpression(); + } + Tokens.get(); + LParen* LParen = static_cast(T0); + std::vector> Constraints; + RParen* RParen; + RArrowAlt* RArrowAlt; + auto T1 = Tokens.peek(); + if (T1->getKind() == NodeKind::RParen) { + Tokens.get(); + RParen = static_cast(T1); + goto after_constraints; + } + for (;;) { + auto C = parseConstraintExpression(); + Comma* Comma = nullptr; + auto T2 = Tokens.get(); + switch (T2->getKind()) { + case NodeKind::Comma: + { + auto Comma = static_cast(T2); + if (C) { + Constraints.push_back(std::make_tuple(C, Comma)); + } else { + Comma->unref(); + } + continue; + } + case NodeKind::RParen: + RParen = static_cast(T2); + if (C) { + Constraints.push_back(std::make_tuple(C, nullptr)); + } + goto after_constraints; default: - // Tokens.get(); - DE.add(File, T0, std::vector { - NodeKind::Identifier, - NodeKind::IdentifierAlt, - NodeKind::StringLiteral, - NodeKind::IntegerLiteral, - NodeKind::LParen, - NodeKind::LBracket - }); + DE.add(File, T2, std::vector { NodeKind::Comma, NodeKind::RArrowAlt }); return nullptr; } } - - Pattern* Parser::parseWidePattern() { - return parsePrimitivePattern(false); - } - - Pattern* Parser::parseNarrowPattern() { - return parsePrimitivePattern(true); - } - - TypeExpression* Parser::parseTypeExpression() { - return parseQualifiedTypeExpression(); - } - - TypeExpression* Parser::parseQualifiedTypeExpression() { - bool HasConstraints = false; - auto T0 = Tokens.peek(); - if (isa(T0)) { - std::size_t I = 1; - for (;;) { - auto T0 = Tokens.peek(I++); - switch (T0->getKind()) { - case NodeKind::RArrowAlt: - HasConstraints = true; - goto after_lookahead; - case NodeKind::Equals: - case NodeKind::BlockStart: - case NodeKind::LineFoldEnd: - case NodeKind::EndOfFile: - goto after_lookahead; - default: - break; - } - } - } -after_lookahead: - if (!HasConstraints) { - return parseArrowTypeExpression(); - } - Tokens.get(); - LParen* LParen = static_cast(T0); - std::vector> Constraints; - RParen* RParen; - RArrowAlt* RArrowAlt; - auto T1 = Tokens.peek(); - if (T1->getKind() == NodeKind::RParen) { - Tokens.get(); - RParen = static_cast(T1); - goto after_constraints; - } - for (;;) { - auto C = parseConstraintExpression(); - Comma* Comma = nullptr; - auto T2 = Tokens.get(); - switch (T2->getKind()) { - case NodeKind::Comma: - { - auto Comma = static_cast(T2); - if (C) { - Constraints.push_back(std::make_tuple(C, Comma)); - } else { - Comma->unref(); - } - continue; - } - case NodeKind::RParen: - RParen = static_cast(T2); - if (C) { - Constraints.push_back(std::make_tuple(C, nullptr)); - } - goto after_constraints; - default: - DE.add(File, T2, std::vector { NodeKind::Comma, NodeKind::RArrowAlt }); - return nullptr; - } - } after_constraints: - RArrowAlt = expectToken(); - if (!RArrowAlt) { - LParen->unref(); - for (auto [CE, Comma]: Constraints) { - CE->unref(); - } - RParen->unref(); - return nullptr; + RArrowAlt = expectToken(); + if (!RArrowAlt) { + LParen->unref(); + for (auto [CE, Comma]: Constraints) { + CE->unref(); } - auto TE = parseArrowTypeExpression(); - if (!TE) { - LParen->unref(); - for (auto [CE, Comma]: Constraints) { - CE->unref(); - if (Comma) { - Comma->unref(); - } - } - RParen->unref(); - RArrowAlt->unref(); - return nullptr; - } - return new QualifiedTypeExpression(Constraints, RArrowAlt, TE); + RParen->unref(); + return nullptr; } + auto TE = parseArrowTypeExpression(); + if (!TE) { + LParen->unref(); + for (auto [CE, Comma]: Constraints) { + CE->unref(); + if (Comma) { + Comma->unref(); + } + } + RParen->unref(); + RArrowAlt->unref(); + return nullptr; + } + return new QualifiedTypeExpression(Constraints, RArrowAlt, TE); +} - TypeExpression* Parser::parsePrimitiveTypeExpression() { - auto T0 = Tokens.peek(); - switch (T0->getKind()) { - case NodeKind::Identifier: - return parseVarTypeExpression(); - case NodeKind::LBrace: - { - Tokens.get(); - auto LBrace = static_cast(T0); - std::vector> Fields; - VBar* VBar = nullptr; - TypeExpression* Rest = nullptr; - for (;;) { - auto T1 = Tokens.peek(); - if (T1->getKind() == NodeKind::RBrace) { - break; - } - auto Name = expectToken(); - if (Name == nullptr) { - for (auto [Field, Comma]: Fields) { - Field->unref(); - Comma->unref(); - } - return nullptr; - } - auto Colon = expectToken(); - if (Colon == nullptr) { - for (auto [Field, Comma]: Fields) { - Field->unref(); - Comma->unref(); - } - Name->unref(); - return nullptr; - } - auto TE = parseTypeExpression(); - if (TE == nullptr) { - for (auto [Field, Comma]: Fields) { - Field->unref(); - Comma->unref(); - } - Name->unref(); - Colon->unref(); - return nullptr; - } - auto Field = new RecordTypeExpressionField(Name, Colon, TE); - auto T3 = Tokens.peek(); - if (T3->getKind() == NodeKind::RBrace) { - Fields.push_back(std::make_tuple(Field, nullptr)); - break; - } - if (T3->getKind() == NodeKind::VBar) { - Tokens.get(); - Fields.push_back(std::make_tuple(Field, nullptr)); - VBar = static_cast(T3); - Rest = parseTypeExpression(); - if (!Rest) { - for (auto [Field, Comma]: Fields) { - Field->unref(); - Comma->unref(); - } - Field->unref(); - return nullptr; - } - auto T4 = Tokens.peek(); - if (T4->getKind() != NodeKind::RBrace) { - for (auto [Field, Comma]: Fields) { - Field->unref(); - Comma->unref(); - } - Field->unref(); - Rest->unref(); - DE.add(File, T4, std::vector { NodeKind::RBrace }); - return nullptr; - } - break; - } - if (T3->getKind() == NodeKind::Comma) { - Tokens.get(); - auto Comma = static_cast(T3); - Fields.push_back(std::make_tuple(Field, Comma)); - continue; - } - DE.add(File, T3, std::vector { NodeKind::RBrace, NodeKind::Comma, NodeKind::VBar }); +TypeExpression* Parser::parsePrimitiveTypeExpression() { + auto T0 = Tokens.peek(); + switch (T0->getKind()) { + case NodeKind::Identifier: + return parseVarTypeExpression(); + case NodeKind::LBrace: + { + Tokens.get(); + auto LBrace = static_cast(T0); + std::vector> Fields; + VBar* VBar = nullptr; + TypeExpression* Rest = nullptr; + for (;;) { + auto T1 = Tokens.peek(); + if (T1->getKind() == NodeKind::RBrace) { + break; + } + auto Name = expectToken(); + if (Name == nullptr) { for (auto [Field, Comma]: Fields) { Field->unref(); Comma->unref(); } - Field->unref(); return nullptr; } - auto RBrace = static_cast(Tokens.get()); - return new RecordTypeExpression(LBrace, Fields, VBar, Rest, RBrace); - } - case NodeKind::LParen: - { - Tokens.get(); - auto LParen = static_cast(T0); - std::vector> Elements; - RParen* RParen; - for (;;) { - auto T1 = Tokens.peek(); - if (isa(T1)) { - Tokens.get(); - RParen = static_cast(T1); - break; + auto Colon = expectToken(); + if (Colon == nullptr) { + for (auto [Field, Comma]: Fields) { + Field->unref(); + Comma->unref(); } - auto TE = parseTypeExpression(); - if (!TE) { + Name->unref(); + return nullptr; + } + auto TE = parseTypeExpression(); + if (TE == nullptr) { + for (auto [Field, Comma]: Fields) { + Field->unref(); + Comma->unref(); + } + Name->unref(); + Colon->unref(); + return nullptr; + } + auto Field = new RecordTypeExpressionField(Name, Colon, TE); + auto T3 = Tokens.peek(); + if (T3->getKind() == NodeKind::RBrace) { + Fields.push_back(std::make_tuple(Field, nullptr)); + break; + } + if (T3->getKind() == NodeKind::VBar) { + Tokens.get(); + Fields.push_back(std::make_tuple(Field, nullptr)); + VBar = static_cast(T3); + Rest = parseTypeExpression(); + if (!Rest) { + for (auto [Field, Comma]: Fields) { + Field->unref(); + Comma->unref(); + } + Field->unref(); + return nullptr; + } + auto T4 = Tokens.peek(); + if (T4->getKind() != NodeKind::RBrace) { + for (auto [Field, Comma]: Fields) { + Field->unref(); + Comma->unref(); + } + Field->unref(); + Rest->unref(); + DE.add(File, T4, std::vector { NodeKind::RBrace }); + return nullptr; + } + break; + } + if (T3->getKind() == NodeKind::Comma) { + Tokens.get(); + auto Comma = static_cast(T3); + Fields.push_back(std::make_tuple(Field, Comma)); + continue; + } + DE.add(File, T3, std::vector { NodeKind::RBrace, NodeKind::Comma, NodeKind::VBar }); + for (auto [Field, Comma]: Fields) { + Field->unref(); + Comma->unref(); + } + Field->unref(); + return nullptr; + } + auto RBrace = static_cast(Tokens.get()); + return new RecordTypeExpression(LBrace, Fields, VBar, Rest, RBrace); + } + case NodeKind::LParen: + { + Tokens.get(); + auto LParen = static_cast(T0); + std::vector> Elements; + RParen* RParen; + for (;;) { + auto T1 = Tokens.peek(); + if (isa(T1)) { + Tokens.get(); + RParen = static_cast(T1); + break; + } + auto TE = parseTypeExpression(); + if (!TE) { + LParen->unref(); + for (auto [TE, Comma]: Elements) { + TE->unref(); + Comma->unref(); + } + return nullptr; + } + auto T2 = Tokens.get(); + switch (T2->getKind()) { + case NodeKind::RParen: + RParen = static_cast(T1); + Elements.push_back({ TE, nullptr }); + goto after_tuple_element; + case NodeKind::Comma: + Elements.push_back({ TE, static_cast(T2) }); + continue; + default: + DE.add(File, T2, std::vector { NodeKind::Comma, NodeKind::RParen }); LParen->unref(); for (auto [TE, Comma]: Elements) { TE->unref(); Comma->unref(); } return nullptr; - } - auto T2 = Tokens.get(); - switch (T2->getKind()) { - case NodeKind::RParen: - RParen = static_cast(T1); - Elements.push_back({ TE, nullptr }); - goto after_tuple_element; - case NodeKind::Comma: - Elements.push_back({ TE, static_cast(T2) }); - continue; - default: - DE.add(File, T2, std::vector { NodeKind::Comma, NodeKind::RParen }); - LParen->unref(); - for (auto [TE, Comma]: Elements) { - TE->unref(); - Comma->unref(); - } - return nullptr; - } } -after_tuple_element: - if (Elements.size() == 1) { - return new NestedTypeExpression { LParen, std::get<0>(Elements.front()), RParen }; - } - return new TupleTypeExpression { LParen, Elements, RParen }; } - case NodeKind::IdentifierAlt: - return parseReferenceTypeExpression(); - default: - // Tokens.get(); - DE.add(File, T0, std::vector { NodeKind::Identifier, NodeKind::IdentifierAlt, NodeKind::LParen }); - return nullptr; +after_tuple_element: + if (Elements.size() == 1) { + return new NestedTypeExpression { LParen, std::get<0>(Elements.front()), RParen }; + } + return new TupleTypeExpression { LParen, Elements, RParen }; } + case NodeKind::IdentifierAlt: + return parseReferenceTypeExpression(); + default: + // Tokens.get(); + DE.add(File, T0, std::vector { NodeKind::Identifier, NodeKind::IdentifierAlt, NodeKind::LParen }); + return nullptr; } +} - ReferenceTypeExpression* Parser::parseReferenceTypeExpression() { - std::vector> ModulePath; - auto Name = expectToken(); +ReferenceTypeExpression* Parser::parseReferenceTypeExpression() { + std::vector> ModulePath; + auto Name = expectToken(); + if (!Name) { + return nullptr; + } + for (;;) { + auto T1 = Tokens.peek(); + if (T1->getKind() != NodeKind::Dot) { + break; + } + Tokens.get(); + ModulePath.push_back(std::make_tuple(static_cast(Name), static_cast(T1))); + Name = expectToken(); if (!Name) { + for (auto [Name, Dot]: ModulePath) { + Name->unref(); + Dot->unref(); + } return nullptr; } - for (;;) { - auto T1 = Tokens.peek(); - if (T1->getKind() != NodeKind::Dot) { - break; + } + return new ReferenceTypeExpression(ModulePath, static_cast(Name)); +} + +TypeExpression* Parser::parseAppTypeExpression() { + auto OpTy = parsePrimitiveTypeExpression(); + if (!OpTy) { + return nullptr; + } + std::vector ArgTys; + for (;;) { + auto T1 = Tokens.peek(); + auto Kind = T1->getKind(); + if (Kind == NodeKind::Comma + || Kind == NodeKind::RArrow + || Kind == NodeKind::Equals + || Kind == NodeKind::BlockStart + || Kind == NodeKind::LineFoldEnd + || Kind == NodeKind::EndOfFile + || Kind == NodeKind::RParen + || Kind == NodeKind::RBracket + || Kind == NodeKind::RBrace + || Kind == NodeKind::VBar) { + break; + } + auto TE = parsePrimitiveTypeExpression(); + if (!TE) { + OpTy->unref(); + for (auto Arg: ArgTys) { + Arg->unref(); } - Tokens.get(); - ModulePath.push_back(std::make_tuple(static_cast(Name), static_cast(T1))); - Name = expectToken(); + return nullptr; + } + ArgTys.push_back(TE); + } + if (ArgTys.empty()) { + return OpTy; + } + return new AppTypeExpression { OpTy, ArgTys }; +} + +TypeExpression* Parser::parseArrowTypeExpression() { + auto RetType = parseAppTypeExpression(); + if (RetType == nullptr) { + return nullptr; + } + std::vector ParamTypes; + for (;;) { + auto T1 = Tokens.peek(); + if (T1->getKind() != NodeKind::RArrow) { + break; + } + Tokens.get(); + ParamTypes.push_back(RetType); + RetType = parseAppTypeExpression(); + if (!RetType) { + for (auto ParamType: ParamTypes) { + ParamType->unref(); + } + return nullptr; + } + } + if (!ParamTypes.empty()) { + return new ArrowTypeExpression(ParamTypes, RetType); + } + return RetType; +} + +MatchExpression* Parser::parseMatchExpression() { + auto T0 = expectToken(); + if (!T0) { + return nullptr; + } + auto T1 = Tokens.peek(); + Expression* Value; + BlockStart* BlockStart; + if (isa(T1)) { + Value = nullptr; + BlockStart = static_cast(T1); + Tokens.get(); + } else { + Value = parseExpression(); + if (!Value) { + T0->unref(); + return nullptr; + } + BlockStart = expectToken(); + if (!BlockStart) { + T0->unref(); + Value->unref(); + return nullptr; + } + } + std::vector Cases; + for (;;) { + auto T2 = Tokens.peek(); + if (isa(T2)) { + Tokens.get()->unref(); + break; + } + auto Pattern = parseWidePattern(); + if (!Pattern) { + skipPastLineFoldEnd(); + continue; + } + auto RArrowAlt = expectToken(); + if (!RArrowAlt) { + Pattern->unref(); + skipPastLineFoldEnd(); + continue; + } + auto Expression = parseExpression(); + if (!Expression) { + Pattern->unref(); + RArrowAlt->unref(); + skipPastLineFoldEnd(); + continue; + } + checkLineFoldEnd(); + Cases.push_back(new MatchCase { Pattern, RArrowAlt, Expression }); + } + return new MatchExpression(static_cast(T0), Value, BlockStart, Cases); +} + +RecordExpression* Parser::parseRecordExpression() { + auto LBrace = expectToken(); + if (!LBrace) { + return nullptr; + } + RBrace* RBrace; + auto T1 = Tokens.peek(); + std::vector> Fields; + if (T1->getKind() == NodeKind::RBrace) { + Tokens.get(); + RBrace = static_cast(T1); + } else { + for (;;) { + auto Name = expectToken(); if (!Name) { + LBrace->unref(); + for (auto [Field, Comma]: Fields) { + Field->unref(); + Comma->unref(); + } + return nullptr; + } + auto Equals = expectToken(); + if (!Equals) { + LBrace->unref(); + for (auto [Field, Comma]: Fields) { + Field->unref(); + Comma->unref(); + } + Name->unref(); + return nullptr; + } + auto E = parseExpression(); + if (!E) { + LBrace->unref(); + for (auto [Field, Comma]: Fields) { + Field->unref(); + Comma->unref(); + } + Name->unref(); + Equals->unref(); + return nullptr; + } + auto T2 = Tokens.peek(); + if (T2->getKind() == NodeKind::Comma) { + Tokens.get(); + Fields.push_back(std::make_tuple(new RecordExpressionField { Name, Equals, E }, static_cast(T2))); + } else if (T2->getKind() == NodeKind::RBrace) { + Tokens.get(); + RBrace = static_cast(T2); + Fields.push_back(std::make_tuple(new RecordExpressionField { Name, Equals, E }, nullptr)); + break; + } else { + DE.add(File, T2, std::vector { NodeKind::Comma, NodeKind::RBrace }); + LBrace->unref(); + for (auto [Field, Comma]: Fields) { + Field->unref(); + Comma->unref(); + } + Name->unref(); + Equals->unref(); + E->unref(); + return nullptr; + } + } + } + return new RecordExpression { LBrace, Fields, RBrace }; +} + +Expression* Parser::parsePrimitiveExpression() { + auto Annotations = parseAnnotations(); + auto T0 = Tokens.peek(); + switch (T0->getKind()) { + case NodeKind::Identifier: + case NodeKind::IdentifierAlt: + { + std::vector> ModulePath; + for (;;) { + auto T1 = Tokens.peek(0); + auto T2 = Tokens.peek(1); + if (!isa(T1) || !isa(T2)) { + break; + } + Tokens.get(); + Tokens.get(); + ModulePath.push_back(std::make_tuple(static_cast(T1), static_cast(T2))); + } + auto T3 = Tokens.get(); + if (!T3->is() && !T3->is()) { for (auto [Name, Dot]: ModulePath) { Name->unref(); Dot->unref(); } + DE.add(File, T3, std::vector { NodeKind::Identifier, NodeKind::IdentifierAlt }); return nullptr; } + return new ReferenceExpression(Annotations, ModulePath, static_cast(T3)); } - return new ReferenceTypeExpression(ModulePath, static_cast(Name)); - } - - TypeExpression* Parser::parseAppTypeExpression() { - auto OpTy = parsePrimitiveTypeExpression(); - if (!OpTy) { - return nullptr; - } - std::vector ArgTys; - for (;;) { + case NodeKind::LParen: + { + Tokens.get(); + std::vector> Elements; + auto LParen = static_cast(T0); + RParen* RParen; auto T1 = Tokens.peek(); - auto Kind = T1->getKind(); - if (Kind == NodeKind::Comma - || Kind == NodeKind::RArrow - || Kind == NodeKind::Equals - || Kind == NodeKind::BlockStart - || Kind == NodeKind::LineFoldEnd - || Kind == NodeKind::EndOfFile - || Kind == NodeKind::RParen - || Kind == NodeKind::RBracket - || Kind == NodeKind::RBrace - || Kind == NodeKind::VBar) { - break; + if (isa(T1)) { + Tokens.get(); + RParen = static_cast(T1); + goto after_tuple_elements; } - auto TE = parsePrimitiveTypeExpression(); - if (!TE) { - OpTy->unref(); - for (auto Arg: ArgTys) { - Arg->unref(); - } - return nullptr; - } - ArgTys.push_back(TE); - } - if (ArgTys.empty()) { - return OpTy; - } - return new AppTypeExpression { OpTy, ArgTys }; - } - - TypeExpression* Parser::parseArrowTypeExpression() { - auto RetType = parseAppTypeExpression(); - if (RetType == nullptr) { - return nullptr; - } - std::vector ParamTypes; - for (;;) { - auto T1 = Tokens.peek(); - if (T1->getKind() != NodeKind::RArrow) { - break; - } - Tokens.get(); - ParamTypes.push_back(RetType); - RetType = parseAppTypeExpression(); - if (!RetType) { - for (auto ParamType: ParamTypes) { - ParamType->unref(); - } - return nullptr; - } - } - if (!ParamTypes.empty()) { - return new ArrowTypeExpression(ParamTypes, RetType); - } - return RetType; - } - - MatchExpression* Parser::parseMatchExpression() { - auto T0 = expectToken(); - if (!T0) { - return nullptr; - } - auto T1 = Tokens.peek(); - Expression* Value; - BlockStart* BlockStart; - if (isa(T1)) { - Value = nullptr; - BlockStart = static_cast(T1); - Tokens.get(); - } else { - Value = parseExpression(); - if (!Value) { - T0->unref(); - return nullptr; - } - BlockStart = expectToken(); - if (!BlockStart) { - T0->unref(); - Value->unref(); - return nullptr; - } - } - std::vector Cases; - for (;;) { - auto T2 = Tokens.peek(); - if (isa(T2)) { - Tokens.get()->unref(); - break; - } - auto Pattern = parseWidePattern(); - if (!Pattern) { - skipPastLineFoldEnd(); - continue; - } - auto RArrowAlt = expectToken(); - if (!RArrowAlt) { - Pattern->unref(); - skipPastLineFoldEnd(); - continue; - } - auto Expression = parseExpression(); - if (!Expression) { - Pattern->unref(); - RArrowAlt->unref(); - skipPastLineFoldEnd(); - continue; - } - checkLineFoldEnd(); - Cases.push_back(new MatchCase { Pattern, RArrowAlt, Expression }); - } - return new MatchExpression(static_cast(T0), Value, BlockStart, Cases); - } - - RecordExpression* Parser::parseRecordExpression() { - auto LBrace = expectToken(); - if (!LBrace) { - return nullptr; - } - RBrace* RBrace; - auto T1 = Tokens.peek(); - std::vector> Fields; - if (T1->getKind() == NodeKind::RBrace) { - Tokens.get(); - RBrace = static_cast(T1); - } else { for (;;) { - auto Name = expectToken(); - if (!Name) { - LBrace->unref(); - for (auto [Field, Comma]: Fields) { - Field->unref(); - Comma->unref(); - } - return nullptr; - } - auto Equals = expectToken(); - if (!Equals) { - LBrace->unref(); - for (auto [Field, Comma]: Fields) { - Field->unref(); - Comma->unref(); - } - Name->unref(); - return nullptr; - } + auto T1 = Tokens.peek(); auto E = parseExpression(); if (!E) { - LBrace->unref(); - for (auto [Field, Comma]: Fields) { - Field->unref(); + LParen->unref(); + for (auto [E, Comma]: Elements) { + E->unref(); Comma->unref(); } - Name->unref(); - Equals->unref(); return nullptr; } - auto T2 = Tokens.peek(); - if (T2->getKind() == NodeKind::Comma) { - Tokens.get(); - Fields.push_back(std::make_tuple(new RecordExpressionField { Name, Equals, E }, static_cast(T2))); - } else if (T2->getKind() == NodeKind::RBrace) { - Tokens.get(); - RBrace = static_cast(T2); - Fields.push_back(std::make_tuple(new RecordExpressionField { Name, Equals, E }, nullptr)); - break; - } else { - DE.add(File, T2, std::vector { NodeKind::Comma, NodeKind::RBrace }); - LBrace->unref(); - for (auto [Field, Comma]: Fields) { - Field->unref(); - Comma->unref(); - } - Name->unref(); - Equals->unref(); - E->unref(); - return nullptr; - } - } - } - return new RecordExpression { LBrace, Fields, RBrace }; - } - - Expression* Parser::parsePrimitiveExpression() { - auto Annotations = parseAnnotations(); - auto T0 = Tokens.peek(); - switch (T0->getKind()) { - case NodeKind::Identifier: - case NodeKind::IdentifierAlt: - { - std::vector> ModulePath; - for (;;) { - auto T1 = Tokens.peek(0); - auto T2 = Tokens.peek(1); - if (!isa(T1) || !isa(T2)) { + auto T2 = Tokens.get(); + switch (T2->getKind()) { + case NodeKind::RParen: + RParen = static_cast(T2); + Elements.push_back({ E, nullptr }); + goto after_tuple_elements; + case NodeKind::Comma: + Elements.push_back({ E, static_cast(T2) }); break; - } - Tokens.get(); - Tokens.get(); - ModulePath.push_back(std::make_tuple(static_cast(T1), static_cast(T2))); - } - auto T3 = Tokens.get(); - if (!T3->is() && !T3->is()) { - for (auto [Name, Dot]: ModulePath) { - Name->unref(); - Dot->unref(); - } - DE.add(File, T3, std::vector { NodeKind::Identifier, NodeKind::IdentifierAlt }); - return nullptr; - } - return new ReferenceExpression(Annotations, ModulePath, static_cast(T3)); - } - case NodeKind::LParen: - { - Tokens.get(); - std::vector> Elements; - auto LParen = static_cast(T0); - RParen* RParen; - auto T1 = Tokens.peek(); - if (isa(T1)) { - Tokens.get(); - RParen = static_cast(T1); - goto after_tuple_elements; - } - for (;;) { - auto T1 = Tokens.peek(); - auto E = parseExpression(); - if (!E) { + default: + DE.add(File, T2, std::vector { NodeKind::RParen, NodeKind::Comma }); LParen->unref(); for (auto [E, Comma]: Elements) { E->unref(); Comma->unref(); } return nullptr; - } - auto T2 = Tokens.get(); - switch (T2->getKind()) { - case NodeKind::RParen: - RParen = static_cast(T2); - Elements.push_back({ E, nullptr }); - goto after_tuple_elements; - case NodeKind::Comma: - Elements.push_back({ E, static_cast(T2) }); - break; - default: - DE.add(File, T2, std::vector { NodeKind::RParen, NodeKind::Comma }); - LParen->unref(); - for (auto [E, Comma]: Elements) { - E->unref(); - Comma->unref(); - } - return nullptr; - case NodeKind::LineFoldEnd: - case NodeKind::BlockStart: - case NodeKind::EndOfFile: - // Can recover from this one - RParen = nullptr; - DE.add(File, T2, std::vector { NodeKind::RParen, NodeKind::Comma }); - goto after_tuple_elements; - } + case NodeKind::LineFoldEnd: + case NodeKind::BlockStart: + case NodeKind::EndOfFile: + // Can recover from this one + RParen = nullptr; + DE.add(File, T2, std::vector { NodeKind::RParen, NodeKind::Comma }); + goto after_tuple_elements; } -after_tuple_elements: - if (Elements.size() == 1 && !std::get<1>(Elements.front())) { - return new NestedExpression(Annotations, LParen, std::get<0>(Elements.front()), RParen); - } - return new TupleExpression { Annotations, LParen, Elements, RParen }; } - case NodeKind::MatchKeyword: - return parseMatchExpression(); - case NodeKind::IntegerLiteral: - case NodeKind::StringLiteral: - Tokens.get(); - return new LiteralExpression(Annotations, static_cast(T0)); - case NodeKind::LBrace: - return parseRecordExpression(); - default: - // Tokens.get(); - DE.add(File, T0, std::vector { - NodeKind::MatchKeyword, - NodeKind::Identifier, - NodeKind::IdentifierAlt, - NodeKind::LParen, - NodeKind::LBrace, - NodeKind::IntegerLiteral, - NodeKind::StringLiteral - }); - return nullptr; +after_tuple_elements: + if (Elements.size() == 1 && !std::get<1>(Elements.front())) { + return new NestedExpression(Annotations, LParen, std::get<0>(Elements.front()), RParen); + } + return new TupleExpression { Annotations, LParen, Elements, RParen }; } - } - - Expression* Parser::parseMemberExpression() { - auto E = parsePrimitiveExpression(); - if (!E) { + case NodeKind::MatchKeyword: + return parseMatchExpression(); + case NodeKind::IntegerLiteral: + case NodeKind::StringLiteral: + Tokens.get(); + return new LiteralExpression(Annotations, static_cast(T0)); + case NodeKind::LBrace: + return parseRecordExpression(); + default: + // Tokens.get(); + DE.add(File, T0, std::vector { + NodeKind::MatchKeyword, + NodeKind::Identifier, + NodeKind::IdentifierAlt, + NodeKind::LParen, + NodeKind::LBrace, + NodeKind::IntegerLiteral, + NodeKind::StringLiteral + }); return nullptr; + } +} + +Expression* Parser::parseMemberExpression() { + auto E = parsePrimitiveExpression(); + if (!E) { + return nullptr; + } + for (;;) { + auto T1 = Tokens.peek(0); + auto T2 = Tokens.peek(1); + if (!isa(T1)) { + break; } - for (;;) { - auto T1 = Tokens.peek(0); - auto T2 = Tokens.peek(1); - if (!isa(T1)) { + switch (T2->getKind()) { + case NodeKind::IntegerLiteral: + case NodeKind::Identifier: + { + Tokens.get(); + Tokens.get(); + auto Annotations = E->Annotations; + E->Annotations = {}; + E = new MemberExpression { Annotations, E, static_cast(T1), T2 }; break; } - switch (T2->getKind()) { - case NodeKind::IntegerLiteral: - case NodeKind::Identifier: - { - Tokens.get(); - Tokens.get(); - auto Annotations = E->Annotations; - E->Annotations = {}; - E = new MemberExpression { Annotations, E, static_cast(T1), T2 }; - break; - } - default: - goto finish; - } + default: + goto finish; } -finish: - return E; } +finish: + return E; +} - Expression* Parser::parseCallExpression() { - auto Operator = parseMemberExpression(); - if (!Operator) { +Expression* Parser::parseCallExpression() { + auto Operator = parseMemberExpression(); + if (!Operator) { + return nullptr; + } + std::vector Args; + for (;;) { + auto T1 = Tokens.peek(); + if (T1->getKind() == NodeKind::LineFoldEnd + || T1->getKind() == NodeKind::RParen + || T1->getKind() == NodeKind::RBracket + || T1->getKind() == NodeKind::RBrace + || T1->getKind() == NodeKind::BlockStart + || T1->getKind() == NodeKind::Comma + || ExprOperators.isInfix(T1)) { + break; + } + auto Arg = parseMemberExpression(); + if (!Arg) { + Operator->unref(); + for (auto Arg: Args) { + Arg->unref(); + } + return nullptr; + } + Args.push_back(Arg); + } + if (Args.empty()) { + return Operator; + } + auto Annotations = Operator->Annotations; + Operator->Annotations = {}; + return new CallExpression(Annotations, Operator, Args); +} + +Expression* Parser::parseUnaryExpression() { + std::vector Prefix; + for (;;) { + auto T0 = Tokens.peek(); + if (!ExprOperators.isPrefix(T0)) { + break; + } + Tokens.get(); + Prefix.push_back(T0); + } + auto E = parseCallExpression(); + if (!E) { + for (auto Tok: Prefix) { + Tok->unref(); + } + return nullptr; + } + for (auto Iter = Prefix.rbegin(); Iter != Prefix.rend(); Iter++) { + E = new PrefixExpression(*Iter, E); + } + return E; +} + +Expression* Parser::parseInfixOperatorAfterExpression(Expression* Left, int MinPrecedence) { + for (;;) { + auto T0 = Tokens.peek(); + auto Info0 = ExprOperators.getInfix(T0); + if (!Info0 || Info0->Precedence < MinPrecedence) { + break; + } + Tokens.get(); + auto Right = parseUnaryExpression(); + if (!Right) { + Left->unref(); + T0->unref(); return nullptr; } - std::vector Args; for (;;) { auto T1 = Tokens.peek(); - if (T1->getKind() == NodeKind::LineFoldEnd - || T1->getKind() == NodeKind::RParen - || T1->getKind() == NodeKind::RBracket - || T1->getKind() == NodeKind::RBrace - || T1->getKind() == NodeKind::BlockStart - || T1->getKind() == NodeKind::Comma - || ExprOperators.isInfix(T1)) { + auto Info1 = ExprOperators.getInfix(T1); + if (!Info1 || Info1->Precedence < Info0->Precedence && (Info1->Precedence > Info0->Precedence || Info1->isRightAssoc())) { break; } - auto Arg = parseMemberExpression(); - if (!Arg) { - Operator->unref(); - for (auto Arg: Args) { - Arg->unref(); - } - return nullptr; - } - Args.push_back(Arg); - } - if (Args.empty()) { - return Operator; - } - auto Annotations = Operator->Annotations; - Operator->Annotations = {}; - return new CallExpression(Annotations, Operator, Args); - } - - Expression* Parser::parseUnaryExpression() { - std::vector Prefix; - for (;;) { - auto T0 = Tokens.peek(); - if (!ExprOperators.isPrefix(T0)) { - break; - } - Tokens.get(); - Prefix.push_back(T0); - } - auto E = parseCallExpression(); - if (!E) { - for (auto Tok: Prefix) { - Tok->unref(); - } - return nullptr; - } - for (auto Iter = Prefix.rbegin(); Iter != Prefix.rend(); Iter++) { - E = new PrefixExpression(*Iter, E); - } - return E; - } - - Expression* Parser::parseInfixOperatorAfterExpression(Expression* Left, int MinPrecedence) { - for (;;) { - auto T0 = Tokens.peek(); - auto Info0 = ExprOperators.getInfix(T0); - if (!Info0 || Info0->Precedence < MinPrecedence) { - break; - } - Tokens.get(); - auto Right = parseUnaryExpression(); - if (!Right) { + auto NewRight = parseInfixOperatorAfterExpression(Right, Info1->Precedence); + if (!NewRight) { Left->unref(); T0->unref(); + Right->unref(); return nullptr; } - for (;;) { - auto T1 = Tokens.peek(); - auto Info1 = ExprOperators.getInfix(T1); - if (!Info1 || Info1->Precedence < Info0->Precedence && (Info1->Precedence > Info0->Precedence || Info1->isRightAssoc())) { - break; - } - auto NewRight = parseInfixOperatorAfterExpression(Right, Info1->Precedence); - if (!NewRight) { - Left->unref(); - T0->unref(); - Right->unref(); - return nullptr; - } - Right = NewRight; - } - Left = new InfixExpression(Left, T0, Right); + Right = NewRight; } - return Left; + Left = new InfixExpression(Left, T0, Right); } + return Left; +} - Expression* Parser::parseExpression() { - auto Left = parseUnaryExpression(); - if (!Left) { - return nullptr; - } - return parseInfixOperatorAfterExpression(Left, 0); +Expression* Parser::parseExpression() { + auto Left = parseUnaryExpression(); + if (!Left) { + return nullptr; } + return parseInfixOperatorAfterExpression(Left, 0); +} - ExpressionStatement* Parser::parseExpressionStatement() { - auto E = parseExpression(); - if (!E) { +ExpressionStatement* Parser::parseExpressionStatement() { + auto E = parseExpression(); + if (!E) { + skipPastLineFoldEnd(); + return nullptr; + } + checkLineFoldEnd(); + return new ExpressionStatement(E); +} + +ReturnStatement* Parser::parseReturnStatement() { + auto Annotations = parseAnnotations(); + auto ReturnKeyword = expectToken(); + if (!ReturnKeyword) { + return nullptr; + } + Expression* Expression; + auto T1 = Tokens.peek(); + if (T1->getKind() == NodeKind::LineFoldEnd) { + Tokens.get()->unref(); + Expression = nullptr; + } else { + Expression = parseExpression(); + if (!Expression) { + ReturnKeyword->unref(); skipPastLineFoldEnd(); return nullptr; } checkLineFoldEnd(); - return new ExpressionStatement(E); } + return new ReturnStatement(Annotations, ReturnKeyword, Expression); +} - ReturnStatement* Parser::parseReturnStatement() { - auto Annotations = parseAnnotations(); - auto ReturnKeyword = expectToken(); - if (!ReturnKeyword) { - return nullptr; - } - Expression* Expression; - auto T1 = Tokens.peek(); - if (T1->getKind() == NodeKind::LineFoldEnd) { +IfStatement* Parser::parseIfStatement() { + std::vector Parts; + auto Annotations = parseAnnotations(); + auto IfKeyword = expectToken(); + auto Test = parseExpression(); + if (!Test) { + IfKeyword->unref(); + skipPastLineFoldEnd(); + return nullptr; + } + auto T1 = expectToken(); + if (!T1) { + IfKeyword->unref(); + Test->unref(); + skipPastLineFoldEnd(); + return nullptr; + } + std::vector Then; + for (;;) { + auto T2 = Tokens.peek(); + if (T2->getKind() == NodeKind::BlockEnd) { Tokens.get()->unref(); - Expression = nullptr; - } else { - Expression = parseExpression(); - if (!Expression) { - ReturnKeyword->unref(); - skipPastLineFoldEnd(); - return nullptr; - } - checkLineFoldEnd(); + break; + } + auto Element = parseLetBodyElement(); + if (Element) { + Then.push_back(Element); } - return new ReturnStatement(Annotations, ReturnKeyword, Expression); } - - IfStatement* Parser::parseIfStatement() { - std::vector Parts; + Tokens.get()->unref(); // Always a LineFoldEnd + Parts.push_back(new IfStatementPart(Annotations, IfKeyword, Test, T1, Then)); + for (;;) { + auto T3 = peekFirstTokenAfterAnnotationsAndModifiers(); + if (T3->getKind() != NodeKind::ElseKeyword && T3->getKind() != NodeKind::ElifKeyword) { + break; + } auto Annotations = parseAnnotations(); - auto IfKeyword = expectToken(); - auto Test = parseExpression(); - if (!Test) { - IfKeyword->unref(); - skipPastLineFoldEnd(); + Tokens.get(); + Expression* Test = nullptr; + if (T3->getKind() == NodeKind::ElifKeyword) { + Test = parseExpression(); + } + auto T4 = expectToken(); + if (!T4) { + for (auto Part: Parts) { + Part->unref(); + } return nullptr; } - auto T1 = expectToken(); - if (!T1) { - IfKeyword->unref(); - Test->unref(); - skipPastLineFoldEnd(); - return nullptr; - } - std::vector Then; + std::vector Alt; for (;;) { - auto T2 = Tokens.peek(); - if (T2->getKind() == NodeKind::BlockEnd) { + auto T5 = Tokens.peek(); + if (T5->getKind() == NodeKind::BlockEnd) { Tokens.get()->unref(); break; } auto Element = parseLetBodyElement(); if (Element) { - Then.push_back(Element); + Alt.push_back(Element); } } Tokens.get()->unref(); // Always a LineFoldEnd - Parts.push_back(new IfStatementPart(Annotations, IfKeyword, Test, T1, Then)); - for (;;) { - auto T3 = peekFirstTokenAfterAnnotationsAndModifiers(); - if (T3->getKind() != NodeKind::ElseKeyword && T3->getKind() != NodeKind::ElifKeyword) { - break; - } - auto Annotations = parseAnnotations(); - Tokens.get(); - Expression* Test = nullptr; - if (T3->getKind() == NodeKind::ElifKeyword) { - Test = parseExpression(); - } - auto T4 = expectToken(); - if (!T4) { - for (auto Part: Parts) { - Part->unref(); - } - return nullptr; - } - std::vector Alt; - for (;;) { - auto T5 = Tokens.peek(); - if (T5->getKind() == NodeKind::BlockEnd) { - Tokens.get()->unref(); - break; - } - auto Element = parseLetBodyElement(); - if (Element) { - Alt.push_back(Element); - } - } - Tokens.get()->unref(); // Always a LineFoldEnd - Parts.push_back(new IfStatementPart(Annotations, T3, Test, T4, Alt)); - if (T3->getKind() == NodeKind::ElseKeyword) { - break; - } + Parts.push_back(new IfStatementPart(Annotations, T3, Test, T4, Alt)); + if (T3->getKind() == NodeKind::ElseKeyword) { + break; } - return new IfStatement(Parts); + } + return new IfStatement(Parts); +} + +LetDeclaration* Parser::parseLetDeclaration() { + + auto Annotations = parseAnnotations(); + PubKeyword* Pub = nullptr; + ForeignKeyword* Foreign = nullptr; + LetKeyword* Let; + MutKeyword* Mut = nullptr; + Pattern* Name; + std::vector Params; + TypeAssert* TA = nullptr; + LetBody* Body = nullptr; + + auto T0 = Tokens.get(); + if (T0->getKind() == NodeKind::PubKeyword) { + Pub = static_cast(T0); + T0 = Tokens.get(); + } + if (T0->getKind() == NodeKind::ForeignKeyword) { + Foreign = static_cast(T0); + T0 = Tokens.get(); + } + if (T0->getKind() != NodeKind::LetKeyword) { + DE.add(File, T0, std::vector { NodeKind::LetKeyword }); + if (Pub) { + Pub->unref(); + } + if (Foreign) { + Foreign->unref(); + } + skipPastLineFoldEnd(); + return nullptr; + } + Let = static_cast(T0); + auto T1 = Tokens.peek(); + if (T1->getKind() == NodeKind::MutKeyword) { + Tokens.get(); + Mut = static_cast(T1); } - LetDeclaration* Parser::parseLetDeclaration() { - - auto Annotations = parseAnnotations(); - PubKeyword* Pub = nullptr; - ForeignKeyword* Foreign = nullptr; - LetKeyword* Let; - MutKeyword* Mut = nullptr; - Pattern* Name; - std::vector Params; - TypeAssert* TA = nullptr; - LetBody* Body = nullptr; - - auto T0 = Tokens.get(); - if (T0->getKind() == NodeKind::PubKeyword) { - Pub = static_cast(T0); - T0 = Tokens.get(); - } - if (T0->getKind() == NodeKind::ForeignKeyword) { - Foreign = static_cast(T0); - T0 = Tokens.get(); - } - if (T0->getKind() != NodeKind::LetKeyword) { - DE.add(File, T0, std::vector { NodeKind::LetKeyword }); + auto T2 = Tokens.peek(0); + auto T3 = Tokens.peek(1); + auto T4 = Tokens.peek(2); + if (isOperator(T2)) { + Tokens.get(); + auto P1 = parseNarrowPattern(); + Params.push_back(new Parameter(P1, nullptr)); + Name = new BindPattern(T2); + goto after_params; + } else if (isOperator(T3) && (T4->getKind() == NodeKind::Colon || T4->getKind() == NodeKind::Equals || T4->getKind() == NodeKind::BlockStart || T4->getKind() == NodeKind::LineFoldEnd)) { + auto P1 = parseNarrowPattern(); + Params.push_back(new Parameter(P1, nullptr)); + Tokens.get(); + Name = new BindPattern(T3); + goto after_params; + } else if (T2->getKind() == NodeKind::LParen && isOperator(T3) && T4->getKind() == NodeKind::RParen) { + Tokens.get(); + Tokens.get(); + Tokens.get(); + Name = new BindPattern( + new WrappedOperator( + static_cast(T2), + T3, + static_cast(T3) + ) + ); + } else if (isOperator(T3)) { + auto P1 = parseNarrowPattern(); + Params.push_back(new Parameter(P1, nullptr)); + Tokens.get(); + auto P2 = parseNarrowPattern(); + Params.push_back(new Parameter(P2, nullptr)); + Name = new BindPattern(T3); + goto after_params; + } else { + Name = parseNarrowPattern(); + if (!Name) { if (Pub) { Pub->unref(); } if (Foreign) { Foreign->unref(); } + Let->unref(); + if (Mut) { + Mut->unref(); + } skipPastLineFoldEnd(); return nullptr; } - Let = static_cast(T0); - auto T1 = Tokens.peek(); - if (T1->getKind() == NodeKind::MutKeyword) { - Tokens.get(); - Mut = static_cast(T1); - } + } - auto T2 = Tokens.peek(0); - auto T3 = Tokens.peek(1); - auto T4 = Tokens.peek(2); - if (isOperator(T2)) { - Tokens.get(); - auto P1 = parseNarrowPattern(); - Params.push_back(new Parameter(P1, nullptr)); - Name = new BindPattern(T2); - goto after_params; - } else if (isOperator(T3) && (T4->getKind() == NodeKind::Colon || T4->getKind() == NodeKind::Equals || T4->getKind() == NodeKind::BlockStart || T4->getKind() == NodeKind::LineFoldEnd)) { - auto P1 = parseNarrowPattern(); - Params.push_back(new Parameter(P1, nullptr)); - Tokens.get(); - Name = new BindPattern(T3); - goto after_params; - } else if (T2->getKind() == NodeKind::LParen && isOperator(T3) && T4->getKind() == NodeKind::RParen) { - Tokens.get(); - Tokens.get(); - Tokens.get(); - Name = new BindPattern( - new WrappedOperator( - static_cast(T2), - T3, - static_cast(T3) - ) - ); - } else if (isOperator(T3)) { - auto P1 = parseNarrowPattern(); - Params.push_back(new Parameter(P1, nullptr)); - Tokens.get(); - auto P2 = parseNarrowPattern(); - Params.push_back(new Parameter(P2, nullptr)); - Name = new BindPattern(T3); - goto after_params; - } else { - Name = parseNarrowPattern(); - if (!Name) { - if (Pub) { - Pub->unref(); + for (;;) { + auto T5 = Tokens.peek(); + switch (T5->getKind()) { + case NodeKind::LineFoldEnd: + case NodeKind::BlockStart: + case NodeKind::Equals: + case NodeKind::Colon: + goto after_params; + default: + auto P = parseNarrowPattern(); + if (!P) { + Tokens.get(); + P = new BindPattern(new Identifier("_")); } - if (Foreign) { - Foreign->unref(); - } - Let->unref(); - if (Mut) { - Mut->unref(); - } - skipPastLineFoldEnd(); - return nullptr; - } - } - - for (;;) { - auto T5 = Tokens.peek(); - switch (T5->getKind()) { - case NodeKind::LineFoldEnd: - case NodeKind::BlockStart: - case NodeKind::Equals: - case NodeKind::Colon: - goto after_params; - default: - auto P = parseNarrowPattern(); - if (!P) { - Tokens.get(); - P = new BindPattern(new Identifier("_")); - } - Params.push_back(new Parameter(P, nullptr)); - } + Params.push_back(new Parameter(P, nullptr)); } + } after_params: - auto T5 = Tokens.peek(); + auto T5 = Tokens.peek(); - if (T5->getKind() == NodeKind::Colon) { + if (T5->getKind() == NodeKind::Colon) { + Tokens.get(); + auto TE = parseTypeExpression(); + if (TE) { + TA = new TypeAssert(static_cast(T5), TE); + } else { + skipPastLineFoldEnd(); + goto finish; + } + T5 = Tokens.peek(); + } + + switch (T5->getKind()) { + case NodeKind::BlockStart: + { Tokens.get(); - auto TE = parseTypeExpression(); - if (TE) { - TA = new TypeAssert(static_cast(T5), TE); - } else { + std::vector Elements; + for (;;) { + auto T6 = Tokens.peek(); + if (T6->getKind() == NodeKind::BlockEnd) { + break; + } + auto Element = parseLetBodyElement(); + if (Element) { + Elements.push_back(Element); + } + } + Tokens.get()->unref(); // Always a BlockEnd + Body = new LetBlockBody(static_cast(T5), Elements); + break; + } + case NodeKind::Equals: + { + Tokens.get(); + auto E = parseExpression(); + if (!E) { skipPastLineFoldEnd(); goto finish; } - T5 = Tokens.peek(); + Body = new LetExprBody(static_cast(T5), E); + break; } - - switch (T5->getKind()) { - case NodeKind::BlockStart: - { - Tokens.get(); - std::vector Elements; - for (;;) { - auto T6 = Tokens.peek(); - if (T6->getKind() == NodeKind::BlockEnd) { - break; - } - auto Element = parseLetBodyElement(); - if (Element) { - Elements.push_back(Element); - } - } - Tokens.get()->unref(); // Always a BlockEnd - Body = new LetBlockBody(static_cast(T5), Elements); - break; + case NodeKind::LineFoldEnd: + break; + default: + std::vector Expected { NodeKind::BlockStart, NodeKind::LineFoldEnd, NodeKind::Equals }; + if (TA == nullptr) { + // First tokens of TypeAssert + Expected.push_back(NodeKind::Colon); + // First tokens of Pattern + Expected.push_back(NodeKind::Identifier); } - case NodeKind::Equals: - { - Tokens.get(); - auto E = parseExpression(); - if (!E) { - skipPastLineFoldEnd(); - goto finish; - } - Body = new LetExprBody(static_cast(T5), E); - break; - } - case NodeKind::LineFoldEnd: - break; - default: - std::vector Expected { NodeKind::BlockStart, NodeKind::LineFoldEnd, NodeKind::Equals }; - if (TA == nullptr) { - // First tokens of TypeAssert - Expected.push_back(NodeKind::Colon); - // First tokens of Pattern - Expected.push_back(NodeKind::Identifier); - } - DE.add(File, T5, Expected); - } + DE.add(File, T5, Expected); + } - checkLineFoldEnd(); + checkLineFoldEnd(); finish: - return new LetDeclaration( - Annotations, - Pub, - Foreign, - Let, - Mut, - Name, - Params, - TA, - Body - ); - } + return new LetDeclaration( + Annotations, + Pub, + Foreign, + Let, + Mut, + Name, + Params, + TA, + Body + ); +} - Node* Parser::parseLetBodyElement() { - auto T0 = peekFirstTokenAfterAnnotationsAndModifiers(); - switch (T0->getKind()) { - case NodeKind::LetKeyword: - return parseLetDeclaration(); - case NodeKind::ReturnKeyword: - return parseReturnStatement(); - case NodeKind::IfKeyword: - return parseIfStatement(); +Node* Parser::parseLetBodyElement() { + auto T0 = peekFirstTokenAfterAnnotationsAndModifiers(); + switch (T0->getKind()) { + case NodeKind::LetKeyword: + return parseLetDeclaration(); + case NodeKind::ReturnKeyword: + return parseReturnStatement(); + case NodeKind::IfKeyword: + return parseIfStatement(); + default: + return parseExpressionStatement(); + } +} + +ConstraintExpression* Parser::parseConstraintExpression() { + bool HasTilde = false; + for (std::size_t I = 0; ; I++) { + auto Tok = Tokens.peek(I); + switch (Tok->getKind()) { + case NodeKind::Tilde: + HasTilde = true; + goto after_lookahead; + case NodeKind::RParen: + case NodeKind::Comma: + case NodeKind::RArrowAlt: + case NodeKind::EndOfFile: + goto after_lookahead; default: - return parseExpressionStatement(); + continue; } } - - ConstraintExpression* Parser::parseConstraintExpression() { - bool HasTilde = false; - for (std::size_t I = 0; ; I++) { - auto Tok = Tokens.peek(I); - switch (Tok->getKind()) { - case NodeKind::Tilde: - HasTilde = true; - goto after_lookahead; - case NodeKind::RParen: - case NodeKind::Comma: - case NodeKind::RArrowAlt: - case NodeKind::EndOfFile: - goto after_lookahead; - default: - continue; - } - } after_lookahead: - if (HasTilde) { - auto Left = parseArrowTypeExpression(); - if (!Left) { - return nullptr; - } - auto Tilde = expectToken(); - if (!Tilde) { - Left->unref(); - return nullptr; - } - auto Right = parseArrowTypeExpression(); - if (!Right) { - Left->unref(); - Tilde->unref(); - return nullptr; - } - return new EqualityConstraintExpression { Left, Tilde, Right }; - } - auto Name = expectToken(); - if (!Name) { + if (HasTilde) { + auto Left = parseArrowTypeExpression(); + if (!Left) { return nullptr; } - std::vector TEs; - for (;;) { - auto T1 = Tokens.peek(); - switch (T1->getKind()) { - case NodeKind::RParen: - case NodeKind::RArrowAlt: - case NodeKind::Comma: - goto after_vars; - case NodeKind::Identifier: - Tokens.get(); - TEs.push_back(new VarTypeExpression { static_cast(T1) }); - break; - default: - DE.add(File, T1, std::vector { NodeKind::RParen, NodeKind::RArrowAlt, NodeKind::Comma, NodeKind::Identifier }); - Name->unref(); - return nullptr; - } + auto Tilde = expectToken(); + if (!Tilde) { + Left->unref(); + return nullptr; } -after_vars: - return new TypeclassConstraintExpression { Name, TEs }; + auto Right = parseArrowTypeExpression(); + if (!Right) { + Left->unref(); + Tilde->unref(); + return nullptr; + } + return new EqualityConstraintExpression { Left, Tilde, Right }; } - - VarTypeExpression* Parser::parseVarTypeExpression() { - auto Name = expectToken(); - if (!Name) { - return nullptr; - } - for (auto Ch: Name->Text) { - if (!std::islower(Ch)) { - // TODO - // DE.add(Name); - Name->unref(); - return nullptr; - } - } - return new VarTypeExpression { Name }; + auto Name = expectToken(); + if (!Name) { + return nullptr; } - - InstanceDeclaration* Parser::parseInstanceDeclaration() { - auto InstanceKeyword = expectToken(); - if (!InstanceKeyword) { - skipPastLineFoldEnd(); - return nullptr; - } - auto Name = expectToken(); - if (!Name) { - InstanceKeyword->unref(); - skipPastLineFoldEnd(); - return nullptr; - } - std::vector TypeExps; - for (;;) { - auto T1 = Tokens.peek(); - if (T1->is()) { + std::vector TEs; + for (;;) { + auto T1 = Tokens.peek(); + switch (T1->getKind()) { + case NodeKind::RParen: + case NodeKind::RArrowAlt: + case NodeKind::Comma: + goto after_vars; + case NodeKind::Identifier: + Tokens.get(); + TEs.push_back(new VarTypeExpression { static_cast(T1) }); break; - } - auto TE = parseTypeExpression(); - if (!TE) { - InstanceKeyword->unref(); + default: + DE.add(File, T1, std::vector { NodeKind::RParen, NodeKind::RArrowAlt, NodeKind::Comma, NodeKind::Identifier }); Name->unref(); - for (auto TE: TypeExps) { - TE->unref(); - } - skipPastLineFoldEnd(); return nullptr; - } - TypeExps.push_back(TE); } - auto BlockStart = expectToken(); - if (!BlockStart) { + } +after_vars: + return new TypeclassConstraintExpression { Name, TEs }; +} + +VarTypeExpression* Parser::parseVarTypeExpression() { + auto Name = expectToken(); + if (!Name) { + return nullptr; + } + for (auto Ch: Name->Text) { + if (!std::islower(Ch)) { + // TODO + // DE.add(Name); + Name->unref(); + return nullptr; + } + } + return new VarTypeExpression { Name }; +} + +InstanceDeclaration* Parser::parseInstanceDeclaration() { + auto InstanceKeyword = expectToken(); + if (!InstanceKeyword) { + skipPastLineFoldEnd(); + return nullptr; + } + auto Name = expectToken(); + if (!Name) { + InstanceKeyword->unref(); + skipPastLineFoldEnd(); + return nullptr; + } + std::vector TypeExps; + for (;;) { + auto T1 = Tokens.peek(); + if (T1->is()) { + break; + } + auto TE = parseTypeExpression(); + if (!TE) { InstanceKeyword->unref(); Name->unref(); for (auto TE: TypeExps) { @@ -1458,74 +1446,72 @@ after_vars: skipPastLineFoldEnd(); return nullptr; } - std::vector Elements; - for (;;) { - auto T2 = Tokens.peek(); - if (T2->is()) { - Tokens.get()->unref(); - break; - } - auto Element = parseClassElement(); - if (Element) { - Elements.push_back(Element); - } - } - checkLineFoldEnd(); - return new InstanceDeclaration( - InstanceKeyword, - Name, - TypeExps, - BlockStart, - Elements - ); + TypeExps.push_back(TE); } + auto BlockStart = expectToken(); + if (!BlockStart) { + InstanceKeyword->unref(); + Name->unref(); + for (auto TE: TypeExps) { + TE->unref(); + } + skipPastLineFoldEnd(); + return nullptr; + } + std::vector Elements; + for (;;) { + auto T2 = Tokens.peek(); + if (T2->is()) { + Tokens.get()->unref(); + break; + } + auto Element = parseClassElement(); + if (Element) { + Elements.push_back(Element); + } + } + checkLineFoldEnd(); + return new InstanceDeclaration( + InstanceKeyword, + Name, + TypeExps, + BlockStart, + Elements + ); +} - ClassDeclaration* Parser::parseClassDeclaration() { - PubKeyword* PubKeyword = nullptr; - auto T0 = Tokens.peek(); - if (T0->getKind() == NodeKind::PubKeyword) { - Tokens.get(); - PubKeyword = static_cast(T0); +ClassDeclaration* Parser::parseClassDeclaration() { + PubKeyword* PubKeyword = nullptr; + auto T0 = Tokens.peek(); + if (T0->getKind() == NodeKind::PubKeyword) { + Tokens.get(); + PubKeyword = static_cast(T0); + } + auto ClassKeyword = expectToken(); + if (!ClassKeyword) { + if (PubKeyword) { + PubKeyword->unref(); } - auto ClassKeyword = expectToken(); - if (!ClassKeyword) { - if (PubKeyword) { - PubKeyword->unref(); - } - skipPastLineFoldEnd(); - return nullptr; + skipPastLineFoldEnd(); + return nullptr; + } + auto Name = expectToken(); + if (!Name) { + if (PubKeyword) { + PubKeyword->unref(); } - auto Name = expectToken(); - if (!Name) { - if (PubKeyword) { - PubKeyword->unref(); - } - ClassKeyword->unref(); - skipPastLineFoldEnd(); - return nullptr; + ClassKeyword->unref(); + skipPastLineFoldEnd(); + return nullptr; + } + std::vector TypeVars; + for (;;) { + auto T2 = Tokens.peek(); + if (T2->getKind() == NodeKind::BlockStart) { + break; } - std::vector TypeVars; - for (;;) { - auto T2 = Tokens.peek(); - if (T2->getKind() == NodeKind::BlockStart) { - break; - } - auto TE = parseVarTypeExpression(); - if (!TE) { - if (PubKeyword) { - PubKeyword->unref(); - } - ClassKeyword->unref(); - for (auto TV: TypeVars) { - TV->unref(); - } - skipPastLineFoldEnd(); - return nullptr; - } - TypeVars.push_back(TE); - } - auto BlockStart = expectToken(); - if (!BlockStart) { + auto TE = parseVarTypeExpression(); + if (!TE) { if (PubKeyword) { PubKeyword->unref(); } @@ -1536,420 +1522,434 @@ after_vars: skipPastLineFoldEnd(); return nullptr; } - std::vector Elements; - for (;;) { - auto T2 = Tokens.peek(); - if (T2->is()) { - Tokens.get()->unref(); - break; - } - auto Element = parseClassElement(); - if (Element) { - Elements.push_back(Element); - } - } - Tokens.get()->unref(); // Always a LineFoldEnd - return new ClassDeclaration( - PubKeyword, - ClassKeyword, - Name, - TypeVars, - BlockStart, - Elements - ); + TypeVars.push_back(TE); } - - std::vector Parser::parseRecordDeclarationFields() { - std::vector Fields; - for (;;) { - auto T1 = Tokens.peek(); - if (T1->getKind() == NodeKind::BlockEnd) { - Tokens.get()->unref(); - break; - } - auto Name = expectToken(); - if (!Name) { - skipPastLineFoldEnd(); - continue; - } - auto Colon = expectToken(); - if (!Colon) { - Name->unref(); - skipPastLineFoldEnd(); - continue; - } - auto TE = parseTypeExpression(); - if (!TE) { - Name->unref(); - Colon->unref(); - skipPastLineFoldEnd(); - continue; - } - checkLineFoldEnd(); - Fields.push_back(new RecordDeclarationField { Name, Colon, TE }); + auto BlockStart = expectToken(); + if (!BlockStart) { + if (PubKeyword) { + PubKeyword->unref(); } - return Fields; + ClassKeyword->unref(); + for (auto TV: TypeVars) { + TV->unref(); + } + skipPastLineFoldEnd(); + return nullptr; } + std::vector Elements; + for (;;) { + auto T2 = Tokens.peek(); + if (T2->is()) { + Tokens.get()->unref(); + break; + } + auto Element = parseClassElement(); + if (Element) { + Elements.push_back(Element); + } + } + Tokens.get()->unref(); // Always a LineFoldEnd + return new ClassDeclaration( + PubKeyword, + ClassKeyword, + Name, + TypeVars, + BlockStart, + Elements + ); +} - RecordDeclaration* Parser::parseRecordDeclaration() { - auto T0 = Tokens.peek(); - PubKeyword* Pub = nullptr; - if (T0->getKind() == NodeKind::MutKeyword) { - Tokens.get(); - Pub = static_cast(T0); +std::vector Parser::parseRecordDeclarationFields() { + std::vector Fields; + for (;;) { + auto T1 = Tokens.peek(); + if (T1->getKind() == NodeKind::BlockEnd) { + Tokens.get()->unref(); + break; } - auto Struct = expectToken(); - if (!Struct) { - if (Pub) { - Pub->unref(); - } - skipPastLineFoldEnd(); - return nullptr; - } - auto Name = expectToken(); + auto Name = expectToken(); if (!Name) { - if (Pub) { - Pub->unref(); - } - Struct->unref(); skipPastLineFoldEnd(); - return nullptr; + continue; } - std::vector Vars; - for (;;) { - auto T1 = Tokens.peek(); - if (T1->getKind() == NodeKind::BlockStart) { - break; - } - auto Var = parseVarTypeExpression(); - if (Var) { - Vars.push_back(Var); - } - } - auto BS = expectToken(); - if (!BS) { - if (Pub) { - Pub->unref(); - } - Struct->unref(); + auto Colon = expectToken(); + if (!Colon) { Name->unref(); skipPastLineFoldEnd(); - return nullptr; + continue; } - auto Fields = parseRecordDeclarationFields(); - Tokens.get()->unref(); // Always a LineFoldEnd - return new RecordDeclaration { Pub, Struct, Name, Vars, BS, Fields }; - } - - VariantDeclaration* Parser::parseVariantDeclaration() { - auto T0 = Tokens.peek(); - PubKeyword* Pub = nullptr; - if (T0->getKind() == NodeKind::MutKeyword) { - Tokens.get(); - Pub = static_cast(T0); - } - auto Enum = expectToken(); - if (!Enum) { - if (Pub) { - Pub->unref(); - } - skipPastLineFoldEnd(); - return nullptr; - } - auto Name = expectToken(); - if (!Name) { - if (Pub) { - Pub->unref(); - } - Enum->unref(); - skipPastLineFoldEnd(); - return nullptr; - } - std::vector TVs; - for (;;) { - auto T0 = Tokens.peek(); - if (T0->getKind() == NodeKind::BlockStart) { - break; - } - auto Var = parseVarTypeExpression(); - if (Var) { - TVs.push_back(Var); - } - } - auto BS = expectToken(); - if (!BS) { - if (Pub) { - Pub->unref(); - } - Enum->unref(); + auto TE = parseTypeExpression(); + if (!TE) { Name->unref(); + Colon->unref(); skipPastLineFoldEnd(); - return nullptr; - } - std::vector Members; - for (;;) { -next_member: - auto T0 = Tokens.peek(); - if (T0->getKind() == NodeKind::BlockEnd) { - Tokens.get()->unref(); - break; - } - auto Name = expectToken(); - if (!Name) { - skipPastLineFoldEnd(); - continue; - } - auto T1 = Tokens.peek(); - if (T1->getKind() == NodeKind::BlockStart) { - Tokens.get(); - auto BS = static_cast(T1); - auto Fields = parseRecordDeclarationFields(); - // TODO continue; on error in Fields - Members.push_back(new RecordVariantDeclarationMember { Name, BS, Fields }); - } else { - std::vector Elements; - for (;;) { - auto T2 = Tokens.peek(); - if (T2->getKind() == NodeKind::LineFoldEnd) { - Tokens.get()->unref(); - break; - } - auto TE = parsePrimitiveTypeExpression(); - if (!TE) { - Name->unref(); - for (auto El: Elements) { - El->unref(); - } - goto next_member; - } - Elements.push_back(TE); - } - Members.push_back(new TupleVariantDeclarationMember { Name, Elements }); - } + continue; } checkLineFoldEnd(); - return new VariantDeclaration { Pub, Enum, Name, TVs, BS, Members }; + Fields.push_back(new RecordDeclarationField { Name, Colon, TE }); } + return Fields; +} - Node* Parser::parseClassElement() { +RecordDeclaration* Parser::parseRecordDeclaration() { + auto T0 = Tokens.peek(); + PubKeyword* Pub = nullptr; + if (T0->getKind() == NodeKind::MutKeyword) { + Tokens.get(); + Pub = static_cast(T0); + } + auto Struct = expectToken(); + if (!Struct) { + if (Pub) { + Pub->unref(); + } + skipPastLineFoldEnd(); + return nullptr; + } + auto Name = expectToken(); + if (!Name) { + if (Pub) { + Pub->unref(); + } + Struct->unref(); + skipPastLineFoldEnd(); + return nullptr; + } + std::vector Vars; + for (;;) { + auto T1 = Tokens.peek(); + if (T1->getKind() == NodeKind::BlockStart) { + break; + } + auto Var = parseVarTypeExpression(); + if (Var) { + Vars.push_back(Var); + } + } + auto BS = expectToken(); + if (!BS) { + if (Pub) { + Pub->unref(); + } + Struct->unref(); + Name->unref(); + skipPastLineFoldEnd(); + return nullptr; + } + auto Fields = parseRecordDeclarationFields(); + Tokens.get()->unref(); // Always a LineFoldEnd + return new RecordDeclaration { Pub, Struct, Name, Vars, BS, Fields }; +} + +VariantDeclaration* Parser::parseVariantDeclaration() { + auto T0 = Tokens.peek(); + PubKeyword* Pub = nullptr; + if (T0->getKind() == NodeKind::MutKeyword) { + Tokens.get(); + Pub = static_cast(T0); + } + auto Enum = expectToken(); + if (!Enum) { + if (Pub) { + Pub->unref(); + } + skipPastLineFoldEnd(); + return nullptr; + } + auto Name = expectToken(); + if (!Name) { + if (Pub) { + Pub->unref(); + } + Enum->unref(); + skipPastLineFoldEnd(); + return nullptr; + } + std::vector TVs; + for (;;) { auto T0 = Tokens.peek(); - switch (T0->getKind()) { - case NodeKind::LetKeyword: - return parseLetDeclaration(); - case NodeKind::TypeKeyword: - // TODO - default: - DE.add(File, T0, std::vector { NodeKind::LetKeyword, NodeKind::TypeKeyword }); - skipPastLineFoldEnd(); - return nullptr; + if (T0->getKind() == NodeKind::BlockStart) { + break; + } + auto Var = parseVarTypeExpression(); + if (Var) { + TVs.push_back(Var); } } - - Node* Parser::parseSourceElement() { - auto T0 = peekFirstTokenAfterAnnotationsAndModifiers(); - switch (T0->getKind()) { - case NodeKind::LetKeyword: - return parseLetDeclaration(); - case NodeKind::IfKeyword: - return parseIfStatement(); - case NodeKind::ClassKeyword: - return parseClassDeclaration(); - case NodeKind::InstanceKeyword: - return parseInstanceDeclaration(); - case NodeKind::StructKeyword: - return parseRecordDeclaration(); - case NodeKind::EnumKeyword: - return parseVariantDeclaration(); - default: - return parseExpressionStatement(); + auto BS = expectToken(); + if (!BS) { + if (Pub) { + Pub->unref(); } + Enum->unref(); + Name->unref(); + skipPastLineFoldEnd(); + return nullptr; } - - SourceFile* Parser::parseSourceFile() { - std::vector Elements; - for (;;) { - auto T0 = Tokens.peek(); - if (T0->is()) { - break; - } - auto Element = parseSourceElement(); - if (Element) { - Elements.push_back(Element); - } - } - return new SourceFile(File, Elements); - } - - std::vector Parser::parseAnnotations() { - std::vector Annotations; - for (;;) { - auto T0 = Tokens.peek(); - if (T0->getKind() != NodeKind::At) { - break; - } - auto At = static_cast(T0); - Tokens.get(); - auto T1 = Tokens.peek(); - switch (T1->getKind()) { - case NodeKind::Colon: - { - auto Colon = static_cast(T1); - Tokens.get(); - auto TE = parsePrimitiveTypeExpression(); - if (!TE) { - // TODO - continue; - } - Annotations.push_back(new TypeAssertAnnotation { At, Colon, TE }); - continue; - } - default: - { - // auto Name = static_cast(T1); - // Tokens.get(); - auto E = parseExpression(); - if (!E) { - At->unref(); - skipPastLineFoldEnd(); - continue; - } - checkLineFoldEnd(); - Annotations.push_back(new ExpressionAnnotation { At, E }); - continue; - } - // default: - // DE.add(File, T1, std::vector { NodeKind::Colon, NodeKind::Identifier }); - // At->unref(); - // skipToLineFoldEnd(); - // break; - } -next_annotation:; - } - return Annotations; - } - - void Parser::skipToRBrace() { - unsigned ParenLevel = 0; - unsigned BracketLevel = 0; - unsigned BraceLevel = 0; - unsigned BlockLevel = 0; - for (;;) { - auto T0 = Tokens.peek(); - switch (T0->getKind()) { - case NodeKind::EndOfFile: - return; - case NodeKind::LineFoldEnd: - Tokens.get()->unref(); - if (BlockLevel == 0 && ParenLevel == 0 && BracketLevel == 0 && BlockLevel == 0) { - return; - } - break; - case NodeKind::BlockStart: - Tokens.get()->unref(); - BlockLevel++; - break; - case NodeKind::BlockEnd: - Tokens.get()->unref(); - BlockLevel--; - break; - case NodeKind::LParen: - Tokens.get()->unref(); - ParenLevel++; - break; - case NodeKind::LBracket: - Tokens.get()->unref(); - BracketLevel++; - break; - case NodeKind::LBrace: - Tokens.get()->unref(); - BraceLevel++; - break; - case NodeKind::RParen: - Tokens.get()->unref(); - ParenLevel--; - break; - case NodeKind::RBracket: - Tokens.get()->unref(); - BracketLevel--; - break; - case NodeKind::RBrace: - if (BlockLevel == 0 && ParenLevel == 0 && BracketLevel == 0 && BlockLevel == 0) { - return; - } - Tokens.get()->unref(); - BraceLevel--; - break; - default: - Tokens.get()->unref(); - break; - } - } - } - - void Parser::skipPastLineFoldEnd() { - unsigned ParenLevel = 0; - unsigned BracketLevel = 0; - unsigned BraceLevel = 0; - unsigned BlockLevel = 0; - for (;;) { - auto T0 = Tokens.get(); - switch (T0->getKind()) { - case NodeKind::EndOfFile: - return; - case NodeKind::LineFoldEnd: - T0->unref(); - if (BlockLevel == 0 && ParenLevel == 0 && BracketLevel == 0 && BlockLevel == 0) { - return; - } - break; - case NodeKind::BlockStart: - T0->unref(); - BlockLevel++; - break; - case NodeKind::BlockEnd: - T0->unref(); - BlockLevel--; - break; - case NodeKind::LParen: - T0->unref(); - ParenLevel++; - break; - case NodeKind::LBracket: - T0->unref(); - BracketLevel++; - break; - case NodeKind::LBrace: - T0->unref(); - BraceLevel++; - break; - case NodeKind::RParen: - T0->unref(); - ParenLevel--; - break; - case NodeKind::RBracket: - T0->unref(); - BracketLevel--; - break; - case NodeKind::RBrace: - T0->unref(); - BraceLevel--; - break; - default: - T0->unref(); - break; - } - } - } - - void Parser::checkLineFoldEnd() { + std::vector Members; + for (;;) { +next_member: auto T0 = Tokens.peek(); - if (T0->getKind() == NodeKind::LineFoldEnd) { + if (T0->getKind() == NodeKind::BlockEnd) { Tokens.get()->unref(); - } else { - DE.add(File, T0, std::vector { NodeKind::LineFoldEnd }); + break; + } + auto Name = expectToken(); + if (!Name) { skipPastLineFoldEnd(); + continue; + } + auto T1 = Tokens.peek(); + if (T1->getKind() == NodeKind::BlockStart) { + Tokens.get(); + auto BS = static_cast(T1); + auto Fields = parseRecordDeclarationFields(); + // TODO continue; on error in Fields + Members.push_back(new RecordVariantDeclarationMember { Name, BS, Fields }); + } else { + std::vector Elements; + for (;;) { + auto T2 = Tokens.peek(); + if (T2->getKind() == NodeKind::LineFoldEnd) { + Tokens.get()->unref(); + break; + } + auto TE = parsePrimitiveTypeExpression(); + if (!TE) { + Name->unref(); + for (auto El: Elements) { + El->unref(); + } + goto next_member; + } + Elements.push_back(TE); + } + Members.push_back(new TupleVariantDeclarationMember { Name, Elements }); } } + checkLineFoldEnd(); + return new VariantDeclaration { Pub, Enum, Name, TVs, BS, Members }; +} + +Node* Parser::parseClassElement() { + auto T0 = Tokens.peek(); + switch (T0->getKind()) { + case NodeKind::LetKeyword: + return parseLetDeclaration(); + case NodeKind::TypeKeyword: + // TODO + default: + DE.add(File, T0, std::vector { NodeKind::LetKeyword, NodeKind::TypeKeyword }); + skipPastLineFoldEnd(); + return nullptr; + } +} + +Node* Parser::parseSourceElement() { + auto T0 = peekFirstTokenAfterAnnotationsAndModifiers(); + switch (T0->getKind()) { + case NodeKind::LetKeyword: + return parseLetDeclaration(); + case NodeKind::IfKeyword: + return parseIfStatement(); + case NodeKind::ClassKeyword: + return parseClassDeclaration(); + case NodeKind::InstanceKeyword: + return parseInstanceDeclaration(); + case NodeKind::StructKeyword: + return parseRecordDeclaration(); + case NodeKind::EnumKeyword: + return parseVariantDeclaration(); + default: + return parseExpressionStatement(); + } +} + +SourceFile* Parser::parseSourceFile() { + std::vector Elements; + for (;;) { + auto T0 = Tokens.peek(); + if (T0->is()) { + break; + } + auto Element = parseSourceElement(); + if (Element) { + Elements.push_back(Element); + } + } + return new SourceFile(File, Elements); +} + +std::vector Parser::parseAnnotations() { + std::vector Annotations; + for (;;) { + auto T0 = Tokens.peek(); + if (T0->getKind() != NodeKind::At) { + break; + } + auto At = static_cast(T0); + Tokens.get(); + auto T1 = Tokens.peek(); + switch (T1->getKind()) { + case NodeKind::Colon: + { + auto Colon = static_cast(T1); + Tokens.get(); + auto TE = parsePrimitiveTypeExpression(); + if (!TE) { + // TODO + continue; + } + Annotations.push_back(new TypeAssertAnnotation { At, Colon, TE }); + continue; + } + default: + { + // auto Name = static_cast(T1); + // Tokens.get(); + auto E = parseExpression(); + if (!E) { + At->unref(); + skipPastLineFoldEnd(); + continue; + } + checkLineFoldEnd(); + Annotations.push_back(new ExpressionAnnotation { At, E }); + continue; + } + // default: + // DE.add(File, T1, std::vector { NodeKind::Colon, NodeKind::Identifier }); + // At->unref(); + // skipToLineFoldEnd(); + // break; + } +next_annotation:; + } + return Annotations; +} + +void Parser::skipToRBrace() { + unsigned ParenLevel = 0; + unsigned BracketLevel = 0; + unsigned BraceLevel = 0; + unsigned BlockLevel = 0; + for (;;) { + auto T0 = Tokens.peek(); + switch (T0->getKind()) { + case NodeKind::EndOfFile: + return; + case NodeKind::LineFoldEnd: + Tokens.get()->unref(); + if (BlockLevel == 0 && ParenLevel == 0 && BracketLevel == 0 && BlockLevel == 0) { + return; + } + break; + case NodeKind::BlockStart: + Tokens.get()->unref(); + BlockLevel++; + break; + case NodeKind::BlockEnd: + Tokens.get()->unref(); + BlockLevel--; + break; + case NodeKind::LParen: + Tokens.get()->unref(); + ParenLevel++; + break; + case NodeKind::LBracket: + Tokens.get()->unref(); + BracketLevel++; + break; + case NodeKind::LBrace: + Tokens.get()->unref(); + BraceLevel++; + break; + case NodeKind::RParen: + Tokens.get()->unref(); + ParenLevel--; + break; + case NodeKind::RBracket: + Tokens.get()->unref(); + BracketLevel--; + break; + case NodeKind::RBrace: + if (BlockLevel == 0 && ParenLevel == 0 && BracketLevel == 0 && BlockLevel == 0) { + return; + } + Tokens.get()->unref(); + BraceLevel--; + break; + default: + Tokens.get()->unref(); + break; + } + } +} + +void Parser::skipPastLineFoldEnd() { + unsigned ParenLevel = 0; + unsigned BracketLevel = 0; + unsigned BraceLevel = 0; + unsigned BlockLevel = 0; + for (;;) { + auto T0 = Tokens.get(); + switch (T0->getKind()) { + case NodeKind::EndOfFile: + return; + case NodeKind::LineFoldEnd: + T0->unref(); + if (BlockLevel == 0 && ParenLevel == 0 && BracketLevel == 0 && BlockLevel == 0) { + return; + } + break; + case NodeKind::BlockStart: + T0->unref(); + BlockLevel++; + break; + case NodeKind::BlockEnd: + T0->unref(); + BlockLevel--; + break; + case NodeKind::LParen: + T0->unref(); + ParenLevel++; + break; + case NodeKind::LBracket: + T0->unref(); + BracketLevel++; + break; + case NodeKind::LBrace: + T0->unref(); + BraceLevel++; + break; + case NodeKind::RParen: + T0->unref(); + ParenLevel--; + break; + case NodeKind::RBracket: + T0->unref(); + BracketLevel--; + break; + case NodeKind::RBrace: + T0->unref(); + BraceLevel--; + break; + default: + T0->unref(); + break; + } + } +} + +void Parser::checkLineFoldEnd() { + auto T0 = Tokens.peek(); + if (T0->getKind() == NodeKind::LineFoldEnd) { + Tokens.get()->unref(); + } else { + DE.add(File, T0, std::vector { NodeKind::LineFoldEnd }); + skipPastLineFoldEnd(); + } +} } diff --git a/bootstrap/cxx/src/Scanner.cc b/bootstrap/cxx/src/Scanner.cc index b69b1ec98..844f57ef2 100644 --- a/bootstrap/cxx/src/Scanner.cc +++ b/bootstrap/cxx/src/Scanner.cc @@ -13,495 +13,495 @@ namespace bolt { - static inline bool isWhiteSpace(Char Chr) { - switch (Chr) { - case ' ': - case '\n': - case '\r': - case '\t': - return true; - default: - return false; - } +static inline bool isWhiteSpace(Char Chr) { + switch (Chr) { + case ' ': + case '\n': + case '\r': + case '\t': + return true; + default: + return false; } - - static inline bool isOperatorPart(Char Chr) { - switch (Chr) { - case '+': - case '-': - case '*': - case '/': - case '^': - case '&': - case '|': - case '%': - case '$': - case '!': - case '?': - case '>': - case '<': - case '=': - return true; - default: - return false; - } - } - - static bool isDirectiveIdentifierStart(Char Chr) { - return (Chr >= 65 && Chr <= 90) // Uppercase letter - || (Chr >= 96 && Chr <= 122) // Lowercase letter - || Chr == '_'; - } - - static bool isIdentifierPart(Char Chr) { - return (Chr >= 65 && Chr <= 90) // Uppercase letter - || (Chr >= 96 && Chr <= 122) // Lowercase letter - || (Chr >= 48 && Chr <= 57) // Digit - || Chr == '_'; - } - - static int toDigit(Char Chr) { - ZEN_ASSERT(Chr >= 48 && Chr <= 57); - return Chr - 48; - } - - std::unordered_map Keywords = { - { "pub", NodeKind::PubKeyword }, - { "let", NodeKind::LetKeyword }, - { "foreign", NodeKind::ForeignKeyword }, - { "mut", NodeKind::MutKeyword }, - { "return", NodeKind::ReturnKeyword }, - { "type", NodeKind::TypeKeyword }, - { "mod", NodeKind::ModKeyword }, - { "if", NodeKind::IfKeyword }, - { "else", NodeKind::ElseKeyword }, - { "elif", NodeKind::ElifKeyword }, - { "match", NodeKind::MatchKeyword }, - { "class", NodeKind::ClassKeyword }, - { "instance", NodeKind::InstanceKeyword }, - { "struct", NodeKind::StructKeyword }, - { "enum", NodeKind::EnumKeyword }, - }; - - Scanner::Scanner(DiagnosticEngine& DE, TextFile& File, Stream& Chars): - DE(DE), File(File), Chars(Chars) {} - - std::string Scanner::scanIdentifier() { - auto Loc = getCurrentLoc(); - auto C0 = getChar(); - if (!isDirectiveIdentifierStart(C0)) { - DE.add(File, Loc, std::string { C0 }); - return nullptr; - } - ByteString Text { static_cast(C0) }; - for (;;) { - auto C1 = peekChar(); - if (!isIdentifierPart(C1)) { - break; - } - Text.push_back(C1); - getChar(); - } - return Text; } - Token* Scanner::readNullable() { +static inline bool isOperatorPart(Char Chr) { + switch (Chr) { + case '+': + case '-': + case '*': + case '/': + case '^': + case '&': + case '|': + case '%': + case '$': + case '!': + case '?': + case '>': + case '<': + case '=': + return true; + default: + return false; + } +} - TextLoc StartLoc; - Char C0; +static bool isDirectiveIdentifierStart(Char Chr) { + return (Chr >= 65 && Chr <= 90) // Uppercase letter + || (Chr >= 96 && Chr <= 122) // Lowercase letter + || Chr == '_'; +} - for (;;) { - StartLoc = getCurrentLoc(); - C0 = getChar(); - if (isWhiteSpace(C0)) { - continue; - } - if (C0 == '#') { - auto C1 = peekChar(0); - auto C2 = peekChar(1); - if (C1 == '!' && C2 == '!') { - getChar(); - getChar(); - auto Name = scanIdentifier(); - std::string Value; - for (;;) { - C0 = getChar(); - Value.push_back(C0); - if (C0 == '\n' || C0 == EOF) { - break; - } - } - continue; - } +static bool isIdentifierPart(Char Chr) { + return (Chr >= 65 && Chr <= 90) // Uppercase letter + || (Chr >= 96 && Chr <= 122) // Lowercase letter + || (Chr >= 48 && Chr <= 57) // Digit + || Chr == '_'; +} + +static int toDigit(Char Chr) { + ZEN_ASSERT(Chr >= 48 && Chr <= 57); + return Chr - 48; +} + +std::unordered_map Keywords = { + { "pub", NodeKind::PubKeyword }, + { "let", NodeKind::LetKeyword }, + { "foreign", NodeKind::ForeignKeyword }, + { "mut", NodeKind::MutKeyword }, + { "return", NodeKind::ReturnKeyword }, + { "type", NodeKind::TypeKeyword }, + { "mod", NodeKind::ModKeyword }, + { "if", NodeKind::IfKeyword }, + { "else", NodeKind::ElseKeyword }, + { "elif", NodeKind::ElifKeyword }, + { "match", NodeKind::MatchKeyword }, + { "class", NodeKind::ClassKeyword }, + { "instance", NodeKind::InstanceKeyword }, + { "struct", NodeKind::StructKeyword }, + { "enum", NodeKind::EnumKeyword }, +}; + +Scanner::Scanner(DiagnosticEngine& DE, TextFile& File, Stream& Chars): + DE(DE), File(File), Chars(Chars) {} + +std::string Scanner::scanIdentifier() { + auto Loc = getCurrentLoc(); + auto C0 = getChar(); + if (!isDirectiveIdentifierStart(C0)) { + DE.add(File, Loc, std::string { C0 }); + return nullptr; + } + ByteString Text { static_cast(C0) }; + for (;;) { + auto C1 = peekChar(); + if (!isIdentifierPart(C1)) { + break; + } + Text.push_back(C1); + getChar(); + } + return Text; +} + +Token* Scanner::readNullable() { + + TextLoc StartLoc; + Char C0; + + for (;;) { + StartLoc = getCurrentLoc(); + C0 = getChar(); + if (isWhiteSpace(C0)) { + continue; + } + if (C0 == '#') { + auto C1 = peekChar(0); + auto C2 = peekChar(1); + if (C1 == '!' && C2 == '!') { + getChar(); + getChar(); + auto Name = scanIdentifier(); + std::string Value; for (;;) { C0 = getChar(); + Value.push_back(C0); if (C0 == '\n' || C0 == EOF) { break; } } continue; } - break; + for (;;) { + C0 = getChar(); + if (C0 == '\n' || C0 == EOF) { + break; + } + } + continue; + } + break; + } + + switch (C0) { + + case static_cast(EOF): + return new EndOfFile(StartLoc); + + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { + Integer I = toDigit(C0); + for (;;) { + auto C1 = peekChar(); + switch (C1) { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + getChar(); + I = I * 10 + toDigit(C1); + break; + default: + goto digit_finish; + } + } +digit_finish: + return new IntegerLiteral(I, StartLoc); } - switch (C0) { + case 'A': + case 'B': + case 'C': + case 'D': + case 'E': + case 'F': + case 'G': + case 'H': + case 'I': + case 'J': + case 'K': + case 'L': + case 'M': + case 'N': + case 'O': + case 'P': + case 'Q': + case 'R': + case 'S': + case 'T': + case 'U': + case 'V': + case 'W': + case 'X': + case 'Y': + case 'Z': + { + ByteString Text { static_cast(C0) }; + for (;;) { + auto C1 = peekChar(); + if (!isIdentifierPart(C1)) { + break; + } + Text.push_back(C1); + getChar(); + } + return new IdentifierAlt(Text, StartLoc); + } - case static_cast(EOF): - return new EndOfFile(StartLoc); + case 'a': + case 'b': + case 'c': + case 'd': + case 'e': + case 'f': + case 'g': + case 'h': + case 'i': + case 'j': + case 'k': + case 'l': + case 'm': + case 'n': + case 'o': + case 'p': + case 'q': + case 'r': + case 's': + case 't': + case 'u': + case 'v': + case 'w': + case 'x': + case 'y': + case 'z': + case '_': + { + ByteString Text { static_cast(C0) }; + for (;;) { + auto C1 = peekChar(); + if (!isIdentifierPart(C1)) { + break; + } + Text.push_back(C1); + getChar(); + } + auto Match = Keywords.find(Text); + if (Match != Keywords.end()) { + switch (Match->second) { + case NodeKind::PubKeyword: + return new PubKeyword(StartLoc); + case NodeKind::LetKeyword: + return new LetKeyword(StartLoc); + case NodeKind::ForeignKeyword: + return new ForeignKeyword(StartLoc); + case NodeKind::MutKeyword: + return new MutKeyword(StartLoc); + case NodeKind::TypeKeyword: + return new TypeKeyword(StartLoc); + case NodeKind::ReturnKeyword: + return new ReturnKeyword(StartLoc); + case NodeKind::IfKeyword: + return new IfKeyword(StartLoc); + case NodeKind::ElifKeyword: + return new ElifKeyword(StartLoc); + case NodeKind::ElseKeyword: + return new ElseKeyword(StartLoc); + case NodeKind::MatchKeyword: + return new MatchKeyword(StartLoc); + case NodeKind::ClassKeyword: + return new ClassKeyword(StartLoc); + case NodeKind::InstanceKeyword: + return new InstanceKeyword(StartLoc); + case NodeKind::StructKeyword: + return new StructKeyword(StartLoc); + case NodeKind::EnumKeyword: + return new EnumKeyword(StartLoc); + default: + ZEN_UNREACHABLE + } + } + return new Identifier(Text, StartLoc); + } - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - { - Integer I = toDigit(C0); - for (;;) { - auto C1 = peekChar(); + case '"': + { + ByteString Text; + bool Escaping = false; + for (;;) { + auto Loc = getCurrentLoc(); + auto C1 = getChar(); + if (Escaping) { switch (C1) { - case '0': - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - getChar(); - I = I * 10 + toDigit(C1); + case 'a': Text.push_back('\a'); break; + case 'b': Text.push_back('\b'); break; + case 'f': Text.push_back('\f'); break; + case 'n': Text.push_back('\n'); break; + case 'r': Text.push_back('\r'); break; + case 't': Text.push_back('\t'); break; + case 'v': Text.push_back('\v'); break; + case '0': Text.push_back('\0'); break; + case '\'': Text.push_back('\''); break; + case '"': Text.push_back('"'); break; + default: + DE.add(File, Loc, String { static_cast(C1) }); + return nullptr; + } + Escaping = false; + } else { + switch (C1) { + case '"': + goto after_string_contents; + case '\\': + Escaping = true; break; default: - goto digit_finish; + Text.push_back(C1); + break; } } -digit_finish: - return new IntegerLiteral(I, StartLoc); } - - case 'A': - case 'B': - case 'C': - case 'D': - case 'E': - case 'F': - case 'G': - case 'H': - case 'I': - case 'J': - case 'K': - case 'L': - case 'M': - case 'N': - case 'O': - case 'P': - case 'Q': - case 'R': - case 'S': - case 'T': - case 'U': - case 'V': - case 'W': - case 'X': - case 'Y': - case 'Z': - { - ByteString Text { static_cast(C0) }; - for (;;) { - auto C1 = peekChar(); - if (!isIdentifierPart(C1)) { - break; - } - Text.push_back(C1); - getChar(); - } - return new IdentifierAlt(Text, StartLoc); - } - - case 'a': - case 'b': - case 'c': - case 'd': - case 'e': - case 'f': - case 'g': - case 'h': - case 'i': - case 'j': - case 'k': - case 'l': - case 'm': - case 'n': - case 'o': - case 'p': - case 'q': - case 'r': - case 's': - case 't': - case 'u': - case 'v': - case 'w': - case 'x': - case 'y': - case 'z': - case '_': - { - ByteString Text { static_cast(C0) }; - for (;;) { - auto C1 = peekChar(); - if (!isIdentifierPart(C1)) { - break; - } - Text.push_back(C1); - getChar(); - } - auto Match = Keywords.find(Text); - if (Match != Keywords.end()) { - switch (Match->second) { - case NodeKind::PubKeyword: - return new PubKeyword(StartLoc); - case NodeKind::LetKeyword: - return new LetKeyword(StartLoc); - case NodeKind::ForeignKeyword: - return new ForeignKeyword(StartLoc); - case NodeKind::MutKeyword: - return new MutKeyword(StartLoc); - case NodeKind::TypeKeyword: - return new TypeKeyword(StartLoc); - case NodeKind::ReturnKeyword: - return new ReturnKeyword(StartLoc); - case NodeKind::IfKeyword: - return new IfKeyword(StartLoc); - case NodeKind::ElifKeyword: - return new ElifKeyword(StartLoc); - case NodeKind::ElseKeyword: - return new ElseKeyword(StartLoc); - case NodeKind::MatchKeyword: - return new MatchKeyword(StartLoc); - case NodeKind::ClassKeyword: - return new ClassKeyword(StartLoc); - case NodeKind::InstanceKeyword: - return new InstanceKeyword(StartLoc); - case NodeKind::StructKeyword: - return new StructKeyword(StartLoc); - case NodeKind::EnumKeyword: - return new EnumKeyword(StartLoc); - default: - ZEN_UNREACHABLE - } - } - return new Identifier(Text, StartLoc); - } - - case '"': - { - ByteString Text; - bool Escaping = false; - for (;;) { - auto Loc = getCurrentLoc(); - auto C1 = getChar(); - if (Escaping) { - switch (C1) { - case 'a': Text.push_back('\a'); break; - case 'b': Text.push_back('\b'); break; - case 'f': Text.push_back('\f'); break; - case 'n': Text.push_back('\n'); break; - case 'r': Text.push_back('\r'); break; - case 't': Text.push_back('\t'); break; - case 'v': Text.push_back('\v'); break; - case '0': Text.push_back('\0'); break; - case '\'': Text.push_back('\''); break; - case '"': Text.push_back('"'); break; - default: - DE.add(File, Loc, String { static_cast(C1) }); - return nullptr; - } - Escaping = false; - } else { - switch (C1) { - case '"': - goto after_string_contents; - case '\\': - Escaping = true; - break; - default: - Text.push_back(C1); - break; - } - } - } after_string_contents: - return new StringLiteral(Text, StartLoc); - } + return new StringLiteral(Text, StartLoc); + } - case '.': - { + case '.': + { + auto C1 = peekChar(); + if (C1 == '.') { + getChar(); + auto C2 = peekChar(); + if (C2 == '.') { + DE.add(File, getCurrentLoc(), String { static_cast(C2) }); + return nullptr; + } + return new DotDot(StartLoc); + } + return new Dot(StartLoc); + } + + case '+': + case '-': + case '*': + case '/': + case '^': + case '&': + case '|': + case '%': + case '$': + case '!': + case '?': + case '>': + case '<': + case '=': + { + ByteString Text { static_cast(C0) }; + for (;;) { auto C1 = peekChar(); - if (C1 == '.') { - getChar(); - auto C2 = peekChar(); - if (C2 == '.') { - DE.add(File, getCurrentLoc(), String { static_cast(C2) }); - return nullptr; - } - return new DotDot(StartLoc); + if (!isOperatorPart(C1)) { + break; } - return new Dot(StartLoc); + Text.push_back(static_cast(C1)); + getChar(); } - - case '+': - case '-': - case '*': - case '/': - case '^': - case '&': - case '|': - case '%': - case '$': - case '!': - case '?': - case '>': - case '<': - case '=': - { - ByteString Text { static_cast(C0) }; - for (;;) { - auto C1 = peekChar(); - if (!isOperatorPart(C1)) { - break; - } - Text.push_back(static_cast(C1)); - getChar(); - } - if (Text == "|") { - return new VBar(StartLoc); - } else if (Text == "->") { - return new RArrow(StartLoc); - } else if (Text == "=>") { - return new RArrowAlt(StartLoc); - } else if (Text == "=") { - return new Equals(StartLoc); - } else if (Text.back() == '=' && Text[Text.size()-2] != '=') { - return new Assignment(Text.substr(0, Text.size()-1), StartLoc); - } - return new CustomOperator(Text, StartLoc); + if (Text == "|") { + return new VBar(StartLoc); + } else if (Text == "->") { + return new RArrow(StartLoc); + } else if (Text == "=>") { + return new RArrowAlt(StartLoc); + } else if (Text == "=") { + return new Equals(StartLoc); + } else if (Text.back() == '=' && Text[Text.size()-2] != '=') { + return new Assignment(Text.substr(0, Text.size()-1), StartLoc); } + return new CustomOperator(Text, StartLoc); + } #define BOLT_SIMPLE_TOKEN(ch, name) case ch: return new name(StartLoc); - BOLT_SIMPLE_TOKEN(',', Comma) - BOLT_SIMPLE_TOKEN(':', Colon) - BOLT_SIMPLE_TOKEN('(', LParen) - BOLT_SIMPLE_TOKEN(')', RParen) - BOLT_SIMPLE_TOKEN('[', LBracket) - BOLT_SIMPLE_TOKEN(']', RBracket) - BOLT_SIMPLE_TOKEN('{', LBrace) - BOLT_SIMPLE_TOKEN('}', RBrace) - BOLT_SIMPLE_TOKEN('~', Tilde) - BOLT_SIMPLE_TOKEN('@', At) + BOLT_SIMPLE_TOKEN(',', Comma) + BOLT_SIMPLE_TOKEN(':', Colon) + BOLT_SIMPLE_TOKEN('(', LParen) + BOLT_SIMPLE_TOKEN(')', RParen) + BOLT_SIMPLE_TOKEN('[', LBracket) + BOLT_SIMPLE_TOKEN(']', RBracket) + BOLT_SIMPLE_TOKEN('{', LBrace) + BOLT_SIMPLE_TOKEN('}', RBrace) + BOLT_SIMPLE_TOKEN('~', Tilde) + BOLT_SIMPLE_TOKEN('@', At) - default: - DE.add(File, StartLoc, String { static_cast(C0) }); - return nullptr; - - } + default: + DE.add(File, StartLoc, String { static_cast(C0) }); + return nullptr; } - Token* Scanner::read() { - for (;;) { - auto T0 = readNullable(); - if (T0) { - // EndOFFile is guaranteed to be produced, so that ends the stream. - return T0; - } +} + +Token* Scanner::read() { + for (;;) { + auto T0 = readNullable(); + if (T0) { + // EndOFFile is guaranteed to be produced, so that ends the stream. + return T0; } } +} - Punctuator::Punctuator(Stream& Tokens): - Tokens(Tokens) { - Frames.push(FrameType::Block); - Locations.push(TextLoc { 0, 0 }); - } +Punctuator::Punctuator(Stream& Tokens): + Tokens(Tokens) { + Frames.push(FrameType::Block); + Locations.push(TextLoc { 0, 0 }); + } - Token* Punctuator::read() { +Token* Punctuator::read() { - auto T0 = Tokens.peek(); + auto T0 = Tokens.peek(); - switch (T0->getKind()) { - case NodeKind::LBrace: - Frames.push(FrameType::Fallthrough); - break; - case NodeKind::EndOfFile: - { - if (Frames.size() == 1) { - return T0; - } - auto Frame = Frames.top(); - Frames.pop(); - switch (Frame) { - case FrameType::Fallthrough: - break; - case FrameType::Block: - return new BlockEnd(T0->getStartLoc()); - case FrameType::LineFold: - return new LineFoldEnd(T0->getStartLoc()); - } - } - default: - break; - } - - auto RefLoc = Locations.top(); - switch (Frames.top()) { - case FrameType::Fallthrough: - { - if (T0->getKind() == NodeKind::RBrace) { - Frames.pop(); - } - Tokens.get(); + switch (T0->getKind()) { + case NodeKind::LBrace: + Frames.push(FrameType::Fallthrough); + break; + case NodeKind::EndOfFile: + { + if (Frames.size() == 1) { return T0; } - case FrameType::LineFold: - { - if (T0->getStartLine() > RefLoc.Line - && T0->getStartColumn() <= RefLoc.Column) { - Frames.pop(); - Locations.pop(); - return new LineFoldEnd(T0->getStartLoc()); - } - if (isa(T0)) { - auto T1 = Tokens.peek(1); - if (T1->getStartLine() > T0->getEndLine()) { - Tokens.get(); - Frames.push(FrameType::Block); - return new BlockStart(T0->getStartLoc()); - } - } - return Tokens.get(); - } - case FrameType::Block: - { - if (T0->getStartColumn() <= RefLoc.Column) { - Frames.pop(); + auto Frame = Frames.top(); + Frames.pop(); + switch (Frame) { + case FrameType::Fallthrough: + break; + case FrameType::Block: return new BlockEnd(T0->getStartLoc()); - } - - Frames.push(FrameType::LineFold); - Locations.push(T0->getStartLoc()); - - return Tokens.get(); + case FrameType::LineFold: + return new LineFoldEnd(T0->getStartLoc()); } } - - ZEN_UNREACHABLE + default: + break; } + auto RefLoc = Locations.top(); + switch (Frames.top()) { + case FrameType::Fallthrough: + { + if (T0->getKind() == NodeKind::RBrace) { + Frames.pop(); + } + Tokens.get(); + return T0; + } + case FrameType::LineFold: + { + if (T0->getStartLine() > RefLoc.Line + && T0->getStartColumn() <= RefLoc.Column) { + Frames.pop(); + Locations.pop(); + return new LineFoldEnd(T0->getStartLoc()); + } + if (isa(T0)) { + auto T1 = Tokens.peek(1); + if (T1->getStartLine() > T0->getEndLine()) { + Tokens.get(); + Frames.push(FrameType::Block); + return new BlockStart(T0->getStartLoc()); + } + } + return Tokens.get(); + } + case FrameType::Block: + { + if (T0->getStartColumn() <= RefLoc.Column) { + Frames.pop(); + return new BlockEnd(T0->getStartLoc()); + } + + Frames.push(FrameType::LineFold); + Locations.push(T0->getStartLoc()); + + return Tokens.get(); + } + } + + ZEN_UNREACHABLE +} + } diff --git a/bootstrap/cxx/src/Text.cc b/bootstrap/cxx/src/Text.cc index f7e38a11b..559eeb012 100644 --- a/bootstrap/cxx/src/Text.cc +++ b/bootstrap/cxx/src/Text.cc @@ -6,48 +6,48 @@ namespace bolt { - TextFile::TextFile(ByteString Path, ByteString Text): - Path(Path), Text(Text) { - LineOffsets.push_back(0); - for (size_t I = 0; I < Text.size(); I++) { - auto Chr = Text[I]; - if (Chr == '\n') { - LineOffsets.push_back(I+1); - } - } - LineOffsets.push_back(Text.size()); - } - - size_t TextFile::getLineCount() const { - return LineOffsets.size(); - } - - size_t TextFile::getStartOffset(size_t Line) const { - return LineOffsets[Line-1]; - } - - size_t TextFile::getLine(size_t Offset) const { - ZEN_ASSERT(Offset < Text.size()); - for (size_t I = 0; I < LineOffsets.size(); ++I) { - if (LineOffsets[I] > Offset) { - return I; +TextFile::TextFile(ByteString Path, ByteString Text): + Path(Path), Text(Text) { + LineOffsets.push_back(0); + for (size_t I = 0; I < Text.size(); I++) { + auto Chr = Text[I]; + if (Chr == '\n') { + LineOffsets.push_back(I+1); } } - ZEN_UNREACHABLE + LineOffsets.push_back(Text.size()); } - size_t TextFile::getColumn(size_t Offset) const { - auto Line = getLine(Offset); - auto StartOffset = getStartOffset(Line); - return Offset - StartOffset + 1 ; - } +size_t TextFile::getLineCount() const { + return LineOffsets.size(); +} - ByteString TextFile::getPath() const { - return Path; - } +size_t TextFile::getStartOffset(size_t Line) const { + return LineOffsets[Line-1]; +} - ByteString TextFile::getText() const { - return Text; +size_t TextFile::getLine(size_t Offset) const { + ZEN_ASSERT(Offset < Text.size()); + for (size_t I = 0; I < LineOffsets.size(); ++I) { + if (LineOffsets[I] > Offset) { + return I; + } } + ZEN_UNREACHABLE +} + +size_t TextFile::getColumn(size_t Offset) const { + auto Line = getLine(Offset); + auto StartOffset = getStartOffset(Line); + return Offset - StartOffset + 1 ; +} + +ByteString TextFile::getPath() const { + return Path; +} + +ByteString TextFile::getText() const { + return Text; +} } diff --git a/bootstrap/cxx/src/Types.cc b/bootstrap/cxx/src/Types.cc index 9429af865..2ff1f318f 100644 --- a/bootstrap/cxx/src/Types.cc +++ b/bootstrap/cxx/src/Types.cc @@ -8,328 +8,328 @@ namespace bolt { - bool TypeclassSignature::operator<(const TypeclassSignature& Other) const { - if (Id < Other.Id) { +bool TypeclassSignature::operator<(const TypeclassSignature& Other) const { + if (Id < Other.Id) { + return true; + } + ZEN_ASSERT(Params.size() == 1); + ZEN_ASSERT(Other.Params.size() == 1); + return Params[0]->asCon().Id < Other.Params[0]->asCon().Id; +} + +bool TypeclassSignature::operator==(const TypeclassSignature& Other) const { + ZEN_ASSERT(Params.size() == 1); + ZEN_ASSERT(Other.Params.size() == 1); + return Id == Other.Id && Params[0]->asCon().Id == Other.Params[0]->asCon().Id; +} + +bool TypeIndex::operator==(const TypeIndex& Other) const noexcept { + if (Kind != Other.Kind) { + return false; + } + switch (Kind) { + case TypeIndexKind::ArrowParamType: + case TypeIndexKind::TupleElement: + return I == Other.I; + default: return true; - } - ZEN_ASSERT(Params.size() == 1); - ZEN_ASSERT(Other.Params.size() == 1); - return Params[0]->asCon().Id < Other.Params[0]->asCon().Id; } +} - bool TypeclassSignature::operator==(const TypeclassSignature& Other) const { - ZEN_ASSERT(Params.size() == 1); - ZEN_ASSERT(Other.Params.size() == 1); - return Id == Other.Id && Params[0]->asCon().Id == Other.Params[0]->asCon().Id; - } +bool TCon::operator==(const TCon& Other) const { + return Id == Other.Id; +} - bool TypeIndex::operator==(const TypeIndex& Other) const noexcept { - if (Kind != Other.Kind) { +bool TApp::operator==(const TApp& Other) const { + return *Op == *Other.Op && *Arg == *Other.Arg; +} + +bool TVar::operator==(const TVar& Other) const { + return Id == Other.Id; +} + +bool TArrow::operator==(const TArrow& Other) const { + return *ParamType == *Other.ParamType + && *ReturnType == *Other.ReturnType; +} + +bool TTuple::operator==(const TTuple& Other) const { + for (auto [T1, T2]: zen::zip(ElementTypes, Other.ElementTypes)) { + if (*T1 != *T2) { return false; } - switch (Kind) { - case TypeIndexKind::ArrowParamType: - case TypeIndexKind::TupleElement: - return I == Other.I; - default: - return true; + } + return true; +} + +bool TNil::operator==(const TNil& Other) const { + return true; +} + +bool TField::operator==(const TField& Other) const { + return Name == Other.Name && *Ty == *Other.Ty && *RestTy == *Other.RestTy; +} + +bool TAbsent::operator==(const TAbsent& Other) const { + return true; +} + +bool TPresent::operator==(const TPresent& Other) const { + return *Ty == *Other.Ty; +} + +bool Type::operator==(const Type& Other) const { + if (Kind != Other.Kind) { + return false; + } + switch (Kind) { + case TypeKind::Var: + return Var == Other.Var; + case TypeKind::Con: + return Con == Other.Con; + case TypeKind::Present: + return Present == Other.Present; + case TypeKind::Absent: + return Absent == Other.Absent; + case TypeKind::Arrow: + return Arrow == Other.Arrow; + case TypeKind::Field: + return Field == Other.Field; + case TypeKind::Nil: + return Nil == Other.Nil; + case TypeKind::Tuple: + return Tuple == Other.Tuple; + case TypeKind::App: + return App == Other.App; + } + ZEN_UNREACHABLE +} + +void Type::visitEachChild(std::function Proc) { + switch (Kind) { + case TypeKind::Var: + case TypeKind::Absent: + case TypeKind::Nil: + case TypeKind::Con: + break; + case TypeKind::Arrow: + { + Proc(Arrow.ParamType); + Proc(Arrow.ReturnType); + break; + } + case TypeKind::Tuple: + { + for (auto I = 0; I < Tuple.ElementTypes.size(); ++I) { + Proc(Tuple.ElementTypes[I]); + } + break; + } + case TypeKind::App: + { + Proc(App.Op); + Proc(App.Arg); + break; + } + case TypeKind::Field: + { + Proc(Field.Ty); + Proc(Field.RestTy); + break; + } + case TypeKind::Present: + { + Proc(Present.Ty); + break; } } +} - bool TCon::operator==(const TCon& Other) const { - return Id == Other.Id; - } - - bool TApp::operator==(const TApp& Other) const { - return *Op == *Other.Op && *Arg == *Other.Arg; - } - - bool TVar::operator==(const TVar& Other) const { - return Id == Other.Id; - } - - bool TArrow::operator==(const TArrow& Other) const { - return *ParamType == *Other.ParamType - && *ReturnType == *Other.ReturnType; - } - - bool TTuple::operator==(const TTuple& Other) const { - for (auto [T1, T2]: zen::zip(ElementTypes, Other.ElementTypes)) { - if (*T1 != *T2) { - return false; - } +Type* Type::rewrite(std::function Fn, bool Recursive) { + auto Ty2 = Fn(this); + if (this != Ty2) { + if (Recursive) { + return Ty2->rewrite(Fn, Recursive); } - return true; + return Ty2; } - - bool TNil::operator==(const TNil& Other) const { - return true; - } - - bool TField::operator==(const TField& Other) const { - return Name == Other.Name && *Ty == *Other.Ty && *RestTy == *Other.RestTy; - } - - bool TAbsent::operator==(const TAbsent& Other) const { - return true; - } - - bool TPresent::operator==(const TPresent& Other) const { - return *Ty == *Other.Ty; - } - - bool Type::operator==(const Type& Other) const { - if (Kind != Other.Kind) { - return false; - } - switch (Kind) { - case TypeKind::Var: - return Var == Other.Var; - case TypeKind::Con: - return Con == Other.Con; - case TypeKind::Present: - return Present == Other.Present; - case TypeKind::Absent: - return Absent == Other.Absent; - case TypeKind::Arrow: - return Arrow == Other.Arrow; - case TypeKind::Field: - return Field == Other.Field; - case TypeKind::Nil: - return Nil == Other.Nil; - case TypeKind::Tuple: - return Tuple == Other.Tuple; - case TypeKind::App: - return App == Other.App; - } - ZEN_UNREACHABLE - } - - void Type::visitEachChild(std::function Proc) { - switch (Kind) { - case TypeKind::Var: - case TypeKind::Absent: - case TypeKind::Nil: - case TypeKind::Con: - break; - case TypeKind::Arrow: - { - Proc(Arrow.ParamType); - Proc(Arrow.ReturnType); - break; - } - case TypeKind::Tuple: - { - for (auto I = 0; I < Tuple.ElementTypes.size(); ++I) { - Proc(Tuple.ElementTypes[I]); - } - break; - } - case TypeKind::App: - { - Proc(App.Op); - Proc(App.Arg); - break; - } - case TypeKind::Field: - { - Proc(Field.Ty); - Proc(Field.RestTy); - break; - } - case TypeKind::Present: - { - Proc(Present.Ty); - break; - } - } - } - - Type* Type::rewrite(std::function Fn, bool Recursive) { - auto Ty2 = Fn(this); - if (this != Ty2) { - if (Recursive) { - return Ty2->rewrite(Fn, Recursive); - } + switch (Kind) { + case TypeKind::Var: return Ty2; + case TypeKind::Arrow: + { + auto Arrow = Ty2->asArrow(); + bool Changed = false; + Type* NewParamType = Arrow.ParamType->rewrite(Fn, Recursive); + if (NewParamType != Arrow.ParamType) { + Changed = true; + } + auto NewRetTy = Arrow.ReturnType->rewrite(Fn, Recursive); + if (NewRetTy != Arrow.ReturnType) { + Changed = true; + } + return Changed ? new Type(TArrow(NewParamType, NewRetTy)) : Ty2; } - switch (Kind) { - case TypeKind::Var: + case TypeKind::Con: + return Ty2; + case TypeKind::App: + { + auto App = Ty2->asApp(); + auto NewOp = App.Op->rewrite(Fn, Recursive); + auto NewArg = App.Arg->rewrite(Fn, Recursive); + if (NewOp == App.Op && NewArg == App.Arg) { return Ty2; - case TypeKind::Arrow: - { - auto Arrow = Ty2->asArrow(); - bool Changed = false; - Type* NewParamType = Arrow.ParamType->rewrite(Fn, Recursive); - if (NewParamType != Arrow.ParamType) { - Changed = true; - } - auto NewRetTy = Arrow.ReturnType->rewrite(Fn, Recursive); - if (NewRetTy != Arrow.ReturnType) { - Changed = true; - } - return Changed ? new Type(TArrow(NewParamType, NewRetTy)) : Ty2; - } - case TypeKind::Con: - return Ty2; - case TypeKind::App: - { - auto App = Ty2->asApp(); - auto NewOp = App.Op->rewrite(Fn, Recursive); - auto NewArg = App.Arg->rewrite(Fn, Recursive); - if (NewOp == App.Op && NewArg == App.Arg) { - return Ty2; - } - return new Type(TApp(NewOp, NewArg)); - } - case TypeKind::Tuple: - { - auto Tuple = Ty2->asTuple(); - bool Changed = false; - std::vector NewElementTypes; - for (auto Ty: Tuple.ElementTypes) { - auto NewElementType = Ty->rewrite(Fn, Recursive); - if (NewElementType != Ty) { - Changed = true; - } - NewElementTypes.push_back(NewElementType); - } - return Changed ? new Type(TTuple(NewElementTypes)) : Ty2; - } - case TypeKind::Nil: - return Ty2; - case TypeKind::Absent: - return Ty2; - case TypeKind::Field: - { - auto Field = Ty2->asField(); - bool Changed = false; - auto NewTy = Field.Ty->rewrite(Fn, Recursive); - if (NewTy != Field.Ty) { - Changed = true; - } - auto NewRestTy = Field.RestTy->rewrite(Fn, Recursive); - if (NewRestTy != Field.RestTy) { - Changed = true; - } - return Changed ? new Type(TField(Field.Name, NewTy, NewRestTy)) : Ty2; - } - case TypeKind::Present: - { - auto Present = Ty2->asPresent(); - auto NewTy = Present.Ty->rewrite(Fn, Recursive); - if (NewTy == Present.Ty) { - return Ty2; - } - return new Type(TPresent(NewTy)); } + return new Type(TApp(NewOp, NewArg)); } - ZEN_UNREACHABLE - } - - Type* Type::substitute(const TVSub &Sub) { - return rewrite([&](auto Ty) { - if (Ty->isVar()) { - auto Match = Sub.find(Ty); - return Match != Sub.end() ? Match->second->substitute(Sub) : Ty; - } - return Ty; - }, false); - } - - Type* Type::resolve(const TypeIndex& Index) const noexcept { - switch (Index.Kind) { - case TypeIndexKind::PresentType: - return this->asPresent().Ty; - case TypeIndexKind::AppOpType: - return this->asApp().Op; - case TypeIndexKind::AppArgType: - return this->asApp().Arg; - case TypeIndexKind::TupleElement: - return this->asTuple().ElementTypes[Index.I]; - case TypeIndexKind::ArrowParamType: - return this->asArrow().ParamType; - case TypeIndexKind::ArrowReturnType: - return this->asArrow().ReturnType; - case TypeIndexKind::FieldType: - return this->asField().Ty; - case TypeIndexKind::FieldRestType: - return this->asField().RestTy; - case TypeIndexKind::End: - ZEN_UNREACHABLE - } - ZEN_UNREACHABLE - } - - TVSet Type::getTypeVars() { - TVSet Out; - std::function visit = [&](Type* Ty) { - if (Ty->isVar()) { - Out.emplace(Ty); - return; - } - Ty->visitEachChild(visit); - }; - visit(this); - return Out; - } - - TypeIterator Type::begin() { - return TypeIterator { this, getStartIndex() }; - } - - TypeIterator Type::end() { - return TypeIterator { this, getEndIndex() }; - } - - TypeIndex Type::getStartIndex() const { - switch (Kind) { - case TypeKind::Arrow: - return TypeIndex::forArrowParamType(); - case TypeKind::Tuple: - { - if (asTuple().ElementTypes.empty()) { - return TypeIndex(TypeIndexKind::End); + case TypeKind::Tuple: + { + auto Tuple = Ty2->asTuple(); + bool Changed = false; + std::vector NewElementTypes; + for (auto Ty: Tuple.ElementTypes) { + auto NewElementType = Ty->rewrite(Fn, Recursive); + if (NewElementType != Ty) { + Changed = true; } - return TypeIndex::forTupleElement(0); + NewElementTypes.push_back(NewElementType); } - case TypeKind::Field: - return TypeIndex::forFieldType(); - default: + return Changed ? new Type(TTuple(NewElementTypes)) : Ty2; + } + case TypeKind::Nil: + return Ty2; + case TypeKind::Absent: + return Ty2; + case TypeKind::Field: + { + auto Field = Ty2->asField(); + bool Changed = false; + auto NewTy = Field.Ty->rewrite(Fn, Recursive); + if (NewTy != Field.Ty) { + Changed = true; + } + auto NewRestTy = Field.RestTy->rewrite(Fn, Recursive); + if (NewRestTy != Field.RestTy) { + Changed = true; + } + return Changed ? new Type(TField(Field.Name, NewTy, NewRestTy)) : Ty2; + } + case TypeKind::Present: + { + auto Present = Ty2->asPresent(); + auto NewTy = Present.Ty->rewrite(Fn, Recursive); + if (NewTy == Present.Ty) { + return Ty2; + } + return new Type(TPresent(NewTy)); + } + } + ZEN_UNREACHABLE +} + +Type* Type::substitute(const TVSub &Sub) { + return rewrite([&](auto Ty) { + if (Ty->isVar()) { + auto Match = Sub.find(Ty); + return Match != Sub.end() ? Match->second->substitute(Sub) : Ty; + } + return Ty; + }, false); +} + +Type* Type::resolve(const TypeIndex& Index) const noexcept { + switch (Index.Kind) { + case TypeIndexKind::PresentType: + return this->asPresent().Ty; + case TypeIndexKind::AppOpType: + return this->asApp().Op; + case TypeIndexKind::AppArgType: + return this->asApp().Arg; + case TypeIndexKind::TupleElement: + return this->asTuple().ElementTypes[Index.I]; + case TypeIndexKind::ArrowParamType: + return this->asArrow().ParamType; + case TypeIndexKind::ArrowReturnType: + return this->asArrow().ReturnType; + case TypeIndexKind::FieldType: + return this->asField().Ty; + case TypeIndexKind::FieldRestType: + return this->asField().RestTy; + case TypeIndexKind::End: + ZEN_UNREACHABLE + } + ZEN_UNREACHABLE +} + +TVSet Type::getTypeVars() { + TVSet Out; + std::function visit = [&](Type* Ty) { + if (Ty->isVar()) { + Out.emplace(Ty); + return; + } + Ty->visitEachChild(visit); + }; + visit(this); + return Out; +} + +TypeIterator Type::begin() { + return TypeIterator { this, getStartIndex() }; +} + +TypeIterator Type::end() { + return TypeIterator { this, getEndIndex() }; +} + +TypeIndex Type::getStartIndex() const { + switch (Kind) { + case TypeKind::Arrow: + return TypeIndex::forArrowParamType(); + case TypeKind::Tuple: + { + if (asTuple().ElementTypes.empty()) { return TypeIndex(TypeIndexKind::End); + } + return TypeIndex::forTupleElement(0); } + case TypeKind::Field: + return TypeIndex::forFieldType(); + default: + return TypeIndex(TypeIndexKind::End); } +} - TypeIndex Type::getEndIndex() const { - return TypeIndex(TypeIndexKind::End); - } +TypeIndex Type::getEndIndex() const { + return TypeIndex(TypeIndexKind::End); +} - bool Type::hasTypeVar(Type* TV) const { - switch (Kind) { - case TypeKind::Var: - return Var.Id == TV->asVar().Id; - case TypeKind::Con: - case TypeKind::Absent: - case TypeKind::Nil: - return false; - case TypeKind::App: - return App.Op->hasTypeVar(TV) || App.Arg->hasTypeVar(TV); - case TypeKind::Tuple: - for (auto Ty: Tuple.ElementTypes) { - if (Ty->hasTypeVar(TV)) { - return true; - } +bool Type::hasTypeVar(Type* TV) const { + switch (Kind) { + case TypeKind::Var: + return Var.Id == TV->asVar().Id; + case TypeKind::Con: + case TypeKind::Absent: + case TypeKind::Nil: + return false; + case TypeKind::App: + return App.Op->hasTypeVar(TV) || App.Arg->hasTypeVar(TV); + case TypeKind::Tuple: + for (auto Ty: Tuple.ElementTypes) { + if (Ty->hasTypeVar(TV)) { + return true; } - return false; - case TypeKind::Field: - return Field.Ty->hasTypeVar(TV) || Field.RestTy->hasTypeVar(TV); - case TypeKind::Arrow: - return Arrow.ParamType->hasTypeVar(TV) || Arrow.ReturnType->hasTypeVar(TV); - case TypeKind::Present: - return Present.Ty->hasTypeVar(TV); - } - ZEN_UNREACHABLE + } + return false; + case TypeKind::Field: + return Field.Ty->hasTypeVar(TV) || Field.RestTy->hasTypeVar(TV); + case TypeKind::Arrow: + return Arrow.ParamType->hasTypeVar(TV) || Arrow.ReturnType->hasTypeVar(TV); + case TypeKind::Present: + return Present.Ty->hasTypeVar(TV); } + ZEN_UNREACHABLE +} }