diff --git a/include/bolt/CST.hpp b/include/bolt/CST.hpp index dcbe99333..67ca881b5 100644 --- a/include/bolt/CST.hpp +++ b/include/bolt/CST.hpp @@ -140,7 +140,9 @@ namespace bolt { BindPattern, LiteralPattern, NamedPattern, + TuplePattern, NestedPattern, + ListPattern, ReferenceExpression, MatchCase, MatchExpression, @@ -1244,6 +1246,27 @@ namespace bolt { }; + class TuplePattern : public Pattern { + public: + + LParen* LParen; + std::vector> Elements; + RParen* RParen; + + inline TuplePattern( + class LParen* LParen, + std::vector> Elements, + class RParen* RParen + ): Pattern(NodeKind::TuplePattern), + LParen(LParen), + Elements(Elements), + RParen(RParen) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + }; + class NestedPattern : public Pattern { public: @@ -1265,6 +1288,28 @@ namespace bolt { }; + class ListPattern : public Pattern { + public: + + class LBracket* LBracket; + std::vector> Elements; + class RBracket* RBracket; + + inline ListPattern( + class LBracket* LBracket, + std::vector> Elements, + class RBracket* RBracket + ): Pattern(NodeKind::ListPattern), + LBracket(LBracket), + Elements(Elements), + RBracket(RBracket) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + }; + + class Expression : public TypedNode { protected: diff --git a/include/bolt/CSTVisitor.hpp b/include/bolt/CSTVisitor.hpp index 74b1d9923..c45015643 100644 --- a/include/bolt/CSTVisitor.hpp +++ b/include/bolt/CSTVisitor.hpp @@ -70,7 +70,9 @@ namespace bolt { BOLT_GEN_CASE(BindPattern) BOLT_GEN_CASE(LiteralPattern) BOLT_GEN_CASE(NamedPattern) + BOLT_GEN_CASE(TuplePattern) BOLT_GEN_CASE(NestedPattern) + BOLT_GEN_CASE(ListPattern) BOLT_GEN_CASE(ReferenceExpression) BOLT_GEN_CASE(MatchCase) BOLT_GEN_CASE(MatchExpression) @@ -334,10 +336,18 @@ namespace bolt { visitPattern(N); } + void visitTuplePattern(TuplePattern* N) { + visitPattern(N); + } + void visitNestedPattern(NestedPattern* N) { visitPattern(N); } + void visitListPattern(ListPattern* N) { + visitPattern(N); + } + void visitExpression(Expression* N) { visitNode(N); } @@ -536,7 +546,9 @@ namespace bolt { BOLT_GEN_CHILD_CASE(BindPattern) BOLT_GEN_CHILD_CASE(LiteralPattern) BOLT_GEN_CHILD_CASE(NamedPattern) + BOLT_GEN_CHILD_CASE(TuplePattern) BOLT_GEN_CHILD_CASE(NestedPattern) + BOLT_GEN_CHILD_CASE(ListPattern) BOLT_GEN_CHILD_CASE(ReferenceExpression) BOLT_GEN_CHILD_CASE(MatchCase) BOLT_GEN_CHILD_CASE(MatchExpression) @@ -774,12 +786,34 @@ namespace bolt { } } + void visitEachChild(TuplePattern* N) { + BOLT_VISIT(N->LParen); + for (auto [P, Comma]: N->Elements) { + BOLT_VISIT(P); + if (Comma) { + BOLT_VISIT(Comma); + } + } + BOLT_VISIT(N->RParen); + } + void visitEachChild(NestedPattern* N) { BOLT_VISIT(N->LParen); BOLT_VISIT(N->P); BOLT_VISIT(N->RParen); } + void visitEachChild(ListPattern* N) { + BOLT_VISIT(N->LBracket); + for (auto [Element, Separator]: N->Elements) { + BOLT_VISIT(Element); + if (Separator) { + BOLT_VISIT(Separator); + } + } + BOLT_VISIT(N->RBracket); + } + void visitEachChild(ReferenceExpression* N) { for (auto [Name, Dot]: N->ModulePath) { BOLT_VISIT(Name); diff --git a/include/bolt/Checker.hpp b/include/bolt/Checker.hpp index a01641db5..78eb4406c 100644 --- a/include/bolt/Checker.hpp +++ b/include/bolt/Checker.hpp @@ -175,6 +175,7 @@ namespace bolt { size_t NextTypeVarId = 0; Type* BoolType; + Type* ListType; Type* IntType; Type* StringType; diff --git a/include/bolt/Parser.hpp b/include/bolt/Parser.hpp index 9be4f7caa..86b26555f 100644 --- a/include/bolt/Parser.hpp +++ b/include/bolt/Parser.hpp @@ -103,8 +103,10 @@ namespace bolt { TypeExpression* parseTypeExpression(); - Pattern* parsePrimitivePattern(); - Pattern* parsePattern(); + ListPattern* parseListPattern(); + Pattern* parsePrimitivePattern(bool IsNarrow); + Pattern* parseWidePattern(); + Pattern* parseNarrowPattern(); Parameter* parseParam(); diff --git a/src/CST.cc b/src/CST.cc index a18e956e9..5cb02fbbb 100644 --- a/src/CST.cc +++ b/src/CST.cc @@ -163,6 +163,22 @@ namespace bolt { visitPattern(Y->P, Decl); break; } + case NodeKind::TuplePattern: + { + auto Y = static_cast(X); + for (auto [Element, Comma]: Y->Elements) { + visitPattern(Element, Decl); + } + break; + } + case NodeKind::ListPattern: + { + auto Y = static_cast(X); + for (auto [Element, Separator]: Y->Elements) { + visitPattern(Element, Decl); + } + break; + } case NodeKind::LiteralPattern: break; default: @@ -278,8 +294,8 @@ namespace bolt { struct UnrefVisitor : public CSTVisitor { void visit(Node* N) { - N->unref(); visitEachChild(N); + N->unref(); } }; @@ -412,6 +428,14 @@ namespace bolt { return Name; } + Token* TuplePattern::getFirstToken() const { + return LParen; + } + + Token* TuplePattern::getLastToken() const { + return RParen; + } + Token* NestedPattern::getFirstToken() const { return LParen; } @@ -420,6 +444,14 @@ namespace bolt { return RParen; } + Token* ListPattern::getFirstToken() const { + return LBracket; + } + + Token* ListPattern::getLastToken() const { + return RBracket; + } + Token* ReferenceExpression::getFirstToken() const { if (!ModulePath.empty()) { return std::get<0>(ModulePath.front()); diff --git a/src/Checker.cc b/src/Checker.cc index 56814602a..cce82d596 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -18,6 +18,8 @@ // TODO Add a pattern that only performs a type assert +// TODO create the constraint in addConstraint, not the other way round + #include #include #include @@ -101,6 +103,7 @@ namespace bolt { BoolType = createConType("Bool"); IntType = createConType("Int"); StringType = createConType("String"); + ListType = createConType("List"); } Scheme* Checker::lookup(ByteString Name) { @@ -1075,6 +1078,26 @@ namespace bolt { return RetTy; } + case NodeKind::TuplePattern: + { + auto P = static_cast(Pattern); + std::vector ElementTypes; + for (auto [Element, Comma]: P->Elements) { + ElementTypes.push_back(inferPattern(Element)); + } + return new TTuple(ElementTypes); + } + + case NodeKind::ListPattern: + { + auto P = static_cast(Pattern); + auto ElementType = createTypeVar(); + for (auto [Element, Separator]: P->Elements) { + addConstraint(new CEqual(ElementType, inferPattern(Element), P)); + } + return new TApp(ListType, ElementType); + } + case NodeKind::NestedPattern: { auto P = static_cast(Pattern); @@ -1292,6 +1315,7 @@ namespace bolt { addBinding("String", new Forall(StringType)); addBinding("Int", new Forall(IntType)); addBinding("Bool", new Forall(BoolType)); + addBinding("List", new Forall(ListType)); addBinding("True", new Forall(BoolType)); addBinding("False", new Forall(BoolType)); auto A = createTypeVar(); diff --git a/src/Diagnostics.cc b/src/Diagnostics.cc index 2c45fcf42..34548f749 100644 --- a/src/Diagnostics.cc +++ b/src/Diagnostics.cc @@ -128,6 +128,8 @@ namespace bolt { return "an if-statement"; case NodeKind::IfStatementPart: return "a branch of an if-statement"; + case NodeKind::ListPattern: + return "a list pattern"; default: ZEN_UNREACHABLE } diff --git a/src/Parser.cc b/src/Parser.cc index 45002c9ca..f41c107d2 100644 --- a/src/Parser.cc +++ b/src/Parser.cc @@ -102,7 +102,49 @@ namespace bolt { return T; } - Pattern* Parser::parsePrimitivePattern() { + ListPattern* Parser::parseListPattern() { + auto LBracket = expectToken(); + if (!LBracket) { + return nullptr; + } + std::vector> Elements; + RBracket* RBracket; + auto T0 = Tokens.peek(); + if (T0->getKind() == NodeKind::RBracket) { + Tokens.get(); + RBracket = static_cast(T0); + goto finish; + } + for (;;) { + auto P = parseWidePattern(); + if (!P) { + LBracket->unref(); + for (auto [Element, Separator]: Elements) { + Element->unref(); + Separator->unref(); + } + return nullptr; + } + auto T1 = Tokens.peek(); + switch (T1->getKind()) { + case NodeKind::Comma: + Tokens.get(); + Elements.push_back(std::make_tuple(P, static_cast(T1))); + break; + case NodeKind::RBracket: + Tokens.get(); + Elements.push_back(std::make_tuple(P, nullptr)); + RBracket = static_cast(T1); + goto finish; + default: + DE.add(File, T1, std::vector { NodeKind::Comma, NodeKind::RBracket }); + } + } +finish: + return new ListPattern { LBracket, Elements, RBracket }; + } + + Pattern* Parser::parsePrimitivePattern(bool IsNarrow) { auto T0 = Tokens.peek(); switch (T0->getKind()) { case NodeKind::StringLiteral: @@ -113,60 +155,102 @@ namespace bolt { Tokens.get(); return new BindPattern(static_cast(T0)); case NodeKind::IdentifierAlt: + { Tokens.get(); - return new NamedPattern(static_cast(T0), {}); + auto Name = static_cast(T0); + if (IsNarrow) { + return new NamedPattern(Name, {}); + } + std::vector Patterns; + for (;;) { + auto T2 = Tokens.peek(); + if (T2->getKind() == NodeKind::RParen + || T2->getKind() == NodeKind::RBracket + || T2->getKind() == NodeKind::RBrace + || T2->getKind() == NodeKind::Comma + || T2->getKind() == NodeKind::Colon + || T2->getKind() == NodeKind::Equals + || T2->getKind() == NodeKind::BlockStart + || T2->getKind() == NodeKind::RArrowAlt) { + break; + } + auto P = parseNarrowPattern(); + if (!P) { + Name->unref(); + for (auto P: Patterns) { + P->unref(); + } + return nullptr; + } + Patterns.push_back(P); + } + return new NamedPattern { Name, Patterns }; + } + case NodeKind::LBracket: + return parseListPattern(); case NodeKind::LParen: { Tokens.get(); auto LParen = static_cast(T0); - auto T1 = Tokens.peek(); + std::vector> Elements; RParen* RParen; - if (T1->getKind() == NodeKind::IdentifierAlt) { - Tokens.get(); - auto Name = static_cast(T1); - std::vector Patterns; - for (;;) { - auto T2 = Tokens.peek(); - if (T2->getKind() == NodeKind::RParen) { - Tokens.get(); - RParen = static_cast(T2); - break; - } - auto P = parsePrimitivePattern(); - if (!P) { - LParen->unref(); - for (auto P: Patterns) { - P->unref(); - } - return nullptr; - } - Patterns.push_back(P); - } - return new NestedPattern { LParen, new NamedPattern { Name, Patterns }, RParen }; - } else { - auto P = parsePattern(); + for (;;) { + auto P = parseWidePattern(); if (!P) { LParen->unref(); + for (auto [P, Comma]: Elements) { + P->unref(); + Comma->unref(); + } + // TODO maybe skip to next comma? return nullptr; } - auto RParen = expectToken(); - if (!RParen) { + auto T1 = Tokens.peek(); + if (T1->getKind() == NodeKind::Comma) { + Tokens.get(); + Elements.push_back(std::make_tuple(P, static_cast(T1))); + } else if (T1->getKind() == NodeKind::RParen) { + Tokens.get(); + RParen = static_cast(T1); + Elements.push_back(std::make_tuple(P, nullptr)); + break; + } else { + DE.add(File, T1, std::vector { NodeKind::Comma, NodeKind::RParen }); LParen->unref(); - P->unref(); + for (auto [P, Comma]: Elements) { + P->unref(); + Comma->unref(); + } + // TODO maybe skip to next comma? return nullptr; + } - return new NestedPattern { LParen, P, RParen }; } + if (Elements.size() == 1) { + return new NestedPattern { LParen, std::get<0>(Elements.front()), RParen }; + } + return new TuplePattern(LParen, Elements, RParen); } default: // Tokens.get(); - DE.add(File, T0, std::vector { NodeKind::Identifier, NodeKind::IdentifierAlt, NodeKind::StringLiteral, NodeKind::IntegerLiteral }); + DE.add(File, T0, std::vector { + NodeKind::Identifier, + NodeKind::IdentifierAlt, + NodeKind::StringLiteral, + NodeKind::IntegerLiteral, + NodeKind::LParen, + NodeKind::LBracket + }); return nullptr; } } - Pattern* Parser::parsePattern() { - return parsePrimitivePattern(); + Pattern* Parser::parseWidePattern() { + return parsePrimitivePattern(false); + } + + Pattern* Parser::parseNarrowPattern() { + return parsePrimitivePattern(true); } TypeExpression* Parser::parseTypeExpression() { @@ -431,7 +515,7 @@ after_tuple_element: Tokens.get()->unref(); break; } - auto Pattern = parsePattern(); + auto Pattern = parseWidePattern(); if (!Pattern) { skipToLineFoldEnd(); continue; @@ -863,7 +947,7 @@ VariableDeclaration* Parser::parseVariableDeclaration() { Tokens.get(); } - auto P = parsePattern(); + auto P = parseWidePattern(); if (!P) { if (Pub) { Pub->unref(); @@ -993,7 +1077,7 @@ finish: case NodeKind::Colon: goto after_params; default: - auto P = parsePattern(); + auto P = parseNarrowPattern(); if (!P) { Tokens.get(); P = new BindPattern(new Identifier("_"));