diff --git a/bootstrap/cxx/include/bolt/CST.hpp b/bootstrap/cxx/include/bolt/CST.hpp index 794c74c4e..bac763ec1 100644 --- a/bootstrap/cxx/include/bolt/CST.hpp +++ b/bootstrap/cxx/include/bolt/CST.hpp @@ -87,6 +87,7 @@ namespace bolt { }; enum class NodeKind { + VBar, Equals, Colon, Comma, @@ -132,6 +133,8 @@ namespace bolt { TypeAssertAnnotation, TypeclassConstraintExpression, EqualityConstraintExpression, + RecordTypeExpressionField, + RecordTypeExpression, QualifiedTypeExpression, ReferenceTypeExpression, ArrowTypeExpression, @@ -363,6 +366,20 @@ namespace bolt { }; + class VBar : public Token { + public: + + inline VBar(TextLoc StartLoc): + Token(NodeKind::VBar, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::VBar; + } + + }; + class Colon : public Token { public: @@ -1085,6 +1102,54 @@ namespace bolt { }; + class RecordTypeExpressionField : public Node { + public: + + Identifier* Name; + Colon* Colon; + TypeExpression* TE; + + inline RecordTypeExpressionField( + Identifier* Name, + class Colon* Colon, + TypeExpression* TE + ): Node(NodeKind::RecordTypeExpressionField), + Name(Name), + Colon(Colon), + TE(TE) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + }; + + class RecordTypeExpression : public TypeExpression { + public: + + LBrace* LBrace; + std::vector> Fields; + VBar* VBar; + TypeExpression* Rest; + RBrace* RBrace; + + inline RecordTypeExpression( + class LBrace* LBrace, + std::vector> Fields, + class VBar* VBar, + TypeExpression* Rest, + class RBrace* RBrace + ): TypeExpression(NodeKind::RecordTypeExpression), + LBrace(LBrace), + Fields(Fields), + VBar(VBar), + Rest(Rest), + RBrace(RBrace) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + }; + class VarTypeExpression; class TypeclassConstraintExpression : public ConstraintExpression { @@ -2079,7 +2144,7 @@ namespace bolt { bool isVariable() const noexcept { // Variables in classes and instances are never possible, so we reflect this by excluding them here. - return !isSignature() && !isClass() && !isInstance() && (Pattern->getKind() != NodeKind::BindPattern || !Body); + return !isSignature() && !isClass() && !isInstance() && Params.empty() && (Pattern->getKind() != NodeKind::BindPattern || !Body); } bool isFunction() const noexcept { diff --git a/bootstrap/cxx/include/bolt/CSTVisitor.hpp b/bootstrap/cxx/include/bolt/CSTVisitor.hpp index a46ba1f31..532f2a671 100644 --- a/bootstrap/cxx/include/bolt/CSTVisitor.hpp +++ b/bootstrap/cxx/include/bolt/CSTVisitor.hpp @@ -19,6 +19,7 @@ namespace bolt { return static_cast(this)->visit ## name(static_cast(N)); switch (N->getKind()) { + BOLT_GEN_CASE(VBar) BOLT_GEN_CASE(Equals) BOLT_GEN_CASE(Colon) BOLT_GEN_CASE(Comma) @@ -64,6 +65,8 @@ namespace bolt { BOLT_GEN_CASE(TypeAssertAnnotation) BOLT_GEN_CASE(TypeclassConstraintExpression) BOLT_GEN_CASE(EqualityConstraintExpression) + BOLT_GEN_CASE(RecordTypeExpressionField) + BOLT_GEN_CASE(RecordTypeExpression) BOLT_GEN_CASE(QualifiedTypeExpression) BOLT_GEN_CASE(ReferenceTypeExpression) BOLT_GEN_CASE(ArrowTypeExpression) @@ -121,6 +124,10 @@ namespace bolt { static_cast(this)->visitNode(N); } + void visitVBar(VBar* N) { + static_cast(this)->visitToken(N); + } + void visitEquals(Equals* N) { static_cast(this)->visitToken(N); } @@ -313,6 +320,14 @@ namespace bolt { static_cast(this)->visitNode(N); } + void visitRecordTypeExpressionField(RecordTypeExpressionField * N) { + static_cast(this)->visitNode(N); + } + + void visitRecordTypeExpression(RecordTypeExpression* N) { + static_cast(this)->visitTypeExpression(N); + } + void visitQualifiedTypeExpression(QualifiedTypeExpression* N) { static_cast(this)->visitTypeExpression(N); } @@ -519,6 +534,7 @@ namespace bolt { break; switch (N->getKind()) { + BOLT_GEN_CHILD_CASE(VBar) BOLT_GEN_CHILD_CASE(Equals) BOLT_GEN_CHILD_CASE(Colon) BOLT_GEN_CHILD_CASE(Comma) @@ -564,6 +580,8 @@ namespace bolt { BOLT_GEN_CHILD_CASE(TypeAssertAnnotation) BOLT_GEN_CHILD_CASE(TypeclassConstraintExpression) BOLT_GEN_CHILD_CASE(EqualityConstraintExpression) + BOLT_GEN_CHILD_CASE(RecordTypeExpressionField) + BOLT_GEN_CHILD_CASE(RecordTypeExpression) BOLT_GEN_CHILD_CASE(QualifiedTypeExpression) BOLT_GEN_CHILD_CASE(ReferenceTypeExpression) BOLT_GEN_CHILD_CASE(ArrowTypeExpression) @@ -613,6 +631,9 @@ namespace bolt { #define BOLT_VISIT(node) static_cast(this)->visit(node) + void visitEachChild(VBar* N) { + } + void visitEachChild(Equals* N) { } @@ -760,6 +781,29 @@ namespace bolt { BOLT_VISIT(N->Right); } + void visitEachChild(RecordTypeExpressionField* N) { + BOLT_VISIT(N->Name); + BOLT_VISIT(N->Colon); + BOLT_VISIT(N->TE); + } + + void visitEachChild(RecordTypeExpression* N) { + BOLT_VISIT(N->LBrace); + for (auto [Field, Comma]: N->Fields) { + BOLT_VISIT(Field); + if (Comma) { + BOLT_VISIT(Comma); + } + } + if (N->VBar) { + BOLT_VISIT(N->VBar); + } + if (N->Rest) { + BOLT_VISIT(N->Rest); + } + BOLT_VISIT(N->RBrace); + } + void visitEachChild(QualifiedTypeExpression* N) { for (auto [CE, Comma]: N->Constraints) { BOLT_VISIT(CE); diff --git a/bootstrap/cxx/src/CST.cc b/bootstrap/cxx/src/CST.cc index b9ad9730c..6a1d4aafd 100644 --- a/bootstrap/cxx/src/CST.cc +++ b/bootstrap/cxx/src/CST.cc @@ -393,6 +393,22 @@ namespace bolt { return Left->getLastToken(); } + Token* RecordTypeExpressionField::getFirstToken() const { + return Name; + } + + Token* RecordTypeExpressionField::getLastToken() const { + return TE->getLastToken(); + } + + Token* RecordTypeExpression::getFirstToken() const { + return LBrace; + } + + Token* RecordTypeExpression::getLastToken() const { + return RBrace; + } + Token* QualifiedTypeExpression::getFirstToken() const { if (!Constraints.empty()) { return std::get<0>(Constraints.front())->getFirstToken(); @@ -840,6 +856,10 @@ namespace bolt { return nullptr; } + std::string VBar::getText() const { + return "|"; + } + std::string Equals::getText() const { return "="; } diff --git a/bootstrap/cxx/src/Checker.cc b/bootstrap/cxx/src/Checker.cc index 1f9d304b9..59c213052 100644 --- a/bootstrap/cxx/src/Checker.cc +++ b/bootstrap/cxx/src/Checker.cc @@ -677,9 +677,12 @@ namespace bolt { case NodeKind::LetDeclaration: { - // Function declarations are handled separately in inferLetDeclaration() + // Function declarations are handled separately in inferFunctionDeclaration() auto Decl = static_cast(N); - if (Decl->isFunction() && !Decl->Visited) { + if (Decl->Visited) { + break; + } + if (Decl->isFunction()) { Decl->IsCycleActive = true; Decl->Visited = true; inferFunctionDeclaration(Decl); @@ -854,6 +857,17 @@ namespace bolt { return Ty; } + case NodeKind::RecordTypeExpression: + { + auto RecTE = static_cast(N); + auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, IsPoly) : new Type(TNil()); + for (auto [Field, Comma]: RecTE->Fields) { + Ty = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(inferTypeExpression(Field->TE, IsPoly))), Ty)); + } + N->setType(Ty); + return Ty; + } + case NodeKind::TupleTypeExpression: { auto TupleTE = static_cast(N); diff --git a/bootstrap/cxx/src/Parser.cc b/bootstrap/cxx/src/Parser.cc index 4f8fe3425..9482c98c6 100644 --- a/bootstrap/cxx/src/Parser.cc +++ b/bootstrap/cxx/src/Parser.cc @@ -1,6 +1,7 @@ // TODO check for memory leaks everywhere a nullptr is returned +#include #include #include "bolt/Common.hpp" @@ -385,6 +386,9 @@ after_constraints: LParen->unref(); for (auto [CE, Comma]: Constraints) { CE->unref(); + if (Comma) { + Comma->unref(); + } } RParen->unref(); RArrowAlt->unref(); @@ -398,6 +402,93 @@ after_constraints: switch (T0->getKind()) { case NodeKind::Identifier: return parseVarTypeExpression(); + case NodeKind::LBrace: + { + Tokens.get(); + auto LBrace = static_cast(T0); + std::vector> Fields; + VBar* VBar = nullptr; + TypeExpression* Rest = nullptr; + for (;;) { + auto T1 = Tokens.peek(); + if (T1->getKind() == NodeKind::RBrace) { + break; + } + auto Name = expectToken(); + if (Name == nullptr) { + for (auto [Field, Comma]: Fields) { + Field->unref(); + Comma->unref(); + } + return nullptr; + } + auto Colon = expectToken(); + if (Colon == nullptr) { + for (auto [Field, Comma]: Fields) { + Field->unref(); + Comma->unref(); + } + Name->unref(); + return nullptr; + } + auto TE = parseTypeExpression(); + if (TE == nullptr) { + for (auto [Field, Comma]: Fields) { + Field->unref(); + Comma->unref(); + } + Name->unref(); + Colon->unref(); + return nullptr; + } + auto Field = new RecordTypeExpressionField(Name, Colon, TE); + auto T3 = Tokens.peek(); + if (T3->getKind() == NodeKind::RBrace) { + Fields.push_back(std::make_tuple(Field, nullptr)); + break; + } + if (T3->getKind() == NodeKind::VBar) { + Tokens.get(); + VBar = static_cast(T3); + Rest = parseTypeExpression(); + if (!Rest) { + for (auto [Field, Comma]: Fields) { + Field->unref(); + Comma->unref(); + } + Field->unref(); + return nullptr; + } + auto T4 = Tokens.peek(); + if (T4->getKind() != NodeKind::RBrace) { + for (auto [Field, Comma]: Fields) { + Field->unref(); + Comma->unref(); + } + Field->unref(); + Rest->unref(); + DE.add(File, T4, std::vector { NodeKind::RBrace }); + return nullptr; + } + break; + } + if (T3->getKind() == NodeKind::Comma) { + Tokens.get(); + auto Comma = static_cast(T3); + Fields.push_back(std::make_tuple(Field, Comma)); + continue; + } + DE.add(File, T3, std::vector { NodeKind::RBrace, NodeKind::Comma, NodeKind::VBar }); + for (auto [Field, Comma]: Fields) { + Field->unref(); + Comma->unref(); + } + Field->unref(); + return nullptr; + } + auto RBrace = static_cast(Tokens.get()); + return new RecordTypeExpression(LBrace, Fields, VBar, Rest, RBrace); + } case NodeKind::LParen: { Tokens.get(); @@ -488,7 +579,16 @@ after_tuple_element: for (;;) { auto T1 = Tokens.peek(); auto Kind = T1->getKind(); - if (Kind == NodeKind::RArrow || Kind == NodeKind::Equals || Kind == NodeKind::BlockStart || Kind == NodeKind::LineFoldEnd || Kind == NodeKind::EndOfFile || Kind == NodeKind::RParen) { + if (Kind == NodeKind::Comma + || Kind == NodeKind::RArrow + || Kind == NodeKind::Equals + || Kind == NodeKind::BlockStart + || Kind == NodeKind::LineFoldEnd + || Kind == NodeKind::EndOfFile + || Kind == NodeKind::RParen + || Kind == NodeKind::RBracket + || Kind == NodeKind::RBrace + || Kind == NodeKind::VBar) { break; } auto TE = parsePrimitiveTypeExpression(); diff --git a/bootstrap/cxx/src/Scanner.cc b/bootstrap/cxx/src/Scanner.cc index cc3989a6a..f404b2d05 100644 --- a/bootstrap/cxx/src/Scanner.cc +++ b/bootstrap/cxx/src/Scanner.cc @@ -379,7 +379,9 @@ after_string_contents: Text.push_back(static_cast(C1)); getChar(); } - if (Text == "->") { + if (Text == "|") { + return new VBar(StartLoc); + } else if (Text == "->") { return new RArrow(StartLoc); } else if (Text == "=>") { return new RArrowAlt(StartLoc);