diff --git a/bootstrap/cxx/include/bolt/CST.hpp b/bootstrap/cxx/include/bolt/CST.hpp index fda3ef6e8..794c74c4e 100644 --- a/bootstrap/cxx/include/bolt/CST.hpp +++ b/bootstrap/cxx/include/bolt/CST.hpp @@ -141,7 +141,9 @@ namespace bolt { TupleTypeExpression, BindPattern, LiteralPattern, - NamedPattern, + RecordPatternField, + NamedRecordPattern, + NamedTuplePattern, TuplePattern, NestedPattern, ListPattern, @@ -1313,16 +1315,68 @@ namespace bolt { }; - class NamedPattern : public Pattern { + class RecordPatternField : public Node { + public: + + Identifier* Name; + Equals* Equals; + Pattern* Pattern; + + inline RecordPatternField( + Identifier* Name, + class Equals* Equals, + class Pattern* Pattern + ): Node(NodeKind::RecordPatternField), + Name(Name), + Equals(Equals), + Pattern(Pattern) {} + + inline RecordPatternField( + Identifier* Name + ): RecordPatternField(Name, nullptr, nullptr) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + }; + + class NamedRecordPattern : public Pattern { + public: + + std::vector> ModulePath; + IdentifierAlt* Name; + LBrace* LBrace; + std::vector> Fields; + RBrace* RBrace; + + inline NamedRecordPattern( + std::vector> ModulePath, + IdentifierAlt* Name, + class LBrace* LBrace, + std::vector> Fields, + class RBrace* RBrace + ): Pattern(NodeKind::NamedRecordPattern), + ModulePath(ModulePath), + Name(Name), + LBrace(LBrace), + Fields(Fields), + RBrace(RBrace) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + }; + + class NamedTuplePattern : public Pattern { public: IdentifierAlt* Name; std::vector Patterns; - inline NamedPattern( + inline NamedTuplePattern( IdentifierAlt* Name, std::vector Patterns - ): Pattern(NodeKind::NamedPattern), + ): Pattern(NodeKind::NamedTuplePattern), Name(Name), Patterns(Patterns) {} diff --git a/bootstrap/cxx/include/bolt/CSTVisitor.hpp b/bootstrap/cxx/include/bolt/CSTVisitor.hpp index 8fa14006f..a46ba1f31 100644 --- a/bootstrap/cxx/include/bolt/CSTVisitor.hpp +++ b/bootstrap/cxx/include/bolt/CSTVisitor.hpp @@ -73,7 +73,9 @@ namespace bolt { BOLT_GEN_CASE(TupleTypeExpression) BOLT_GEN_CASE(BindPattern) BOLT_GEN_CASE(LiteralPattern) - BOLT_GEN_CASE(NamedPattern) + BOLT_GEN_CASE(RecordPatternField) + BOLT_GEN_CASE(NamedRecordPattern) + BOLT_GEN_CASE(NamedTuplePattern) BOLT_GEN_CASE(TuplePattern) BOLT_GEN_CASE(NestedPattern) BOLT_GEN_CASE(ListPattern) @@ -351,7 +353,15 @@ namespace bolt { static_cast(this)->visitPattern(N); } - void visitNamedPattern(NamedPattern* N) { + void visitRecordPatternField(RecordPatternField* N) { + static_cast(this)->visitNode(N); + } + + void visitNamedRecordPattern(NamedRecordPattern* N) { + static_cast(this)->visitPattern(N); + } + + void visitNamedTuplePattern(NamedTuplePattern* N) { static_cast(this)->visitPattern(N); } @@ -563,7 +573,9 @@ namespace bolt { BOLT_GEN_CHILD_CASE(TupleTypeExpression) BOLT_GEN_CHILD_CASE(BindPattern) BOLT_GEN_CHILD_CASE(LiteralPattern) - BOLT_GEN_CHILD_CASE(NamedPattern) + BOLT_GEN_CHILD_CASE(RecordPatternField) + BOLT_GEN_CHILD_CASE(NamedRecordPattern) + BOLT_GEN_CHILD_CASE(NamedTuplePattern) BOLT_GEN_CHILD_CASE(TuplePattern) BOLT_GEN_CHILD_CASE(NestedPattern) BOLT_GEN_CHILD_CASE(ListPattern) @@ -810,7 +822,36 @@ namespace bolt { BOLT_VISIT(N->Literal); } - void visitEachChild(NamedPattern* N) { + void visitEachChild(RecordPatternField* N) { + BOLT_VISIT(N->Name); + if (N->Equals) { + BOLT_VISIT(N->Equals); + } + if (N->Pattern) { + BOLT_VISIT(N->Pattern); + } + } + + void visitEachChild(NamedRecordPattern* N) { + for (auto [Name, Dot]: N->ModulePath) { + BOLT_VISIT(Name); + if (Dot) { + BOLT_VISIT(Dot); + } + } + BOLT_VISIT(N->Name); + BOLT_VISIT(N->LBrace); + for (auto [Field, Comma]: N->Fields) { + BOLT_VISIT(Field); + if (Comma) { + BOLT_VISIT(Comma); + } + } + BOLT_VISIT(N->LBrace); + BOLT_VISIT(N->RBrace); + } + + void visitEachChild(NamedTuplePattern* N) { BOLT_VISIT(N->Name); for (auto P: N->Patterns) { BOLT_VISIT(P); diff --git a/bootstrap/cxx/include/bolt/Parser.hpp b/bootstrap/cxx/include/bolt/Parser.hpp index 777b0dbd4..db2882bd7 100644 --- a/bootstrap/cxx/include/bolt/Parser.hpp +++ b/bootstrap/cxx/include/bolt/Parser.hpp @@ -71,7 +71,8 @@ namespace bolt { Token* expectToken(NodeKind Ty); - std::vector parseRecordFields(); + std::vector parseRecordDeclarationFields(); + std::optional>> parseRecordPatternFields(); template T* expectToken() { @@ -97,7 +98,8 @@ namespace bolt { std::vector parseAnnotations(); void checkLineFoldEnd(); - void skipToLineFoldEnd(); + void skipPastLineFoldEnd(); + void skipToRBrace(); public: diff --git a/bootstrap/cxx/src/CST.cc b/bootstrap/cxx/src/CST.cc index c33a87df4..b9ad9730c 100644 --- a/bootstrap/cxx/src/CST.cc +++ b/bootstrap/cxx/src/CST.cc @@ -3,6 +3,7 @@ #include "bolt/CST.hpp" #include "bolt/CSTVisitor.hpp" +#include namespace bolt { @@ -173,9 +174,21 @@ namespace bolt { addSymbol(Y->Name->Text, Decl, SymbolKind::Var); break; } - case NodeKind::NamedPattern: + case NodeKind::NamedRecordPattern: { - auto Y = static_cast(X); + auto Y = static_cast(X); + for (auto [Field, Comma]: Y->Fields) { + if (Field->Pattern) { + visitPattern(Field->Pattern, Decl); + } else { + addSymbol(Field->Name->Text, Decl, SymbolKind::Var); + } + } + break; + } + case NodeKind::NamedTuplePattern: + { + auto Y = static_cast(X); for (auto P: Y->Patterns) { visitPattern(P, Decl); } @@ -464,11 +477,36 @@ namespace bolt { return Literal; } - Token* NamedPattern::getFirstToken() const { + Token* RecordPatternField::getFirstToken() const { return Name; } - Token* NamedPattern::getLastToken() const { + Token* RecordPatternField::getLastToken() const { + if (Pattern) { + return Pattern->getLastToken(); + } + if (Equals) { + return Equals; + } + return Name; + } + + Token* NamedRecordPattern::getFirstToken() const { + if (!ModulePath.empty()) { + return std::get<0>(ModulePath.back()); + } + return Name; + } + + Token* NamedRecordPattern::getLastToken() const { + return RBrace; + } + + Token* NamedTuplePattern::getFirstToken() const { + return Name; + } + + Token* NamedTuplePattern::getLastToken() const { if (Patterns.size()) { return Patterns.back()->getLastToken(); } diff --git a/bootstrap/cxx/src/Checker.cc b/bootstrap/cxx/src/Checker.cc index 029299884..1f9d304b9 100644 --- a/bootstrap/cxx/src/Checker.cc +++ b/bootstrap/cxx/src/Checker.cc @@ -1113,13 +1113,13 @@ namespace bolt { return Ty; } - case NodeKind::NamedPattern: + case NodeKind::NamedTuplePattern: { - auto P = static_cast(Pattern); + auto P = static_cast(Pattern); auto Scm = lookup(P->Name->getCanonicalText(), SymKind::Var); - std::vector ParamTypes; + std::vector ElementTypes; for (auto P2: P->Patterns) { - ParamTypes.push_back(inferPattern(P2, Constraints, TVs)); + ElementTypes.push_back(inferPattern(P2, Constraints, TVs)); } if (!Scm) { DE.add(P->Name->getCanonicalText(), P->Name); @@ -1127,7 +1127,32 @@ namespace bolt { } auto Ty = instantiate(Scm, P); auto RetTy = createTypeVar(); - makeEqual(Ty, Type::buildArrow(ParamTypes, RetTy), P); + makeEqual(Ty, Type::buildArrow(ElementTypes, RetTy), P); + return RetTy; + } + + case NodeKind::NamedRecordPattern: + { + auto P = static_cast(Pattern); + auto Scm = lookup(P->Name->getCanonicalText(), SymKind::Var); + if (Scm == nullptr) { + DE.add(P->Name->getCanonicalText(), P->Name); + return createTypeVar(); + } + auto RecordTy = new Type(TNil()); + for (auto [Field, Comma]: P->Fields) { + Type* FieldTy; + if (Field->Pattern) { + FieldTy = inferPattern(Field->Pattern, Constraints, TVs); + } else { + FieldTy = createTypeVar(); + addBinding(Field->Name->getCanonicalText(), new Forall(TVs, Constraints, FieldTy), SymKind::Var); + } + RecordTy = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(FieldTy)), RecordTy)); + } + auto Ty = instantiate(Scm, P); + auto RetTy = createTypeVar(); + makeEqual(Ty, new Type(TArrow(RecordTy, RetTy)), P); return RetTy; } diff --git a/bootstrap/cxx/src/ConsolePrinter.cc b/bootstrap/cxx/src/ConsolePrinter.cc index 860852db7..243ece772 100644 --- a/bootstrap/cxx/src/ConsolePrinter.cc +++ b/bootstrap/cxx/src/ConsolePrinter.cc @@ -89,7 +89,7 @@ namespace bolt { return "a tuple type expression"; case NodeKind::BindPattern: return "a variable binder"; - case NodeKind::NamedPattern: + case NodeKind::NamedTuplePattern: return "a pattern for a variant member"; case NodeKind::TuplePattern: return "a pattern for a tuple"; diff --git a/bootstrap/cxx/src/Parser.cc b/bootstrap/cxx/src/Parser.cc index 776daf020..4f8fe3425 100644 --- a/bootstrap/cxx/src/Parser.cc +++ b/bootstrap/cxx/src/Parser.cc @@ -150,6 +150,39 @@ finish: return new ListPattern { LBracket, Elements, RBracket }; } + std::optional>> Parser::parseRecordPatternFields() { + std::vector> Fields; + for (;;) { + auto T0 = Tokens.peek(); + if (T0->getKind() == NodeKind::RBrace) { + break; + } + auto Name = expectToken(); + Equals* Equals = nullptr; + Pattern* Pattern = nullptr; + auto T1 = Tokens.peek(); + if (T1->getKind() == NodeKind::Equals) { + Tokens.get(); + Equals = static_cast(T1); + Pattern = parseWidePattern(); + } + auto Field = new RecordPatternField(Name, Equals, Pattern); + auto T2 = Tokens.peek(); + if (T2->getKind() == NodeKind::RBrace) { + Fields.push_back(std::make_tuple(Field, nullptr)); + break; + } + if (T2->getKind() != NodeKind::Comma) { + DE.add(File, T2, std::vector { NodeKind::RBrace, NodeKind::Comma }); + return {}; + } + Tokens.get(); + auto Comma = static_cast(T2); + Fields.push_back(std::make_tuple(Field, Comma)); + } + return Fields; + } + Pattern* Parser::parsePrimitivePattern(bool IsNarrow) { auto T0 = Tokens.peek(); switch (T0->getKind()) { @@ -165,7 +198,19 @@ finish: Tokens.get(); auto Name = static_cast(T0); if (IsNarrow) { - return new NamedPattern(Name, {}); + return new NamedTuplePattern(Name, {}); + } + auto T1 = Tokens.peek(); + if (T1->getKind() == NodeKind::LBrace) { + auto LBrace = static_cast(T1); + Tokens.get(); + auto Fields = parseRecordPatternFields(); + if (!Fields) { + skipToRBrace(); + return nullptr; + } + auto RBrace = static_cast(Tokens.get()); + return new NamedRecordPattern({}, Name, LBrace, *Fields, RBrace); } std::vector Patterns; for (;;) { @@ -190,7 +235,7 @@ finish: } Patterns.push_back(P); } - return new NamedPattern { Name, Patterns }; + return new NamedTuplePattern { Name, Patterns }; } case NodeKind::LBracket: return parseListPattern(); @@ -523,20 +568,20 @@ after_tuple_element: } auto Pattern = parseWidePattern(); if (!Pattern) { - skipToLineFoldEnd(); + skipPastLineFoldEnd(); continue; } auto RArrowAlt = expectToken(); if (!RArrowAlt) { Pattern->unref(); - skipToLineFoldEnd(); + skipPastLineFoldEnd(); continue; } auto Expression = parseExpression(); if (!Expression) { Pattern->unref(); RArrowAlt->unref(); - skipToLineFoldEnd(); + skipPastLineFoldEnd(); continue; } checkLineFoldEnd(); @@ -853,7 +898,7 @@ finish: ExpressionStatement* Parser::parseExpressionStatement() { auto E = parseExpression(); if (!E) { - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } checkLineFoldEnd(); @@ -875,7 +920,7 @@ finish: Expression = parseExpression(); if (!Expression) { ReturnKeyword->unref(); - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } checkLineFoldEnd(); @@ -890,14 +935,14 @@ finish: auto Test = parseExpression(); if (!Test) { IfKeyword->unref(); - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } auto T1 = expectToken(); if (!T1) { IfKeyword->unref(); Test->unref(); - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } std::vector Then; @@ -980,7 +1025,7 @@ finish: if (Foreign) { Foreign->unref(); } - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } Let = static_cast(T0); @@ -1002,7 +1047,7 @@ finish: if (Mut) { Mut->unref(); } - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } @@ -1034,7 +1079,7 @@ after_params: if (TE) { TA = new TypeAssert(static_cast(T2), TE); } else { - skipToLineFoldEnd(); + skipPastLineFoldEnd(); goto finish; } T2 = Tokens.peek(); @@ -1064,7 +1109,7 @@ after_params: Tokens.get(); auto E = parseExpression(); if (!E) { - skipToLineFoldEnd(); + skipPastLineFoldEnd(); goto finish; } Body = new LetExprBody(static_cast(T2), E); @@ -1195,13 +1240,13 @@ after_vars: InstanceDeclaration* Parser::parseInstanceDeclaration() { auto InstanceKeyword = expectToken(); if (!InstanceKeyword) { - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } auto Name = expectToken(); if (!Name) { InstanceKeyword->unref(); - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } std::vector TypeExps; @@ -1217,7 +1262,7 @@ after_vars: for (auto TE: TypeExps) { TE->unref(); } - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } TypeExps.push_back(TE); @@ -1229,7 +1274,7 @@ after_vars: for (auto TE: TypeExps) { TE->unref(); } - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } std::vector Elements; @@ -1266,7 +1311,7 @@ after_vars: if (PubKeyword) { PubKeyword->unref(); } - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } auto Name = expectToken(); @@ -1275,7 +1320,7 @@ after_vars: PubKeyword->unref(); } ClassKeyword->unref(); - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } std::vector TypeVars; @@ -1293,7 +1338,7 @@ after_vars: for (auto TV: TypeVars) { TV->unref(); } - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } TypeVars.push_back(TE); @@ -1307,7 +1352,7 @@ after_vars: for (auto TV: TypeVars) { TV->unref(); } - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } std::vector Elements; @@ -1333,7 +1378,7 @@ after_vars: ); } - std::vector Parser::parseRecordFields() { + std::vector Parser::parseRecordDeclarationFields() { std::vector Fields; for (;;) { auto T1 = Tokens.peek(); @@ -1343,20 +1388,20 @@ after_vars: } auto Name = expectToken(); if (!Name) { - skipToLineFoldEnd(); + skipPastLineFoldEnd(); continue; } auto Colon = expectToken(); if (!Colon) { Name->unref(); - skipToLineFoldEnd(); + skipPastLineFoldEnd(); continue; } auto TE = parseTypeExpression(); if (!TE) { Name->unref(); Colon->unref(); - skipToLineFoldEnd(); + skipPastLineFoldEnd(); continue; } checkLineFoldEnd(); @@ -1377,7 +1422,7 @@ after_vars: if (Pub) { Pub->unref(); } - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } auto Name = expectToken(); @@ -1386,7 +1431,7 @@ after_vars: Pub->unref(); } Struct->unref(); - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } std::vector Vars; @@ -1407,10 +1452,10 @@ after_vars: } Struct->unref(); Name->unref(); - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } - auto Fields = parseRecordFields(); + auto Fields = parseRecordDeclarationFields(); Tokens.get()->unref(); // Always a LineFoldEnd return new RecordDeclaration { Pub, Struct, Name, Vars, BS, Fields }; } @@ -1427,7 +1472,7 @@ after_vars: if (Pub) { Pub->unref(); } - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } auto Name = expectToken(); @@ -1436,7 +1481,7 @@ after_vars: Pub->unref(); } Enum->unref(); - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } std::vector TVs; @@ -1457,7 +1502,7 @@ after_vars: } Enum->unref(); Name->unref(); - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } std::vector Members; @@ -1470,14 +1515,14 @@ next_member: } auto Name = expectToken(); if (!Name) { - skipToLineFoldEnd(); + skipPastLineFoldEnd(); continue; } auto T1 = Tokens.peek(); if (T1->getKind() == NodeKind::BlockStart) { Tokens.get(); auto BS = static_cast(T1); - auto Fields = parseRecordFields(); + auto Fields = parseRecordDeclarationFields(); // TODO continue; on error in Fields Members.push_back(new RecordVariantDeclarationMember { Name, BS, Fields }); } else { @@ -1514,7 +1559,7 @@ next_member: // TODO default: DE.add(File, T0, std::vector { NodeKind::LetKeyword, NodeKind::TypeKeyword }); - skipToLineFoldEnd(); + skipPastLineFoldEnd(); return nullptr; } } @@ -1584,7 +1629,7 @@ next_member: auto E = parseExpression(); if (!E) { At->unref(); - skipToLineFoldEnd(); + skipPastLineFoldEnd(); continue; } checkLineFoldEnd(); @@ -1602,26 +1647,111 @@ next_annotation:; return Annotations; } - void Parser::skipToLineFoldEnd() { - unsigned Level = 0; + void Parser::skipToRBrace() { + unsigned ParenLevel = 0; + unsigned BracketLevel = 0; + unsigned BraceLevel = 0; + unsigned BlockLevel = 0; + for (;;) { + auto T0 = Tokens.peek(); + switch (T0->getKind()) { + case NodeKind::EndOfFile: + return; + case NodeKind::LineFoldEnd: + Tokens.get()->unref(); + if (BlockLevel == 0 && ParenLevel == 0 && BracketLevel == 0 && BlockLevel == 0) { + return; + } + break; + case NodeKind::BlockStart: + Tokens.get()->unref(); + BlockLevel++; + break; + case NodeKind::BlockEnd: + Tokens.get()->unref(); + BlockLevel--; + break; + case NodeKind::LParen: + Tokens.get()->unref(); + ParenLevel++; + break; + case NodeKind::LBracket: + Tokens.get()->unref(); + BracketLevel++; + break; + case NodeKind::LBrace: + Tokens.get()->unref(); + BraceLevel++; + break; + case NodeKind::RParen: + Tokens.get()->unref(); + ParenLevel--; + break; + case NodeKind::RBracket: + Tokens.get()->unref(); + BracketLevel--; + break; + case NodeKind::RBrace: + if (BlockLevel == 0 && ParenLevel == 0 && BracketLevel == 0 && BlockLevel == 0) { + return; + } + Tokens.get()->unref(); + BraceLevel--; + break; + default: + Tokens.get()->unref(); + break; + } + } + } + + void Parser::skipPastLineFoldEnd() { + unsigned ParenLevel = 0; + unsigned BracketLevel = 0; + unsigned BraceLevel = 0; + unsigned BlockLevel = 0; for (;;) { auto T0 = Tokens.get(); switch (T0->getKind()) { case NodeKind::EndOfFile: return; - case NodeKind::LineFoldEnd: + case NodeKind::LineFoldEnd: T0->unref(); - if (Level == 0) { + if (BlockLevel == 0 && ParenLevel == 0 && BracketLevel == 0 && BlockLevel == 0) { return; } break; case NodeKind::BlockStart: T0->unref(); - Level++; + BlockLevel++; break; case NodeKind::BlockEnd: T0->unref(); - Level--; + BlockLevel--; + break; + case NodeKind::LParen: + T0->unref(); + ParenLevel++; + break; + case NodeKind::LBracket: + T0->unref(); + BracketLevel++; + break; + case NodeKind::LBrace: + T0->unref(); + BraceLevel++; + break; + case NodeKind::RParen: + T0->unref(); + ParenLevel--; + break; + case NodeKind::RBracket: + T0->unref(); + BracketLevel--; + break; + case NodeKind::RBrace: + T0->unref(); + BraceLevel--; break; default: T0->unref(); @@ -1636,7 +1766,7 @@ next_annotation:; Tokens.get()->unref(); } else { DE.add(File, T0, std::vector { NodeKind::LineFoldEnd }); - skipToLineFoldEnd(); + skipPastLineFoldEnd(); } }