From 719dbfcad44da2c857657f551f8778156c67c13d Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Sat, 9 Mar 2024 12:51:35 +0100 Subject: [PATCH] Split LetDeclaration into VariableDeclaration and FunctionDeclaration We only generate VariableDeclaration when we're absolutely sure it is a variable. --- bootstrap/cxx/CMakeLists.txt | 22 +- bootstrap/cxx/include/bolt/CST.hpp | 648 +++++++++++++++++----- bootstrap/cxx/include/bolt/CSTVisitor.hpp | 167 +++++- bootstrap/cxx/include/bolt/Checker.hpp | 4 +- bootstrap/cxx/include/bolt/Common.hpp | 2 + bootstrap/cxx/include/bolt/Evaluator.hpp | 10 +- bootstrap/cxx/include/bolt/Parser.hpp | 2 +- bootstrap/cxx/src/CST.cc | 406 ++++++-------- bootstrap/cxx/src/Checker.cc | 253 +++++---- bootstrap/cxx/src/ConsolePrinter.cc | 7 +- bootstrap/cxx/src/Evaluator.cc | 45 +- bootstrap/cxx/src/Parser.cc | 154 +++-- 12 files changed, 1156 insertions(+), 564 deletions(-) diff --git a/bootstrap/cxx/CMakeLists.txt b/bootstrap/cxx/CMakeLists.txt index 5ae4b36ed..b2795fff4 100644 --- a/bootstrap/cxx/CMakeLists.txt +++ b/bootstrap/cxx/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.10) -project(Bolt CXX) +project(Bolt C CXX) set(CMAKE_CXX_STANDARD 20) @@ -17,6 +17,8 @@ if (CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo" OR CMAKE_BUILD_TYPE STREQUAL "Deb set(BOLT_DEBUG ON) endif() +find_package(LLVM 18.1.0 REQUIRED) + add_library( BoltCore #src/Text.cc @@ -28,6 +30,7 @@ add_library( src/Types.cc src/Checker.cc src/Evaluator.cc + src/Scope.cc ) target_link_directories( BoltCore @@ -61,6 +64,22 @@ target_link_libraries( icuuc ) +add_library( + BoltLLVM + src/LLVMCodeGen.cc +) +llvm_map_components_to_libnames(llvm_libs support core irreader) +target_include_directories(BoltLLVM PRIVATE ${LLVM_INCLUDE_DIRS}) +separate_arguments(LLVM_DEFINITIONS_LIST NATIVE_COMMAND ${LLVM_DEFINITIONS}) +target_compile_definitions(BoltLLVM PRIVATE ${LLVM_DEFINITIONS_LIST}) + +target_link_libraries( + BoltLLVM + PUBLIC + BoltCore + ${llvm_libs} +) + add_executable( bolt src/main.cc @@ -69,6 +88,7 @@ target_link_libraries( bolt PUBLIC BoltCore + BoltLLVM ) if (BOLT_ENABLE_TESTS) diff --git a/bootstrap/cxx/include/bolt/CST.hpp b/bootstrap/cxx/include/bolt/CST.hpp index 4f9295176..4a4a15f43 100644 --- a/bootstrap/cxx/include/bolt/CST.hpp +++ b/bootstrap/cxx/include/bolt/CST.hpp @@ -1,11 +1,13 @@ #ifndef BOLT_CST_HPP #define BOLT_CST_HPP +#include #include #include #include #include +#include "bolt/Common.hpp" #include "zen/config.hpp" #include "bolt/Integer.hpp" @@ -172,7 +174,11 @@ enum class NodeKind { Parameter, LetBlockBody, LetExprBody, - LetDeclaration, + PrefixFunctionDeclaration, + InfixFunctionDeclaration, + SuffixFunctionDeclaration, + NamedFunctionDeclaration, + VariableDeclaration, RecordDeclarationField, RecordDeclaration, VariantDeclaration, @@ -364,31 +370,6 @@ public: }; -/// Any node that can be used as an operator -/// -/// This includes the following nodes: -/// - VBar -/// - CustomOperator -using Operator = Token; - -/// Any node that can be used as a kind of identifier. -/// -/// This includes the following nodes: -/// - Identifier -/// - IdentifierAlt -/// - WrappedOperator -using Symbol = Node; - -inline bool isSymbol(const Node* N) { - return N->getKind() == NodeKind::Identifier - || N->getKind() == NodeKind::IdentifierAlt - || N->getKind() == NodeKind::WrappedOperator; -} - -/// Get the text that is actually represented by a symbol, without all the -/// syntactic sugar. -ByteString getCanonicalText(const Symbol* N); - class Equals : public Token { public: @@ -903,6 +884,8 @@ public: std::string getText() const override; + std::string getCanonicalText() const; + static bool classof(const Node* N) { return N->getKind() == NodeKind::CustomOperator; } @@ -935,6 +918,8 @@ public: std::string getText() const override; + ByteString getCanonicalText() const; + bool isTypeVar() const; static bool classof(const Node* N) { @@ -953,6 +938,8 @@ public: std::string getText() const override; + ByteString getCanonicalText() const; + static bool classof(const Node* N) { return N->getKind() == NodeKind::IdentifierAlt; } @@ -1021,6 +1008,182 @@ public: }; +/// Base node for things that can be used as an operator +/// +/// This includes the following nodes: +/// - VBar +/// - CustomOperator +class Operator { + + Node* N; + + Operator(Node* N): + N(N) {} + +public: + + Operator() {} + + Operator(VBar* N): + N(N) {} + + Operator(CustomOperator* N): + N(N) {} + + static Operator from_raw_node(Node* N) { + ZEN_ASSERT(isa(N)); + return N; + } + + inline NodeKind getKind() const { + return N->getKind(); + } + + inline bool isVBar() const { + return N->getKind() == NodeKind::VBar; + } + + inline bool isCustomOperator() const { + return N->getKind() == NodeKind::CustomOperator; + } + + VBar* asVBar() const { + return static_cast(N); + } + + CustomOperator* asCustomOperator() const { + return static_cast(N); + } + + operator Node*() const { + return N; + } + + /// Get the name that is actually represented by an operator, without all the + /// syntactic sugar. + virtual ByteString getCanonicalText() const; + + Token* getFirstToken() const; + Token* getLastToken() const; + + inline static bool classof(const Node* N) { + return N->getKind() == NodeKind::VBar + || N->getKind() == NodeKind::CustomOperator; + } + +}; +class WrappedOperator : public Node { +public: + + class LParen* LParen; + Operator Op; + class RParen* RParen; + + WrappedOperator( + class LParen* LParen, + Operator Operator, + class RParen* RParen + ): Node(NodeKind::WrappedOperator), + LParen(LParen), + Op(Operator), + RParen(RParen) {} + + inline Operator getOperator() const { + return Op; + } + + ByteString getCanonicalText() const { + return Op.getCanonicalText(); + } + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::WrappedOperator; + } + +}; + +/// Base node for things that can be used as a symbol +/// +/// This includes the following nodes: +/// - WrappedOperator +/// - Identifier +/// - IdentifierAlt +class Symbol { + + Node* N; + + Symbol(Node* N): + N(N) {} + +public: + + Symbol() {} + + Symbol(WrappedOperator* N): + N(N) {} + + Symbol(Identifier* N): + N(N) {} + + Symbol(IdentifierAlt* N): + N(N) {} + + static Symbol from_raw_node(Node* N) { + ZEN_ASSERT(isa(N)); + return N; + } + + NodeKind getKind() const { + return N->getKind(); + } + + bool isWrappedOperator() const { + return N->getKind() == NodeKind::WrappedOperator; + } + + bool isIdentifier() const { + return N->getKind() == NodeKind::Identifier; + } + + bool isIdentifierAlt() const { + return N->getKind() == NodeKind::IdentifierAlt; + } + + IdentifierAlt* asIdentifierAlt() const { + return cast(N); + } + + Identifier* asIdentifier() const { + return cast(N); + } + + WrappedOperator* asWrappedOperator() const { + return cast(N); + } + + operator Node*() const { + return N; + } + + /// Get the name that is actually represented by a symbol, without all the + /// syntactic sugar. + ByteString getCanonicalText() const; + + Token* getFirstToken() const; + Token* getLastToken() const; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::Identifier + || N->getKind() == NodeKind::IdentifierAlt + || N->getKind() == NodeKind::WrappedOperator; + } + +}; + + class Annotation : public Node { public: @@ -1353,31 +1516,6 @@ public: }; -class WrappedOperator : public Symbol { -public: - - class LParen* LParen; - Token* Op; - class RParen* RParen; - - WrappedOperator( - class LParen* LParen, - Token* Operator, - class RParen* RParen - ): Symbol(NodeKind::WrappedOperator), - LParen(LParen), - Op(Operator), - RParen(RParen) {} - - inline Token* getOperator() const { - return Op; - } - - Token* getFirstToken() const override; - Token* getLastToken() const override; - -}; - class Pattern : public Node { protected: @@ -1390,10 +1528,10 @@ protected: class BindPattern : public Pattern { public: - Symbol* Name; + Identifier* Name; BindPattern( - Symbol* Name + Identifier* Name ): Pattern(NodeKind::BindPattern), Name(Name) {} @@ -1608,11 +1746,11 @@ class ReferenceExpression : public Expression { public: std::vector> ModulePath; - Symbol* Name; + Symbol Name; inline ReferenceExpression( std::vector> ModulePath, - Symbol* Name + Symbol Name ): Expression(NodeKind::ReferenceExpression), ModulePath(ModulePath), Name(Name) {} @@ -1620,13 +1758,13 @@ public: inline ReferenceExpression( std::vector Annotations, std::vector> ModulePath, - Symbol* Name + Symbol Name ): Expression(NodeKind::ReferenceExpression, Annotations), ModulePath(ModulePath), Name(Name) {} inline ByteString getNameAsString() const noexcept { - return getCanonicalText(Name); + return Name.getCanonicalText(); } Token* getFirstToken() const override; @@ -1864,12 +2002,12 @@ class InfixExpression : public Expression { public: Expression* Left; - Token* Operator; + Operator Operator; Expression* Right; inline InfixExpression( Expression* Left, - Token* Operator, + class Operator Operator, Expression* Right ): Expression(NodeKind::InfixExpression), Left(Left), @@ -1879,7 +2017,7 @@ public: inline InfixExpression( std::vector Annotations, Expression* Left, - Token* Operator, + class Operator Operator, Expression* Right ): Expression(NodeKind::InfixExpression, Annotations), Left(Left), @@ -2110,6 +2248,10 @@ public: Token* getFirstToken() const override; Token* getLastToken() const override; + static bool classof(const Node* N) { + return N->getKind() == NodeKind::Parameter; + } + }; class LetBody : public Node { @@ -2155,7 +2297,7 @@ public: }; -class LetDeclaration : public TypedNode, public AnnotationContainer { +class FunctionDeclaration : public TypedNode, public AnnotationContainer { Scope* TheScope = nullptr; @@ -2165,63 +2307,34 @@ public: bool Visited = false; InferContext* Ctx; - class PubKeyword* PubKeyword; - class ForeignKeyword* ForeignKeyword; - class LetKeyword* LetKeyword; - class MutKeyword* MutKeyword; - class Pattern* Pattern; - std::vector Params; - class TypeAssert* TypeAssert; - LetBody* Body; + FunctionDeclaration(NodeKind Kind, std::vector Annotations = {}): + TypedNode(Kind), AnnotationContainer(Annotations) {} - LetDeclaration( - class PubKeyword* PubKeyword, - class ForeignKeyword* ForeignKeyword, - class LetKeyword* LetKeyword, - class MutKeyword* MutKeyword, - class Pattern* Pattern, - std::vector Params, - class TypeAssert* TypeAssert, - LetBody* Body - ): TypedNode(NodeKind::LetDeclaration), - PubKeyword(PubKeyword), - ForeignKeyword(ForeignKeyword), - LetKeyword(LetKeyword), - MutKeyword(MutKeyword), - Pattern(Pattern), - Params(Params), - TypeAssert(TypeAssert), - Body(Body) {} + virtual bool isPublic() const = 0; - LetDeclaration( - std::vector Annotations, - class PubKeyword* PubKeyword, - class ForeignKeyword* ForeignKeyword, - class LetKeyword* LetKeyword, - class MutKeyword* MutKeyword, - class Pattern* Pattern, - std::vector Params, - class TypeAssert* TypeAssert, - LetBody* Body - ): TypedNode(NodeKind::LetDeclaration), - AnnotationContainer(Annotations), - PubKeyword(PubKeyword), - ForeignKeyword(ForeignKeyword), - LetKeyword(LetKeyword), - MutKeyword(MutKeyword), - Pattern(Pattern), - Params(Params), - TypeAssert(TypeAssert), - Body(Body) {} + virtual bool isForeign() const = 0; + + virtual ByteString getNameAsString() const = 0; + + virtual std::vector getParams() const = 0; + + virtual TypeAssert* getTypeAssert() const = 0; + + bool hasTypeAssert() const { + return getTypeAssert(); + } + + virtual LetBody* getBody() const = 0; + + bool hasBody() const { + return getBody(); + } inline Scope* getScope() override { - if (isFunction()) { - if (TheScope == nullptr) { - TheScope = new Scope(this); - } - return TheScope; + if (TheScope == nullptr) { + TheScope = new Scope(this); } - return Parent->getScope(); + return TheScope; } bool isInstance() const noexcept { @@ -2232,33 +2345,310 @@ public: return Parent->getKind() == NodeKind::ClassDeclaration; } - bool isSignature() const noexcept { - return ForeignKeyword; + static bool classof(const Node* N) { + return N->getKind() == NodeKind::PrefixFunctionDeclaration + || N->getKind() == NodeKind::InfixFunctionDeclaration + || N->getKind() == NodeKind::SuffixFunctionDeclaration + || N->getKind() == NodeKind::NamedFunctionDeclaration; } - bool isVariable() const noexcept { - // Variables in classes and instances are never possible, so we reflect this by excluding them here. - return !isSignature() && !isClass() && !isInstance() && Params.empty() && (Pattern->getKind() != NodeKind::BindPattern || !Body); +}; + +class PrefixFunctionDeclaration : public FunctionDeclaration { +public: + + class PubKeyword* PubKeyword; + class ForeignKeyword* ForeignKeyword; + class LetKeyword* LetKeyword; + class Operator Name; + Parameter* Param; + class TypeAssert* TypeAssert; + LetBody* Body; + + PrefixFunctionDeclaration( + class std::vector Annotations, + class PubKeyword* PubKeyword, + class ForeignKeyword* ForeignKeyword, + class LetKeyword* LetKeyword, + Operator Name, + Parameter* Param, + class TypeAssert* TypeAssert, + LetBody* Body + ): FunctionDeclaration(NodeKind::PrefixFunctionDeclaration, Annotations), + PubKeyword(PubKeyword), + ForeignKeyword(ForeignKeyword), + LetKeyword(LetKeyword), + Name(Name), + Param(Param), + TypeAssert(TypeAssert), + Body(Body) {} + + bool isPublic() const override { + return PubKeyword != nullptr; } - bool isFunction() const noexcept { - return !isSignature() && !isVariable(); + bool isForeign() const override { + return ForeignKeyword != nullptr; } - Symbol* getName() const noexcept { - ZEN_ASSERT(Pattern->getKind() == NodeKind::BindPattern); - return static_cast(Pattern)->Name; + ByteString getNameAsString() const override { + return Name.getCanonicalText(); } - ByteString getNameAsString() const noexcept { - return getCanonicalText(getName()); + std::vector getParams() const override { + return { Param }; + } + + class TypeAssert* getTypeAssert() const override { + return TypeAssert; + } + + LetBody* getBody() const override { + return Body; } Token* getFirstToken() const override; Token* getLastToken() const override; static bool classof(const Node* N) { - return N->getKind() == NodeKind::LetDeclaration; + return N->getKind() == NodeKind::PrefixFunctionDeclaration; + } + +}; + +class SuffixFunctionDeclaration : public FunctionDeclaration { +public: + + class PubKeyword* PubKeyword; + class ForeignKeyword* ForeignKeyword; + class LetKeyword* LetKeyword; + Parameter* Param; + class Operator Name; + class TypeAssert* TypeAssert; + LetBody* Body; + + SuffixFunctionDeclaration( + class std::vector Annotations, + class PubKeyword* PubKeyword, + class ForeignKeyword* ForeignKeyword, + class LetKeyword* LetKeyword, + Parameter* Param, + Operator Name, + class TypeAssert* TypeAssert, + LetBody* Body + ): FunctionDeclaration(NodeKind::SuffixFunctionDeclaration, Annotations), + PubKeyword(PubKeyword), + ForeignKeyword(ForeignKeyword), + LetKeyword(LetKeyword), + Name(Name), + Param(Param), + TypeAssert(TypeAssert), + Body(Body) {} + + bool isPublic() const override { + return PubKeyword != nullptr; + } + + bool isForeign() const override { + return ForeignKeyword != nullptr; + } + + ByteString getNameAsString() const override { + return Name.getCanonicalText(); + } + + std::vector getParams() const override { + return { Param }; + } + + class TypeAssert* getTypeAssert() const override { + return TypeAssert; + } + + LetBody* getBody() const override { + return Body; + } + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::SuffixFunctionDeclaration; + } + +}; + +class InfixFunctionDeclaration : public FunctionDeclaration { +public: + + class PubKeyword* PubKeyword; + class ForeignKeyword* ForeignKeyword; + class LetKeyword* LetKeyword; + Parameter* Left; + class Operator Name; + Parameter* Right; + class TypeAssert* TypeAssert; + LetBody* Body; + + InfixFunctionDeclaration( + class std::vector Annotations, + class PubKeyword* PubKeyword, + class ForeignKeyword* ForeignKeyword, + class LetKeyword* LetKeyword, + Parameter* Left, + class Operator Name, + Parameter* Right, + class TypeAssert* TypeAssert, + LetBody* Body + ): FunctionDeclaration(NodeKind::InfixFunctionDeclaration, Annotations), + PubKeyword(PubKeyword), + ForeignKeyword(ForeignKeyword), + LetKeyword(LetKeyword), + Left(Left), + Name(Name), + Right(Right), + TypeAssert(TypeAssert), + Body(Body) {} + + bool isPublic() const override { + return PubKeyword != nullptr; + } + + bool isForeign() const override { + return ForeignKeyword != nullptr; + } + + ByteString getNameAsString() const override { + return Name.getCanonicalText(); + } + + std::vector getParams() const override { + return { Left, Right }; + } + + class TypeAssert* getTypeAssert() const override { + return TypeAssert; + } + + LetBody* getBody() const override { + return Body; + } + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::InfixFunctionDeclaration; + } + +}; + +class NamedFunctionDeclaration : public FunctionDeclaration { +public: + + class PubKeyword* PubKeyword; + class ForeignKeyword* ForeignKeyword; + class LetKeyword* LetKeyword; + class Symbol Name; + std::vector Params; + class TypeAssert* TypeAssert; + LetBody* Body; + + NamedFunctionDeclaration( + class std::vector Annotations, + class PubKeyword* PubKeyword, + class ForeignKeyword* ForeignKeyword, + class LetKeyword* LetKeyword, + class Symbol Name, + std::vector Params, + class TypeAssert* TypeAssert, + LetBody* Body + ): FunctionDeclaration(NodeKind::NamedFunctionDeclaration, Annotations), + PubKeyword(PubKeyword), + ForeignKeyword(ForeignKeyword), + LetKeyword(LetKeyword), + Name(Name), + Params(Params), + TypeAssert(TypeAssert), + Body(Body) {} + + bool isPublic() const override { + return PubKeyword != nullptr; + } + + bool isForeign() const override { + return ForeignKeyword != nullptr; + } + + ByteString getNameAsString() const override { + return Name.getCanonicalText(); + } + + std::vector getParams() const override { + return Params; + } + + class TypeAssert* getTypeAssert() const override { + return TypeAssert; + } + + LetBody* getBody() const override { + return Body; + } + Token* getFirstToken() const override; + Token* getLastToken() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::NamedFunctionDeclaration; + } + +}; + +class VariableDeclaration : public TypedNode, public AnnotationContainer { +public: + + class PubKeyword* PubKeyword; + class ForeignKeyword* ForeignKeyword; + class LetKeyword* LetKeyword; + class MutKeyword* MutKeyword; + class Pattern* Pattern; + std::vector Params; + class TypeAssert* TypeAssert; + LetBody* Body; + + VariableDeclaration( + class std::vector Annotations, + class PubKeyword* PubKeyword, + class ForeignKeyword* ForeignKeyword, + class LetKeyword* LetKeyword, + class MutKeyword* MutKeyword, + class Pattern* Pattern, + class TypeAssert* TypeAssert, + LetBody* Body + ): TypedNode(NodeKind::VariableDeclaration), + AnnotationContainer(Annotations), + PubKeyword(PubKeyword), + ForeignKeyword(ForeignKeyword), + LetKeyword(LetKeyword), + MutKeyword(MutKeyword), + Pattern(Pattern), + TypeAssert(TypeAssert), + Body(Body) {} + + Symbol getName() const noexcept { + ZEN_ASSERT(Pattern->getKind() == NodeKind::BindPattern); + return static_cast(Pattern)->Name; + } + + ByteString getNameAsString() const noexcept { + return getName().getCanonicalText(); + } + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::VariableDeclaration; } }; @@ -2555,7 +2945,11 @@ template<> inline NodeKind getNodeType() { return NodeKind::TypeAsse template<> inline NodeKind getNodeType() { return NodeKind::Parameter; } template<> inline NodeKind getNodeType() { return NodeKind::LetBlockBody; } template<> inline NodeKind getNodeType() { return NodeKind::LetExprBody; } -template<> inline NodeKind getNodeType() { return NodeKind::LetDeclaration; } +template<> inline NodeKind getNodeType() { return NodeKind::PrefixFunctionDeclaration; } +template<> inline NodeKind getNodeType() { return NodeKind::InfixFunctionDeclaration; } +template<> inline NodeKind getNodeType() { return NodeKind::SuffixFunctionDeclaration; } +template<> inline NodeKind getNodeType() { return NodeKind::NamedFunctionDeclaration; } +template<> inline NodeKind getNodeType() { return NodeKind::VariableDeclaration; } template<> inline NodeKind getNodeType() { return NodeKind::RecordDeclarationField; } template<> inline NodeKind getNodeType() { return NodeKind::RecordDeclaration; } template<> inline NodeKind getNodeType() { return NodeKind::ClassDeclaration; } diff --git a/bootstrap/cxx/include/bolt/CSTVisitor.hpp b/bootstrap/cxx/include/bolt/CSTVisitor.hpp index 90228669d..3d3fd4ae1 100644 --- a/bootstrap/cxx/include/bolt/CSTVisitor.hpp +++ b/bootstrap/cxx/include/bolt/CSTVisitor.hpp @@ -15,8 +15,8 @@ public: void visit(Node* N) { #define BOLT_GEN_CASE(name) \ - case NodeKind::name: \ - return static_cast(this)->visit ## name(static_cast(N)); + case NodeKind::name: \ + return static_cast(this)->visit ## name(static_cast(N)); switch (N->getKind()) { BOLT_GEN_CASE(VBar) @@ -104,7 +104,11 @@ public: BOLT_GEN_CASE(Parameter) BOLT_GEN_CASE(LetBlockBody) BOLT_GEN_CASE(LetExprBody) - BOLT_GEN_CASE(LetDeclaration) + BOLT_GEN_CASE(PrefixFunctionDeclaration) + BOLT_GEN_CASE(InfixFunctionDeclaration) + BOLT_GEN_CASE(SuffixFunctionDeclaration) + BOLT_GEN_CASE(NamedFunctionDeclaration) + BOLT_GEN_CASE(VariableDeclaration) BOLT_GEN_CASE(RecordDeclaration) BOLT_GEN_CASE(RecordDeclarationField) BOLT_GEN_CASE(VariantDeclaration) @@ -116,6 +120,35 @@ public: } } + void dispatchSymbol(const Symbol& S) { + switch (S.getKind()) { + case NodeKind::Identifier: + visit(S.asIdentifier()); + break; + case NodeKind::IdentifierAlt: + visit(S.asIdentifierAlt()); + break; + case NodeKind::WrappedOperator: + visit(S.asWrappedOperator()); + break; + default: + ZEN_UNREACHABLE + } + } + + void dispatchOperator(const Operator& O) { + switch (O.getKind()) { + case NodeKind::VBar: + visit(O.asVBar()); + break; + case NodeKind::CustomOperator: + visit(O.asCustomOperator()); + break; + default: + ZEN_UNREACHABLE + } + } + protected: void visitNode(Node* N) { @@ -494,7 +527,27 @@ protected: static_cast(this)->visitLetBody(N); } - void visitLetDeclaration(LetDeclaration* N) { + void visitFunctionDeclaration(FunctionDeclaration* N) { + static_cast(this)->visitNode(N); + } + + void visitPrefixFunctionDeclaration(PrefixFunctionDeclaration* N) { + static_cast(this)->visitFunctionDeclaration(N); + } + + void visitInfixFunctionDeclaration(InfixFunctionDeclaration* N) { + static_cast(this)->visitFunctionDeclaration(N); + } + + void visitSuffixFunctionDeclaration(SuffixFunctionDeclaration* N) { + static_cast(this)->visitFunctionDeclaration(N); + } + + void visitNamedFunctionDeclaration(NamedFunctionDeclaration* N) { + static_cast(this)->visitFunctionDeclaration(N); + } + + void visitVariableDeclaration(VariableDeclaration* N) { static_cast(this)->visitNode(N); } @@ -629,7 +682,11 @@ public: BOLT_GEN_CHILD_CASE(Parameter) BOLT_GEN_CHILD_CASE(LetBlockBody) BOLT_GEN_CHILD_CASE(LetExprBody) - BOLT_GEN_CHILD_CASE(LetDeclaration) + BOLT_GEN_CHILD_CASE(PrefixFunctionDeclaration) + BOLT_GEN_CHILD_CASE(InfixFunctionDeclaration) + BOLT_GEN_CHILD_CASE(SuffixFunctionDeclaration) + BOLT_GEN_CHILD_CASE(NamedFunctionDeclaration) + BOLT_GEN_CHILD_CASE(VariableDeclaration) BOLT_GEN_CHILD_CASE(RecordDeclaration) BOLT_GEN_CHILD_CASE(RecordDeclarationField) BOLT_GEN_CHILD_CASE(VariantDeclaration) @@ -642,6 +699,8 @@ public: } #define BOLT_VISIT(node) static_cast(this)->visit(node) +#define BOLT_VISIT_SYMBOL(node) static_cast(this)->dispatchSymbol(node) +#define BOLT_VISIT_OPERATOR(node) static_cast(this)->dispatchOperator(node) void visitEachChild(VBar* N) { } @@ -771,7 +830,7 @@ public: void visitEachChild(WrappedOperator* N) { BOLT_VISIT(N->LParen); - BOLT_VISIT(N->Op); + BOLT_VISIT_OPERATOR(N->Op); BOLT_VISIT(N->RParen); } @@ -972,7 +1031,7 @@ public: BOLT_VISIT(Name); BOLT_VISIT(Dot); } - BOLT_VISIT(N->Name); + BOLT_VISIT_SYMBOL(N->Name); } void visitEachChild(MatchCase* N) { @@ -1049,7 +1108,7 @@ public: BOLT_VISIT(A); } BOLT_VISIT(N->Left); - BOLT_VISIT(N->Operator); + BOLT_VISIT_OPERATOR(N->Operator); BOLT_VISIT(N->Right); } @@ -1140,7 +1199,7 @@ public: BOLT_VISIT(N->Expression); } - void visitEachChild(LetDeclaration* N) { + void visitEachChild(PrefixFunctionDeclaration* N) { for (auto A: N->Annotations) { BOLT_VISIT(A); } @@ -1151,6 +1210,96 @@ public: BOLT_VISIT(N->ForeignKeyword); } BOLT_VISIT(N->LetKeyword); + BOLT_VISIT(N->Param); + BOLT_VISIT_OPERATOR(N->Name); + if (N->TypeAssert) { + BOLT_VISIT(N->TypeAssert); + } + if (N->Body) { + BOLT_VISIT(N->Body); + } + } + + void visitEachChild(InfixFunctionDeclaration* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); + } + if (N->PubKeyword) { + BOLT_VISIT(N->PubKeyword); + } + if (N->ForeignKeyword) { + BOLT_VISIT(N->ForeignKeyword); + } + BOLT_VISIT(N->LetKeyword); + BOLT_VISIT(N->Left); + BOLT_VISIT_OPERATOR(N->Name); + BOLT_VISIT(N->Right); + if (N->TypeAssert) { + BOLT_VISIT(N->TypeAssert); + } + if (N->Body) { + BOLT_VISIT(N->Body); + } + } + + void visitEachChild(SuffixFunctionDeclaration* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); + } + if (N->PubKeyword) { + BOLT_VISIT(N->PubKeyword); + } + if (N->ForeignKeyword) { + BOLT_VISIT(N->ForeignKeyword); + } + BOLT_VISIT(N->LetKeyword); + BOLT_VISIT_OPERATOR(N->Name); + BOLT_VISIT(N->Param); + if (N->TypeAssert) { + BOLT_VISIT(N->TypeAssert); + } + if (N->Body) { + BOLT_VISIT(N->Body); + } + } + + void visitEachChild(NamedFunctionDeclaration* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); + } + if (N->PubKeyword) { + BOLT_VISIT(N->PubKeyword); + } + if (N->ForeignKeyword) { + BOLT_VISIT(N->ForeignKeyword); + } + BOLT_VISIT(N->LetKeyword); + BOLT_VISIT_SYMBOL(N->Name); + for (auto Param: N->Params) { + BOLT_VISIT(Param); + } + if (N->TypeAssert) { + BOLT_VISIT(N->TypeAssert); + } + if (N->Body) { + BOLT_VISIT(N->Body); + } + } + + void visitEachChild(VariableDeclaration* N) { + for (auto A: N->Annotations) { + BOLT_VISIT(A); + } + if (N->PubKeyword) { + BOLT_VISIT(N->PubKeyword); + } + if (N->ForeignKeyword) { + BOLT_VISIT(N->ForeignKeyword); + } + BOLT_VISIT(N->LetKeyword); + if (N->MutKeyword) { + BOLT_VISIT(N->MutKeyword); + } BOLT_VISIT(N->Pattern); for (auto Param: N->Params) { BOLT_VISIT(Param); diff --git a/bootstrap/cxx/include/bolt/Checker.hpp b/bootstrap/cxx/include/bolt/Checker.hpp index 279a78034..9f10a22a2 100644 --- a/bootstrap/cxx/include/bolt/Checker.hpp +++ b/bootstrap/cxx/include/bolt/Checker.hpp @@ -241,7 +241,7 @@ class Checker { /// Type inference void forwardDeclare(Node* Node); - void forwardDeclareFunctionDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints); + void forwardDeclareFunctionDeclaration(FunctionDeclaration* N, TVSet* TVs, ConstraintSet* Constraints); Type* inferExpression(Expression* Expression); Type* inferTypeExpression(TypeExpression* TE, bool AutoVars = true); @@ -249,7 +249,7 @@ class Checker { Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet); void infer(Node* node); - void inferFunctionDeclaration(LetDeclaration* N); + void inferFunctionDeclaration(FunctionDeclaration* N); void inferConstraintExpression(ConstraintExpression* C); /// Factory methods diff --git a/bootstrap/cxx/include/bolt/Common.hpp b/bootstrap/cxx/include/bolt/Common.hpp index 930ae4517..cecd29e2c 100644 --- a/bootstrap/cxx/include/bolt/Common.hpp +++ b/bootstrap/cxx/include/bolt/Common.hpp @@ -1,6 +1,8 @@ #pragma once +#include "zen/config.hpp" + namespace bolt { class LanguageConfig { diff --git a/bootstrap/cxx/include/bolt/Evaluator.hpp b/bootstrap/cxx/include/bolt/Evaluator.hpp index f984a5489..0614ed764 100644 --- a/bootstrap/cxx/include/bolt/Evaluator.hpp +++ b/bootstrap/cxx/include/bolt/Evaluator.hpp @@ -29,7 +29,7 @@ class Value { union { ByteString S; Integer I; - LetDeclaration* D; + FunctionDeclaration* D; NativeFunction F; Tuple T; }; @@ -45,7 +45,7 @@ public: Value(Integer I): Kind(ValueKind::Integer), I(I) {} - Value(LetDeclaration* D): + Value(FunctionDeclaration* D): Kind(ValueKind::SourceFunction), D(D) {} Value(NativeFunction F): @@ -67,7 +67,7 @@ public: new (&I) Tuple(V.T); break; case ValueKind::SourceFunction: - new (&D) LetDeclaration*(V.D); + new (&D) FunctionDeclaration*(V.D); break; case ValueKind::NativeFunction: new (&F) NativeFunction(V.F); @@ -90,7 +90,7 @@ public: new (&I) Tuple(Other.T); break; case ValueKind::SourceFunction: - new (&D) LetDeclaration*(Other.D); + new (&D) FunctionDeclaration*(Other.D); break; case ValueKind::NativeFunction: new (&F) NativeFunction(Other.F); @@ -112,7 +112,7 @@ public: return S; } - inline LetDeclaration* getDeclaration() { + inline FunctionDeclaration* getDeclaration() { ZEN_ASSERT(Kind == ValueKind::SourceFunction); return D; } diff --git a/bootstrap/cxx/include/bolt/Parser.hpp b/bootstrap/cxx/include/bolt/Parser.hpp index 478c425fc..9ddbf3d07 100644 --- a/bootstrap/cxx/include/bolt/Parser.hpp +++ b/bootstrap/cxx/include/bolt/Parser.hpp @@ -130,7 +130,7 @@ public: Node* parseLetBodyElement(); - LetDeclaration* parseLetDeclaration(); + Node* parseLetDeclaration(); Node* parseClassElement(); diff --git a/bootstrap/cxx/src/CST.cc b/bootstrap/cxx/src/CST.cc index 9623124a7..0aafe4569 100644 --- a/bootstrap/cxx/src/CST.cc +++ b/bootstrap/cxx/src/CST.cc @@ -59,216 +59,6 @@ ByteString TextFile::getText() const { return Text; } -Scope::Scope(Node* Source): - Source(Source) { - scan(Source); - } - -void Scope::addSymbol(ByteString Name, Node* Decl, SymbolKind Kind) { - Mapping.emplace(Name, std::make_tuple(Decl, Kind)); -} - -void Scope::scan(Node* X) { - switch (X->getKind()) { - case NodeKind::SourceFile: - { - auto File = static_cast(X); - for (auto Element: File->Elements) { - scanChild(Element); - } - break; - } - case NodeKind::MatchCase: - { - auto Case = static_cast(X); - visitPattern(Case->Pattern, Case); - break; - } - case NodeKind::LetDeclaration: - { - auto Decl = static_cast(X); - ZEN_ASSERT(Decl->isFunction()); - for (auto Param: Decl->Params) { - visitPattern(Param->Pattern, Param); - } - if (Decl->Body) { - scanChild(Decl->Body); - } - break; - } - default: - ZEN_UNREACHABLE - } -} - -void Scope::scanChild(Node* X) { - switch (X->getKind()) { - case NodeKind::LetExprBody: - case NodeKind::ExpressionStatement: - case NodeKind::IfStatement: - case NodeKind::ReturnStatement: - break; - case NodeKind::LetBlockBody: - { - auto Block = static_cast(X); - for (auto Element: Block->Elements) { - scanChild(Element); - } - break; - } - case NodeKind::InstanceDeclaration: - // We ignore let-declarations inside instance-declarations for now - break; - case NodeKind::ClassDeclaration: - { - auto Decl = static_cast(X); - addSymbol(getCanonicalText(Decl->Name), Decl, SymbolKind::Class); - for (auto Element: Decl->Elements) { - scanChild(Element); - } - break; - } - case NodeKind::LetDeclaration: - { - auto Decl = static_cast(X); - // No matter if it is a function or a variable, by visiting the pattern - // we add all relevant bindings to the current scope. - visitPattern(Decl->Pattern, Decl); - break; - } - case NodeKind::RecordDeclaration: - { - auto Decl = static_cast(X); - addSymbol(getCanonicalText(Decl->Name), Decl, SymbolKind::Type); - break; - } - case NodeKind::VariantDeclaration: - { - auto Decl = static_cast(X); - addSymbol(getCanonicalText(Decl->Name), Decl, SymbolKind::Type); - for (auto Member: Decl->Members) { - switch (Member->getKind()) { - case NodeKind::TupleVariantDeclarationMember: - { - auto T = static_cast(Member); - addSymbol(getCanonicalText(T->Name), Decl, SymbolKind::Constructor); - break; - } - case NodeKind::RecordVariantDeclarationMember: - { - auto R = static_cast(Member); - addSymbol(getCanonicalText(R->Name), Decl, SymbolKind::Constructor); - break; - } - default: - ZEN_UNREACHABLE - } - } - break; - } - default: - ZEN_UNREACHABLE - } -} - -void Scope::visitPattern(Pattern* X, Node* Decl) { - switch (X->getKind()) { - case NodeKind::BindPattern: - { - auto Y = static_cast(X); - addSymbol(getCanonicalText(Y->Name), Decl, SymbolKind::Var); - break; - } - case NodeKind::RecordPattern: - { - auto Y = static_cast(X); - for (auto [Field, Comma]: Y->Fields) { - if (Field->Pattern) { - visitPattern(Field->Pattern, Decl); - } else if (Field->Name) { - addSymbol(Field->Name->Text, Decl, SymbolKind::Var); - } - } - break; - } - case NodeKind::NamedRecordPattern: - { - auto Y = static_cast(X); - for (auto [Field, Comma]: Y->Fields) { - if (Field->Pattern) { - visitPattern(Field->Pattern, Decl); - } else if (Field->Name) { - addSymbol(Field->Name->Text, Decl, SymbolKind::Var); - } - } - break; - } - case NodeKind::NamedTuplePattern: - { - auto Y = static_cast(X); - for (auto P: Y->Patterns) { - visitPattern(P, Decl); - } - break; - } - case NodeKind::NestedPattern: - { - auto Y = static_cast(X); - visitPattern(Y->P, Decl); - break; - } - case NodeKind::TuplePattern: - { - auto Y = static_cast(X); - for (auto [Element, Comma]: Y->Elements) { - visitPattern(Element, Decl); - } - break; - } - case NodeKind::ListPattern: - { - auto Y = static_cast(X); - for (auto [Element, Separator]: Y->Elements) { - visitPattern(Element, Decl); - } - break; - } - case NodeKind::LiteralPattern: - break; - default: - ZEN_UNREACHABLE - } -} - -Node* Scope::lookupDirect(SymbolPath Path, SymbolKind Kind) { - ZEN_ASSERT(Path.Modules.empty()); - auto Match = Mapping.find(Path.Name); - if (Match != Mapping.end() && std::get<1>(Match->second) == Kind) { - return std::get<0>(Match->second); - } - return nullptr; -} - -Node* Scope::lookup(SymbolPath Path, SymbolKind Kind) { - ZEN_ASSERT(Path.Modules.empty()); - auto Curr = this; - do { - auto Found = Curr->lookupDirect(Path, Kind); - if (Found) { - return Found; - } - Curr = Curr->getParentScope(); - } while (Curr != nullptr); - return nullptr; -} - -Scope* Scope::getParentScope() { - if (Source->Parent == nullptr) { - return nullptr; - } - return Source->Parent->getScope(); -} - const SourceFile* Node::getSourceFile() const { const Node* CurrNode = this; for (;;) { @@ -503,29 +293,11 @@ Token* WrappedOperator::getLastToken() const { } Token* BindPattern::getFirstToken() const { - switch (Name->getKind()) { - case NodeKind::Identifier: - return static_cast(Name); - case NodeKind::IdentifierAlt: - return static_cast(Name); - case NodeKind::WrappedOperator: - return static_cast(Name)->LParen; - default: - ZEN_UNREACHABLE - } + return Name; } Token* BindPattern::getLastToken() const { - switch (Name->getKind()) { - case NodeKind::Identifier: - return static_cast(Name); - case NodeKind::IdentifierAlt: - return static_cast(Name); - case NodeKind::WrappedOperator: - return static_cast(Name)->RParen; - default: - ZEN_UNREACHABLE - } + return Name; } Token* LiteralPattern::getFirstToken() const { @@ -608,26 +380,26 @@ Token* ReferenceExpression::getFirstToken() const { if (!ModulePath.empty()) { return std::get<0>(ModulePath.front()); } - switch (Name->getKind()) { + switch (Name.getKind()) { case NodeKind::Identifier: - return static_cast(Name); + return Name.asIdentifier(); case NodeKind::IdentifierAlt: - return static_cast(Name); + return Name.asIdentifierAlt(); case NodeKind::WrappedOperator: - return static_cast(Name)->LParen; + return Name.asWrappedOperator()->getFirstToken(); default: ZEN_UNREACHABLE } } Token* ReferenceExpression::getLastToken() const { - switch (Name->getKind()) { + switch (Name.getKind()) { case NodeKind::Identifier: - return static_cast(Name); + return Name.asIdentifier(); case NodeKind::IdentifierAlt: - return static_cast(Name); + return Name.asIdentifierAlt(); case NodeKind::WrappedOperator: - return static_cast(Name)->RParen; + return Name.asWrappedOperator()->getLastToken(); default: ZEN_UNREACHABLE } @@ -805,7 +577,7 @@ Token* LetExprBody::getLastToken() const { return Expression->getLastToken(); } -Token* LetDeclaration::getFirstToken() const { +Token* PrefixFunctionDeclaration::getFirstToken() const { if (PubKeyword) { return PubKeyword; } @@ -815,17 +587,97 @@ Token* LetDeclaration::getFirstToken() const { return LetKeyword; } -Token* LetDeclaration::getLastToken() const { +Token* PrefixFunctionDeclaration::getLastToken() const { if (Body) { return Body->getLastToken(); } if (TypeAssert) { return TypeAssert->getLastToken(); } - if (Params.size()) { + return Param->getLastToken(); +} + +Token* InfixFunctionDeclaration::getFirstToken() const { + if (PubKeyword) { + return PubKeyword; + } + if (ForeignKeyword) { + return ForeignKeyword; + } + return LetKeyword; +} + +Token* InfixFunctionDeclaration::getLastToken() const { + if (Body) { + return Body->getLastToken(); + } + if (TypeAssert) { + return TypeAssert->getLastToken(); + } + return Right->getLastToken(); +} + +Token* SuffixFunctionDeclaration::getFirstToken() const { + if (PubKeyword) { + return PubKeyword; + } + if (ForeignKeyword) { + return ForeignKeyword; + } + return LetKeyword; +} + +Token* SuffixFunctionDeclaration::getLastToken() const { + if (Body) { + return Body->getLastToken(); + } + if (TypeAssert) { + return TypeAssert->getLastToken(); + } + return Name.getLastToken(); +} + +Token* NamedFunctionDeclaration::getFirstToken() const { + if (PubKeyword) { + return PubKeyword; + } + if (ForeignKeyword) { + return ForeignKeyword; + } + return LetKeyword; +} + +Token* NamedFunctionDeclaration::getLastToken() const { + if (Body) { + return Body->getLastToken(); + } + if (TypeAssert) { + return TypeAssert->getLastToken(); + } + if (!Params.empty()) { return Params.back()->getLastToken(); } + return Name.getLastToken(); +} + +Token* VariableDeclaration::getFirstToken() const { + if (PubKeyword) { + return PubKeyword; + } + if (ForeignKeyword) { + return ForeignKeyword; + } + return LetKeyword; +} + +Token* VariableDeclaration::getLastToken() const { return Pattern->getLastToken(); + if (TypeAssert) { + return TypeAssert->getLastToken(); + } + if (Body) { + return Body->getLastToken(); + } } Token* RecordDeclarationField::getFirstToken() const { @@ -1093,23 +945,81 @@ std::string InstanceKeyword::getText() const { return "instance"; } -ByteString getCanonicalText(const Symbol* N) { +ByteString Identifier::getCanonicalText() const { + return Text; +} + +ByteString IdentifierAlt::getCanonicalText() const { + return Text; +} + +ByteString CustomOperator::getCanonicalText() const { + return Text; +} + +ByteString Symbol::getCanonicalText() const { switch (N->getKind()) { case NodeKind::Identifier: - return static_cast(N)->Text; + return static_cast(N)->getCanonicalText(); case NodeKind::IdentifierAlt: - return static_cast(N)->Text; + return static_cast(N)->getCanonicalText(); case NodeKind::CustomOperator: - return static_cast(N)->Text; + return static_cast(N)->getCanonicalText(); case NodeKind::VBar: return static_cast(N)->getText(); case NodeKind::WrappedOperator: - return static_cast(N)->getOperator()->getText(); + return static_cast(N)->getCanonicalText(); default: ZEN_UNREACHABLE } } +Token* Symbol::getFirstToken() const { + switch (N->getKind()) { + case NodeKind::Identifier: + return static_cast(N); + case NodeKind::IdentifierAlt: + return static_cast(N); + case NodeKind::WrappedOperator: + return static_cast(N)->getFirstToken(); + default: + ZEN_UNREACHABLE + } +} + +Token* Symbol::getLastToken() const { + switch (N->getKind()) { + case NodeKind::Identifier: + return static_cast(N); + case NodeKind::IdentifierAlt: + return static_cast(N); + case NodeKind::WrappedOperator: + return static_cast(N)->getLastToken(); + default: + ZEN_UNREACHABLE + } +} + + +ByteString Operator::getCanonicalText() const { + switch (N->getKind()) { + case NodeKind::CustomOperator: + return static_cast(N)->getCanonicalText(); + case NodeKind::VBar: + return static_cast(N)->getText(); + default: + ZEN_UNREACHABLE + } +} + +Token* Operator::getFirstToken() const { + return static_cast(N); +} + +Token* Operator::getLastToken() const { + return static_cast(N); +} + LiteralValue StringLiteral::getValue() { return Text; } @@ -1121,9 +1031,9 @@ LiteralValue IntegerLiteral::getValue() { SymbolPath ReferenceExpression::getSymbolPath() const { std::vector ModuleNames; for (auto [Name, Dot]: ModulePath) { - ModuleNames.push_back(getCanonicalText(Name)); + ModuleNames.push_back(Name->getCanonicalText()); } - return SymbolPath { ModuleNames, getCanonicalText(Name) }; + return SymbolPath { ModuleNames, Name.getCanonicalText() }; } } diff --git a/bootstrap/cxx/src/Checker.cc b/bootstrap/cxx/src/Checker.cc index f28fc9b46..72344ee07 100644 --- a/bootstrap/cxx/src/Checker.cc +++ b/bootstrap/cxx/src/Checker.cc @@ -251,9 +251,9 @@ void Checker::forwardDeclare(Node* X) { inferTypeExpression(TE); } - auto Match = InstanceMap.find(getCanonicalText(Decl->Name)); + auto Match = InstanceMap.find(Decl->Name->getCanonicalText()); if (Match == InstanceMap.end()) { - InstanceMap.emplace(getCanonicalText(Decl->Name), std::vector { Decl }); + InstanceMap.emplace(Decl->Name->getCanonicalText(), std::vector { Decl }); } else { Match->second.push_back(Decl); } @@ -265,13 +265,15 @@ void Checker::forwardDeclare(Node* X) { break; } - case NodeKind::LetDeclaration: + case NodeKind::PrefixFunctionDeclaration: + case NodeKind::InfixFunctionDeclaration: + case NodeKind::SuffixFunctionDeclaration: + case NodeKind::NamedFunctionDeclaration: + break; + + case NodeKind::VariableDeclaration: { - // Function declarations are handled separately in forwardDeclareLetDeclaration() and inferExpression() - auto Decl = static_cast(X); - if (!Decl->isVariable()) { - break; - } + auto Decl = static_cast(X); Type* Ty; if (Decl->TypeAssert) { Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression); @@ -290,13 +292,13 @@ void Checker::forwardDeclare(Node* X) { std::vector Vars; for (auto TE: Decl->TVs) { - auto TV = createRigidVar(getCanonicalText(TE->Name)); + auto TV = createRigidVar(TE->Name->getCanonicalText()); Decl->Ctx->TVs->emplace(TV); - Decl->Ctx->Env.add(getCanonicalText(TE->Name), new Forall(TV), SymKind::Type); + Decl->Ctx->Env.add(TE->Name->getCanonicalText(), new Forall(TV), SymKind::Type); Vars.push_back(TV); } - Type* Ty = createConType(getCanonicalText(Decl->Name)); + Type* Ty = createConType(Decl->Name->getCanonicalText()); // Build the type that is actually returned by constructor functions auto RetTy = Ty; @@ -305,7 +307,7 @@ void Checker::forwardDeclare(Node* X) { } // Must be added early so we can create recursive types - Decl->Ctx->Parent->Env.add(getCanonicalText(Decl->Name), new Forall(Ty), SymKind::Type); + Decl->Ctx->Parent->Env.add(Decl->Name->getCanonicalText(), new Forall(Ty), SymKind::Type); for (auto Member: Decl->Members) { switch (Member->getKind()) { @@ -318,7 +320,7 @@ void Checker::forwardDeclare(Node* X) { ParamTypes.push_back(inferTypeExpression(Element, false)); } Decl->Ctx->Parent->Env.add( - getCanonicalText(TupleMember->Name), + TupleMember->Name->getCanonicalText(), new Forall( Decl->Ctx->TVs, Decl->Ctx->Constraints, @@ -351,13 +353,13 @@ void Checker::forwardDeclare(Node* X) { std::vector Vars; for (auto TE: Decl->Vars) { - auto TV = createRigidVar(getCanonicalText(TE->Name)); + auto TV = createRigidVar(TE->Name->getCanonicalText()); Decl->Ctx->TVs->emplace(TV); - Decl->Ctx->Env.add(getCanonicalText(TE->Name), new Forall(TV), SymKind::Type); + Decl->Ctx->Env.add(TE->Name->getCanonicalText(), new Forall(TV), SymKind::Type); Vars.push_back(TV); } - auto Name = getCanonicalText(Decl->Name); + auto Name = Decl->Name->getCanonicalText(); auto Ty = createConType(Name); // Must be added early so we can create recursive types @@ -373,7 +375,7 @@ void Checker::forwardDeclare(Node* X) { for (auto Field: Decl->Fields) { FieldsTy = new Type( TField( - getCanonicalText(Field->Name), + Field->Name->getCanonicalText(), new Type(TPresent(inferTypeExpression(Field->TypeExpression, false))), FieldsTy ) @@ -435,13 +437,11 @@ void Checker::initialize(Node* N) { Contexts.pop(); } - void visitLetDeclaration(LetDeclaration* Let) { - if (Let->isFunction()) { - Let->Ctx = createDerivedContext(); - Contexts.push(Let->Ctx); - visitEachChild(Let); - Contexts.pop(); - } + void visitFunctionDeclaration(FunctionDeclaration* Func) { + Func->Ctx = createDerivedContext(); + Contexts.push(Func->Ctx); + visitEachChild(Func); + Contexts.pop(); } // void visitVariableDeclaration(VariableDeclaration* Var) { @@ -456,22 +456,18 @@ void Checker::initialize(Node* N) { } -void Checker::forwardDeclareFunctionDeclaration(LetDeclaration* Let, TVSet* TVs, ConstraintSet* Constraints) { - - if (!Let->isFunction()) { - return; - } +void Checker::forwardDeclareFunctionDeclaration(FunctionDeclaration* Let, TVSet* TVs, ConstraintSet* Constraints) { // std::cerr << "declare " << Let->getNameAsString() << std::endl; setContext(Let->Ctx); auto addClassVars = [&](ClassDeclaration* Class, bool IsRigid) { - auto Id = getCanonicalText(Class->Name); + auto Id = Class->Name->getCanonicalText(); auto Ctx = &getContext(); std::vector Out; for (auto TE: Class->TypeVars) { - auto Name = getCanonicalText(TE->Name); + auto Name = TE->Name->getCanonicalText(); auto TV = IsRigid ? createRigidVar(Name) : createTypeVar(); TV->asVar().Context.emplace(Id); Ctx->Env.add(Name, new Forall(TV), SymKind::Type); @@ -493,8 +489,8 @@ void Checker::forwardDeclareFunctionDeclaration(LetDeclaration* Let, TVSet* TVs, // Otherwise, the type is not further specified and we create a new // unification variable. Type* Ty; - if (Let->TypeAssert) { - Ty = inferTypeExpression(Let->TypeAssert->TypeExpression); + if (Let->hasTypeAssert()) { + Ty = inferTypeExpression(Let->getTypeAssert()->TypeExpression); } else { Ty = createTypeVar(); } @@ -507,9 +503,33 @@ void Checker::forwardDeclareFunctionDeclaration(LetDeclaration* Let, TVSet* TVs, if (Let->isInstance()) { auto Instance = static_cast(Let->Parent); - auto Class = cast(Instance->getScope()->lookup({ {}, getCanonicalText(Instance->Name) }, SymbolKind::Class)); - // TODO check if `Class` is nullptr - auto SigLet = cast(Class->getScope()->lookupDirect({ {}, Let->getNameAsString() }, SymbolKind::Var)); + auto Class = cast(Instance->getScope()->lookup({ {}, Instance->Name->getCanonicalText() }, SymbolKind::Class)); + + if (Class == nullptr) { + // TODO print diagnostic + // DE.add(Instance->Name->getCanonicalText()); + goto after_isinstance; + } + + auto Decl = Class->getScope()->lookupDirect({ {}, Let->getNameAsString() }, SymbolKind::Var); + + if (Decl == nullptr) { + + // TODO print diagnostic + // DE.add(Let->getNameAsStrings), Let->getName()); + goto after_isinstance; + + } + + if (!isa(Decl)) { + + // TODO print diagnostic + // DE.add(Decl); + goto after_isinstance; + + } + + auto FuncDecl = cast(Decl); auto Params = addClassVars(Class, false); @@ -536,23 +556,25 @@ void Checker::forwardDeclareFunctionDeclaration(LetDeclaration* Let, TVSet* TVs, // It would be very strange if there was no type assert in the type // class let-declaration but we rather not let the compiler crash if that happens. - if (SigLet->TypeAssert) { + if (FuncDecl->hasTypeAssert()) { // Note that we can't do SigLet->TypeAssert->TypeExpression->getType() // because we need to re-generate the type within the local context of // this let-declaration. // TODO make CEqual accept multiple nodes - makeEqual(Ty, inferTypeExpression(SigLet->TypeAssert->TypeExpression), Let); + makeEqual(Ty, inferTypeExpression(FuncDecl->getTypeAssert()->TypeExpression), Let); } } - if (Let->Body) { - switch (Let->Body->getKind()) { +after_isinstance: + + if (Let->hasBody()) { + switch (Let->getBody()->getKind()) { case NodeKind::LetExprBody: break; case NodeKind::LetBlockBody: { - auto Block = static_cast(Let->Body); + auto Block = static_cast(Let->getBody()); Let->Ctx->ReturnType = createTypeVar(); for (auto Element: Block->Elements) { forwardDeclare(Element); @@ -570,11 +592,7 @@ void Checker::forwardDeclareFunctionDeclaration(LetDeclaration* Let, TVSet* TVs, } -void Checker::inferFunctionDeclaration(LetDeclaration* Decl) { - - if (!Decl->isFunction()) { - return; - } +void Checker::inferFunctionDeclaration(FunctionDeclaration* Decl) { // std::cerr << "infer " << Decl->getNameAsString() << std::endl; @@ -584,21 +602,21 @@ void Checker::inferFunctionDeclaration(LetDeclaration* Decl) { std::vector ParamTypes; Type* RetType; - for (auto Param: Decl->Params) { + for (auto Param: Decl->getParams()) { ParamTypes.push_back(inferPattern(Param->Pattern)); } - if (Decl->Body) { - switch (Decl->Body->getKind()) { + if (Decl->hasBody()) { + switch (Decl->getBody()->getKind()) { case NodeKind::LetExprBody: { - auto Expr = static_cast(Decl->Body); + auto Expr = static_cast(Decl->getBody()); RetType = inferExpression(Expr->Expression); break; } case NodeKind::LetBlockBody: { - auto Block = static_cast(Decl->Body); + auto Block = static_cast(Decl->getBody()); RetType = Decl->Ctx->ReturnType; for (auto Element: Block->Elements) { infer(Element); @@ -680,29 +698,34 @@ void Checker::infer(Node* N) { break; } - case NodeKind::LetDeclaration: + case NodeKind::PrefixFunctionDeclaration: + case NodeKind::InfixFunctionDeclaration: + case NodeKind::SuffixFunctionDeclaration: + case NodeKind::NamedFunctionDeclaration: { - // Function declarations are handled separately in inferFunctionDeclaration() - auto Decl = static_cast(N); + auto Decl = static_cast(N); if (Decl->Visited) { break; } - if (Decl->isFunction()) { - Decl->IsCycleActive = true; - Decl->Visited = true; - inferFunctionDeclaration(Decl); - Decl->IsCycleActive = false; - } else if (Decl->isVariable()) { - auto Ty = Decl->getType(); - if (Decl->Body) { - ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody); - auto E = static_cast(Decl->Body); - auto Ty2 = inferExpression(E->Expression); - makeEqual(Ty, Ty2, Decl); - } - auto Ty3 = inferPattern(Decl->Pattern); - makeEqual(Ty, Ty3, Decl); + Decl->IsCycleActive = true; + Decl->Visited = true; + inferFunctionDeclaration(Decl); + Decl->IsCycleActive = false; + break; + } + + case NodeKind::VariableDeclaration: + { + auto Decl = static_cast(N); + auto Ty = Decl->getType(); + if (Decl->Body) { + ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody); + auto E = static_cast(Decl->Body); + auto Ty2 = inferExpression(E->Expression); + makeEqual(Ty, Ty2, Decl); } + auto Ty3 = inferPattern(Decl->Pattern); + makeEqual(Ty, Ty3, Decl); break; } @@ -801,7 +824,7 @@ void Checker::inferConstraintExpression(ConstraintExpression* C) { std::vector Types; for (auto TE: D->TEs) { auto Ty = inferTypeExpression(TE); - Ty->asVar().Provided->emplace(getCanonicalText(D->Name)); + Ty->asVar().Provided->emplace(D->Name->getCanonicalText()); Types.push_back(Ty); } break; @@ -824,10 +847,10 @@ Type* Checker::inferTypeExpression(TypeExpression* N, bool AutoVars) { case NodeKind::ReferenceTypeExpression: { auto RefTE = static_cast(N); - auto Scm = lookup(getCanonicalText(RefTE->Name), SymKind::Type); + auto Scm = lookup(RefTE->Name->getCanonicalText(), SymKind::Type); Type* Ty; if (Scm == nullptr) { - DE.add(getCanonicalText(RefTE->Name), RefTE->Name); + DE.add(RefTE->Name->getCanonicalText(), RefTE->Name); Ty = createTypeVar(); } else { Ty = instantiate(Scm, RefTE); @@ -850,13 +873,13 @@ Type* Checker::inferTypeExpression(TypeExpression* N, bool AutoVars) { case NodeKind::VarTypeExpression: { auto VarTE = static_cast(N); - auto Ty = lookupMono(getCanonicalText(VarTE->Name), SymKind::Type); + auto Ty = lookupMono(VarTE->Name->getCanonicalText(), SymKind::Type); if (Ty == nullptr) { if (!AutoVars || Config.typeVarsRequireForall()) { - DE.add(getCanonicalText(VarTE->Name), VarTE->Name); + DE.add(VarTE->Name->getCanonicalText(), VarTE->Name); } - Ty = createRigidVar(getCanonicalText(VarTE->Name)); - addBinding(getCanonicalText(VarTE->Name), new Forall(Ty), SymKind::Type); + Ty = createRigidVar(VarTE->Name->getCanonicalText()); + addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty), SymKind::Type); } ZEN_ASSERT(Ty->isVar()); N->setType(Ty); @@ -868,7 +891,7 @@ Type* Checker::inferTypeExpression(TypeExpression* N, bool AutoVars) { auto RecTE = static_cast(N); auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, AutoVars) : new Type(TNil()); for (auto [Field, Comma]: RecTE->Fields) { - Ty = new Type(TField(getCanonicalText(Field->Name), new Type(TPresent(inferTypeExpression(Field->TE, AutoVars))), Ty)); + Ty = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(inferTypeExpression(Field->TE, AutoVars))), Ty)); } N->setType(Ty); return Ty; @@ -980,7 +1003,7 @@ Type* Checker::inferExpression(Expression* X) { Ty = new Type(TNil()); for (auto [Field, Comma]: Record->Fields) { Ty = new Type(TField( - getCanonicalText(Field->Name), + Field->Name->getCanonicalText(), new Type(TPresent(inferExpression(Field->getExpression()))), Ty )); @@ -999,11 +1022,12 @@ Type* Checker::inferExpression(Expression* X) { case NodeKind::ReferenceExpression: { auto Ref = static_cast(X); + auto Name = Ref->Name.getCanonicalText(); ZEN_ASSERT(Ref->ModulePath.empty()); - if (Ref->Name->is()) { - auto Scm = lookup(getCanonicalText(Ref->Name), SymKind::Var); + if (Ref->Name.isIdentifierAlt()) { + auto Scm = lookup(Name, SymKind::Var); if (!Scm) { - DE.add(getCanonicalText(Ref->Name), Ref->Name); + DE.add(Name, Ref->Name); Ty = createTypeVar(); break; } @@ -1012,12 +1036,12 @@ Type* Checker::inferExpression(Expression* X) { } auto Target = Ref->getScope()->lookup(Ref->getSymbolPath()); if (!Target) { - DE.add(getCanonicalText(Ref->Name), Ref->Name); + DE.add(Name, Ref->Name); Ty = createTypeVar(); break; } - if (Target->getKind() == NodeKind::LetDeclaration) { - auto Let = static_cast(Target); + if (isa(Target)) { + auto Let = static_cast(Target); if (Let->IsCycleActive) { Ty = Let->getType(); break; @@ -1026,7 +1050,7 @@ Type* Checker::inferExpression(Expression* X) { infer(Let); } } - auto Scm = lookup(getCanonicalText(Ref->Name), SymKind::Var); + auto Scm = lookup(Name, SymKind::Var); ZEN_ASSERT(Scm); Ty = instantiate(Scm, X); break; @@ -1048,9 +1072,9 @@ Type* Checker::inferExpression(Expression* X) { case NodeKind::InfixExpression: { auto Infix = static_cast(X); - auto Scm = lookup(Infix->Operator->getText(), SymKind::Var); + auto Scm = lookup(Infix->Operator.getCanonicalText(), SymKind::Var); if (Scm == nullptr) { - DE.add(Infix->Operator->getText(), Infix->Operator); + DE.add(Infix->Operator.getCanonicalText(), Infix->Operator); Ty = createTypeVar(); break; } @@ -1091,7 +1115,7 @@ Type* Checker::inferExpression(Expression* X) { auto K = static_cast(Member->Name); Ty = createTypeVar(); auto RestTy = createTypeVar(); - makeEqual(new Type(TField(getCanonicalText(K), Ty, RestTy)), ExprTy, Member); + makeEqual(new Type(TField(K->getCanonicalText(), Ty, RestTy)), ExprTy, Member); break; } default: @@ -1138,20 +1162,20 @@ Type* Checker::inferPattern( { auto P = static_cast(Pattern); auto Ty = createTypeVar(); - addBinding(getCanonicalText(P->Name), new Forall(TVs, Constraints, Ty), SymKind::Var); + addBinding(P->Name->getCanonicalText(), new Forall(TVs, Constraints, Ty), SymKind::Var); return Ty; } case NodeKind::NamedTuplePattern: { auto P = static_cast(Pattern); - auto Scm = lookup(getCanonicalText(P->Name), SymKind::Var); + auto Scm = lookup(P->Name->getCanonicalText(), SymKind::Var); std::vector ElementTypes; for (auto P2: P->Patterns) { ElementTypes.push_back(inferPattern(P2, Constraints, TVs)); } if (!Scm) { - DE.add(getCanonicalText(P->Name), P->Name); + DE.add(P->Name->getCanonicalText(), P->Name); return createTypeVar(); } auto Ty = instantiate(Scm, P); @@ -1181,9 +1205,9 @@ Type* Checker::inferPattern( FieldTy = inferPattern(Field->Pattern, Constraints, TVs); } else { FieldTy = createTypeVar(); - addBinding(getCanonicalText(Field->Name), new Forall(TVs, Constraints, FieldTy), SymKind::Var); + addBinding(Field->Name->getCanonicalText(), new Forall(TVs, Constraints, FieldTy), SymKind::Var); } - RecordTy = new Type(TField(getCanonicalText(Field->Name), new Type(TPresent(FieldTy)), RecordTy)); + RecordTy = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(FieldTy)), RecordTy)); } return RecordTy; } @@ -1191,9 +1215,9 @@ Type* Checker::inferPattern( case NodeKind::NamedRecordPattern: { auto P = static_cast(Pattern); - auto Scm = lookup(getCanonicalText(P->Name), SymKind::Var); + auto Scm = lookup(P->Name->getCanonicalText(), SymKind::Var); if (Scm == nullptr) { - DE.add(getCanonicalText(P->Name), P->Name); + DE.add(P->Name->getCanonicalText(), P->Name); return createTypeVar(); } auto RestField = getRestField(P->Fields); @@ -1214,9 +1238,9 @@ Type* Checker::inferPattern( FieldTy = inferPattern(Field->Pattern, Constraints, TVs); } else { FieldTy = createTypeVar(); - addBinding(getCanonicalText(Field->Name), new Forall(TVs, Constraints, FieldTy), SymKind::Var); + addBinding(Field->Name->getCanonicalText(), new Forall(TVs, Constraints, FieldTy), SymKind::Var); } - RecordTy = new Type(TField(getCanonicalText(Field->Name), new Type(TPresent(FieldTy)), RecordTy)); + RecordTy = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(FieldTy)), RecordTy)); } auto Ty = instantiate(Scm, P); auto RetTy = createTypeVar(); @@ -1287,7 +1311,14 @@ void Checker::populate(SourceFile* SF) { std::stack Stack; - void visitLetDeclaration(LetDeclaration* N) { + void visitFunctionDeclaration(FunctionDeclaration* N) { + RefGraph.addVertex(N); + Stack.push(N); + visitEachChild(N); + Stack.pop(); + } + + void visitVariableDeclaration(VariableDeclaration* N) { RefGraph.addVertex(N); Stack.push(N); visitEachChild(N); @@ -1295,22 +1326,26 @@ void Checker::populate(SourceFile* SF) { } void visitReferenceExpression(ReferenceExpression* N) { - auto Y = static_cast(N); - auto Def = Y->getScope()->lookup(Y->getSymbolPath()); - // Name lookup failures will be reported directly in inferExpression(). - if (Def == nullptr || Def->getKind() != NodeKind::LetDeclaration) { + auto Ref = static_cast(N); + auto Def = Ref->getScope()->lookup(Ref->getSymbolPath()); + if (Def == nullptr) { + // Name lookup failures will be reported directly in inferExpression(). return; } + ZEN_ASSERT(isa(Def) || isa(Def) || isa(Def)); // This case ensures that a deeply nested structure that references a // parameter of a parent node but is not referenced itself is correctly handled. // Note that the edge goes from the parent let to the parameter. This is normal. - if (Def->getKind() == NodeKind::Parameter) { - RefGraph.addEdge(Stack.top(), Def->Parent); + // if (Def->getKind() == NodeKind::Parameter) { + // RefGraph.addEdge(Stack.top(), Def->Parent); + // return; + // } + if (Stack.empty()) { + // An empty stack means we are traversing the toplevel of the source + // file, in which case we don't have anyting to connect with. return; } - if (!Stack.empty()) { - RefGraph.addEdge(Def, Stack.top()); - } + RefGraph.addEdge(Def, Stack.top()); } }; @@ -1353,10 +1388,10 @@ void Checker::check(SourceFile *SF) { auto TVs = new TVSet; auto Constraints = new ConstraintSet; for (auto N: Nodes) { - if (N->getKind() != NodeKind::LetDeclaration) { + if (!isa(N)) { continue; } - auto Decl = static_cast(N); + auto Decl = static_cast(N); forwardDeclareFunctionDeclaration(Decl, TVs, Constraints); } } diff --git a/bootstrap/cxx/src/ConsolePrinter.cc b/bootstrap/cxx/src/ConsolePrinter.cc index 919b80d02..5da51c905 100644 --- a/bootstrap/cxx/src/ConsolePrinter.cc +++ b/bootstrap/cxx/src/ConsolePrinter.cc @@ -155,7 +155,12 @@ static std::string describe(NodeKind Type) { return "'class'"; case NodeKind::InstanceKeyword: return "'instance'"; - case NodeKind::LetDeclaration: + case NodeKind::PrefixFunctionDeclaration: + case NodeKind::InfixFunctionDeclaration: + case NodeKind::SuffixFunctionDeclaration: + case NodeKind::NamedFunctionDeclaration: + return "a let-declaration"; + case NodeKind::VariableDeclaration: return "a let-declaration"; case NodeKind::CallExpression: return "a call-expression"; diff --git a/bootstrap/cxx/src/Evaluator.cc b/bootstrap/cxx/src/Evaluator.cc index 13817435e..d8914db69 100644 --- a/bootstrap/cxx/src/Evaluator.cc +++ b/bootstrap/cxx/src/Evaluator.cc @@ -11,7 +11,7 @@ Value Evaluator::evaluateExpression(Expression* X, Env& Env) { case NodeKind::ReferenceExpression: { auto RE = static_cast(X); - return Env.lookup(getCanonicalText(RE->Name)); + return Env.lookup(RE->Name.getCanonicalText()); // auto Decl = RE->getScope()->lookup(RE->getSymbolPath()); // ZEN_ASSERT(Decl && Decl->getKind() == NodeKind::FunctionDeclaration); // return static_cast(Decl); @@ -48,7 +48,7 @@ void Evaluator::assignPattern(Pattern* P, Value& V, Env& E) { case NodeKind::BindPattern: { auto BP = static_cast(P); - E.add(getCanonicalText(BP->Name), V); + E.add(BP->Name->getCanonicalText(), V); break; } default: @@ -62,12 +62,12 @@ Value Evaluator::apply(Value Op, std::vector Args) { { auto Fn = Op.getDeclaration(); Env NewEnv; - for (auto [Param, Arg]: zen::zip(Fn->Params, Args)) { + for (auto [Param, Arg]: zen::zip(Fn->getParams(), Args)) { assignPattern(Param->Pattern, Arg, NewEnv); } - switch (Fn->Body->getKind()) { + switch (Fn->getBody()->getKind()) { case NodeKind::LetExprBody: - return evaluateExpression(static_cast(Fn->Body)->Expression, NewEnv); + return evaluateExpression(static_cast(Fn->getBody())->Expression, NewEnv); default: ZEN_UNREACHABLE } @@ -98,23 +98,28 @@ void Evaluator::evaluate(Node* N, Env& E) { evaluateExpression(ES->Expression, E); break; } - case NodeKind::LetDeclaration: + case NodeKind::PrefixFunctionDeclaration: + case NodeKind::InfixFunctionDeclaration: + case NodeKind::SuffixFunctionDeclaration: + case NodeKind::NamedFunctionDeclaration: { - auto Decl = static_cast(N); - if (Decl->isFunction()) { - E.add(Decl->getNameAsString(), Decl); - } else { - Value V; - if (Decl->Body) { - switch (Decl->Body->getKind()) { - case NodeKind::LetExprBody: - { - auto Body = static_cast(Decl->Body); - V = evaluateExpression(Body->Expression, E); - } - default: - ZEN_UNREACHABLE + auto Decl = static_cast(N); + E.add(Decl->getNameAsString(), Decl); + break; + } + case NodeKind::VariableDeclaration: + { + auto Decl = static_cast(N); + Value V; + if (Decl->Body) { + switch (Decl->Body->getKind()) { + case NodeKind::LetExprBody: + { + auto Body = static_cast(Decl->Body); + V = evaluateExpression(Body->Expression, E); } + default: + ZEN_UNREACHABLE } } break; diff --git a/bootstrap/cxx/src/Parser.cc b/bootstrap/cxx/src/Parser.cc index 3d4b77429..8de8939df 100644 --- a/bootstrap/cxx/src/Parser.cc +++ b/bootstrap/cxx/src/Parser.cc @@ -32,16 +32,6 @@ namespace bolt { -bool isOperator(Token* T) { - switch (T->getKind()) { - case NodeKind::VBar: - case NodeKind::CustomOperator: - return true; - default: - return false; - } -} - std::optional OperatorTable::getInfix(Token* T) { auto Match = Mapping.find(T->getText()); if (Match == Mapping.end() || !Match->second.isInfix()) { @@ -828,7 +818,7 @@ Expression* Parser::parsePrimitiveExpression() { DE.add(File, T3, std::vector { NodeKind::Identifier, NodeKind::IdentifierAlt }); return nullptr; } - return new ReferenceExpression(Annotations, ModulePath, static_cast(T3)); + return new ReferenceExpression(Annotations, ModulePath, Symbol::from_raw_node(T3)); } case NodeKind::LParen: { @@ -1025,7 +1015,7 @@ Expression* Parser::parseInfixOperatorAfterExpression(Expression* Left, int MinP } Right = NewRight; } - Left = new InfixExpression(Left, T0, Right); + Left = new InfixExpression(Left, Operator::from_raw_node(T0), Right); } return Left; } @@ -1141,17 +1131,31 @@ IfStatement* Parser::parseIfStatement() { return new IfStatement(Parts); } -LetDeclaration* Parser::parseLetDeclaration() { +enum class LetMode { + Prefix, + Infix, + Suffix, + Wrapped, + VarOrNamed, +}; + +Node* Parser::parseLetDeclaration() { auto Annotations = parseAnnotations(); PubKeyword* Pub = nullptr; ForeignKeyword* Foreign = nullptr; LetKeyword* Let; MutKeyword* Mut = nullptr; + Operator Op; + Symbol Sym; Pattern* Name; + Parameter* Param; + Parameter* Left; + Parameter* Right; std::vector Params; TypeAssert* TA = nullptr; LetBody* Body = nullptr; + LetMode Mode; auto T0 = Tokens.get(); if (T0->getKind() == NodeKind::PubKeyword) { @@ -1183,38 +1187,46 @@ LetDeclaration* Parser::parseLetDeclaration() { auto T2 = Tokens.peek(0); auto T3 = Tokens.peek(1); auto T4 = Tokens.peek(2); - if (isOperator(T2)) { + if (isa(T2)) { + // Prefix function declaration Tokens.get(); auto P1 = parseNarrowPattern(); - Params.push_back(new Parameter(P1, nullptr)); - Name = new BindPattern(T2); + Param = new Parameter(P1, nullptr); + Op = Operator::from_raw_node(T2); + Mode = LetMode::Prefix; goto after_params; - } else if (isOperator(T3) && (T4->getKind() == NodeKind::Colon || T4->getKind() == NodeKind::Equals || T4->getKind() == NodeKind::BlockStart || T4->getKind() == NodeKind::LineFoldEnd)) { + } else if (isa(T3) && (T4->getKind() == NodeKind::Colon || T4->getKind() == NodeKind::Equals || T4->getKind() == NodeKind::BlockStart || T4->getKind() == NodeKind::LineFoldEnd)) { + // Sufffix function declaration auto P1 = parseNarrowPattern(); - Params.push_back(new Parameter(P1, nullptr)); + Param = new Parameter(P1, nullptr); Tokens.get(); - Name = new BindPattern(T3); + Op = Operator::from_raw_node(T3); + Mode = LetMode::Suffix; goto after_params; - } else if (T2->getKind() == NodeKind::LParen && isOperator(T3) && T4->getKind() == NodeKind::RParen) { + } else if (T2->getKind() == NodeKind::LParen && isa(T3) && T4->getKind() == NodeKind::RParen) { + // Wrapped operator function declaration Tokens.get(); Tokens.get(); Tokens.get(); - Name = new BindPattern( - new WrappedOperator( - static_cast(T2), - T3, - static_cast(T3) - ) + Sym = new WrappedOperator( + static_cast(T2), + Operator::from_raw_node(T3), + static_cast(T3) ); - } else if (isOperator(T3)) { + Mode = LetMode::Wrapped; + } else if (isa(T3)) { + // Infix function declaration auto P1 = parseNarrowPattern(); - Params.push_back(new Parameter(P1, nullptr)); + Left = new Parameter(P1, nullptr); Tokens.get(); auto P2 = parseNarrowPattern(); - Params.push_back(new Parameter(P2, nullptr)); - Name = new BindPattern(T3); + Right = new Parameter(P2, nullptr); + Op = Operator::from_raw_node(T3); + Mode = LetMode::Infix; goto after_params; } else { + // Variable declaration or named function declaration + Mode = LetMode::VarOrNamed; Name = parseNarrowPattern(); if (!Name) { if (Pub) { @@ -1313,17 +1325,77 @@ after_params: finish: - return new LetDeclaration( - Annotations, - Pub, - Foreign, - Let, - Mut, - Name, - Params, - TA, - Body - ); + switch (Mode) { + case LetMode::Prefix: + return new PrefixFunctionDeclaration( + Annotations, + Pub, + Foreign, + Let, + Op, + Param, + TA, + Body + ); + case LetMode::Suffix: + return new SuffixFunctionDeclaration( + Annotations, + Pub, + Foreign, + Let, + Param, + Op, + TA, + Body + ); + case LetMode::Infix: + return new InfixFunctionDeclaration( + Annotations, + Pub, + Foreign, + Let, + Left, + Op, + Right, + TA, + Body + ); + case LetMode::Wrapped: + return new NamedFunctionDeclaration( + Annotations, + Pub, + Foreign, + Let, + Sym, + Params, + TA, + Body + ); + case LetMode::VarOrNamed: + if (Name->getKind() != NodeKind::BindPattern || Mut) { + // TODO assert Params is empty + return new VariableDeclaration( + Annotations, + Pub, + Foreign, + Let, + Mut, + Name, + TA, + Body + ); + } + return new NamedFunctionDeclaration( + Annotations, + Pub, + Foreign, + Let, + Name->as()->Name, + Params, + TA, + Body + ); + } } Node* Parser::parseLetBodyElement() {