diff --git a/include/bolt/CST.hpp b/include/bolt/CST.hpp index 2911db023..cfb896ab2 100644 --- a/include/bolt/CST.hpp +++ b/include/bolt/CST.hpp @@ -51,6 +51,7 @@ namespace bolt { ElifKeyword, IfKeyword, ElseKeyword, + MatchKeyword, Invalid, EndOfFile, BlockStart, @@ -70,6 +71,8 @@ namespace bolt { VarTypeExpression, BindPattern, ReferenceExpression, + MatchCase, + MatchExpression, NestedExpression, ConstantExpression, CallExpression, @@ -588,6 +591,20 @@ namespace bolt { }; + class MatchKeyword : public Token { + public: + + inline MatchKeyword(TextLoc StartLoc): + Token(NodeKind::MatchKeyword, StartLoc) {} + + std::string getText() const override; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::MatchKeyword; + } + + }; + class Invalid : public Token { public: @@ -969,6 +986,52 @@ namespace bolt { }; + class MatchCase : public Node { + public: + + class Pattern* Pattern; + class RArrowAlt* RArrowAlt; + class Expression* Expression; + + inline MatchCase( + class Pattern* Pattern, + class RArrowAlt* RArrowAlt, + class Expression* Expression + ): Node(NodeKind::MatchCase), + Pattern(Pattern), + RArrowAlt(RArrowAlt), + Expression(Expression) {} + + Token* getFirstToken() override; + Token* getLastToken() override; + + }; + + + class MatchExpression : public Expression { + public: + + class MatchKeyword* MatchKeyword; + Expression* Value; + class BlockStart* BlockStart; + std::vector Cases; + + inline MatchExpression( + class MatchKeyword* MatchKeyword, + Expression* Value, + class BlockStart* BlockStart, + std::vector Cases + ): Expression(NodeKind::MatchExpression), + MatchKeyword(MatchKeyword), + Value(Value), + BlockStart(BlockStart), + Cases(Cases) {} + + Token* getFirstToken() override; + Token* getLastToken() override; + + }; + class NestedExpression : public Expression { public: diff --git a/include/bolt/CSTVisitor.hpp b/include/bolt/CSTVisitor.hpp index 26583f271..b8cd578c7 100644 --- a/include/bolt/CSTVisitor.hpp +++ b/include/bolt/CSTVisitor.hpp @@ -63,6 +63,8 @@ namespace bolt { return static_cast(this)->visitIfKeyword(static_cast(N)); case NodeKind::ElseKeyword: return static_cast(this)->visitElseKeyword(static_cast(N)); + case NodeKind::MatchKeyword: + return static_cast(this)->visitMatchKeyword(static_cast(N)); case NodeKind::Invalid: return static_cast(this)->visitInvalid(static_cast(N)); case NodeKind::EndOfFile: @@ -101,6 +103,10 @@ namespace bolt { return static_cast(this)->visitBindPattern(static_cast(N)); case NodeKind::ReferenceExpression: return static_cast(this)->visitReferenceExpression(static_cast(N)); + case NodeKind::MatchCase: + return static_cast(this)->visitMatchCase(static_cast(N)); + case NodeKind::MatchExpression: + return static_cast(this)->visitMatchExpression(static_cast(N)); case NodeKind::NestedExpression: return static_cast(this)->visitNestedExpression(static_cast(N)); case NodeKind::ConstantExpression: @@ -256,6 +262,10 @@ namespace bolt { visitToken(N); } + void visitMatchKeyword(MatchKeyword* N) { + visitToken(N); + } + void visitInvalid(Invalid* N) { visitToken(N); } @@ -348,6 +358,14 @@ namespace bolt { visitExpression(N); } + void visitMatchCase(MatchCase* N) { + visitNode(N); + } + + void visitMatchExpression(MatchExpression* N) { + visitExpression(N); + } + void visitNestedExpression(NestedExpression* N) { visitExpression(N); } @@ -514,6 +532,9 @@ namespace bolt { case NodeKind::ElseKeyword: visitEachChild(static_cast(N)); break; + case NodeKind::MatchKeyword: + visitEachChild(static_cast(N)); + break; case NodeKind::Invalid: visitEachChild(static_cast(N)); break; @@ -571,6 +592,12 @@ namespace bolt { case NodeKind::ReferenceExpression: visitEachChild(static_cast(N)); break; + case NodeKind::MatchCase: + visitEachChild(static_cast(N)); + break; + case NodeKind::MatchExpression: + visitEachChild(static_cast(N)); + break; case NodeKind::NestedExpression: visitEachChild(static_cast(N)); break; @@ -713,6 +740,9 @@ namespace bolt { void visitEachChild(ElseKeyword* N) { } + void visitEachChild(MatchKeyword* N) { + } + void visitEachChild(Invalid* N) { } @@ -801,6 +831,23 @@ namespace bolt { BOLT_VISIT(N->Name); } + void visitEachChild(MatchCase* N) { + BOLT_VISIT(N->Pattern); + BOLT_VISIT(N->RArrowAlt); + BOLT_VISIT(N->Expression); + } + + void visitEachChild(MatchExpression* N) { + BOLT_VISIT(N->MatchKeyword); + if (N->Value) { + BOLT_VISIT(N->Value); + } + BOLT_VISIT(N->BlockStart); + for (auto Case: N->Cases) { + BOLT_VISIT(Case); + } + } + void visitEachChild(NestedExpression* N) { BOLT_VISIT(N->LParen); BOLT_VISIT(N->Inner); diff --git a/src/CST.cc b/src/CST.cc index 42ce0bbe7..d9f8899f2 100644 --- a/src/CST.cc +++ b/src/CST.cc @@ -232,6 +232,25 @@ namespace bolt { return Name; } + Token* MatchCase::getFirstToken() { + return Pattern->getFirstToken(); + } + + Token* MatchCase::getLastToken() { + return Expression->getLastToken(); + } + + Token* MatchExpression::getFirstToken() { + return MatchKeyword; + } + + Token* MatchExpression::getLastToken() { + if (!Cases.empty()) { + return Cases.back()->getLastToken(); + } + return BlockStart; + } + Token* NestedExpression::getFirstToken() { return LParen; } @@ -514,6 +533,10 @@ namespace bolt { return "elif"; } + std::string MatchKeyword::getText() const { + return "match"; + } + std::string ModKeyword::getText() const { return "mod"; } diff --git a/src/Parser.cc b/src/Parser.cc index 677e6cb6f..53d86a5d0 100644 --- a/src/Parser.cc +++ b/src/Parser.cc @@ -231,6 +231,33 @@ after_constraints: auto T2 = static_cast(expectToken(NodeKind::RParen)); return new NestedExpression(static_cast(T0), E, T2); } + case NodeKind::MatchKeyword: + { + Tokens.get(); + auto T1 = Tokens.peek(); + Expression* Value = nullptr; + BlockStart* BlockStart; + if (llvm::isa(T1)) { + BlockStart = static_cast(T1); + } else { + Value = parseExpression(); + BlockStart = expectToken(); + } + std::vector Cases; + for (;;) { + auto T2 = Tokens.peek(); + if (llvm::isa(T2)) { + Tokens.get(); + break; + } + auto Pattern = parsePattern(); + auto RArrowAlt = expectToken(); + auto Expression = parseExpression(); + expectToken(); + Cases.push_back(new MatchCase { Pattern, RArrowAlt, Expression }); + } + return new MatchExpression(static_cast(T0), Value, BlockStart, Cases); + } case NodeKind::IntegerLiteral: case NodeKind::StringLiteral: Tokens.get(); diff --git a/src/Scanner.cc b/src/Scanner.cc index 6e889b2c5..db58d3de8 100644 --- a/src/Scanner.cc +++ b/src/Scanner.cc @@ -69,6 +69,7 @@ namespace bolt { { "if", NodeKind::IfKeyword }, { "else", NodeKind::ElseKeyword }, { "elif", NodeKind::ElifKeyword }, + { "match", NodeKind::MatchKeyword }, { "class", NodeKind::ClassKeyword }, { "instance", NodeKind::InstanceKeyword }, }; @@ -235,6 +236,8 @@ digit_finish: return new ElifKeyword(StartLoc); case NodeKind::ElseKeyword: return new ElseKeyword(StartLoc); + case NodeKind::MatchKeyword: + return new MatchKeyword(StartLoc); case NodeKind::ClassKeyword: return new ClassKeyword(StartLoc); case NodeKind::InstanceKeyword: