From 5ac162cd72c6ef1b678d21df76e8b7d97f8835ce Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Tue, 23 May 2023 22:36:01 +0200 Subject: [PATCH] Add support for nested/tuple type expressions --- include/bolt/CST.hpp | 44 +++++++++++++++++++++++++++++++++++++ include/bolt/CSTVisitor.hpp | 35 +++++++++++++++++++++++++++++ src/CST.cc | 16 ++++++++++++++ src/Checker.cc | 20 +++++++++++++++++ src/Parser.cc | 35 ++++++++++++++++++++++++++++- 5 files changed, 149 insertions(+), 1 deletion(-) diff --git a/include/bolt/CST.hpp b/include/bolt/CST.hpp index 9e533b5a2..665f5db1e 100644 --- a/include/bolt/CST.hpp +++ b/include/bolt/CST.hpp @@ -70,6 +70,8 @@ namespace bolt { ReferenceTypeExpression, ArrowTypeExpression, VarTypeExpression, + NestedTypeExpression, + TupleTypeExpression, BindPattern, LiteralPattern, ReferenceExpression, @@ -1005,6 +1007,48 @@ namespace bolt { }; + class NestedTypeExpression : public TypeExpression { + public: + + LParen* LParen; + TypeExpression* TE; + RParen* RParen; + + inline NestedTypeExpression( + class LParen* LParen, + TypeExpression* TE, + class RParen* RParen + ): TypeExpression(NodeKind::NestedTypeExpression), + LParen(LParen), + TE(TE), + RParen(RParen) {} + + Token* getFirstToken() override; + Token* getLastToken() override; + + }; + + class TupleTypeExpression : public TypeExpression { + public: + + LParen* LParen; + std::vector> Elements; + RParen* RParen; + + inline TupleTypeExpression( + class LParen* LParen, + std::vector> Elements, + class RParen* RParen + ): TypeExpression(NodeKind::TupleTypeExpression), + LParen(LParen), + Elements(Elements), + RParen(RParen) {} + + Token* getFirstToken() override; + Token* getLastToken() override; + + }; + class Pattern : public Node { protected: diff --git a/include/bolt/CSTVisitor.hpp b/include/bolt/CSTVisitor.hpp index 8a9bba81d..19b5efc88 100644 --- a/include/bolt/CSTVisitor.hpp +++ b/include/bolt/CSTVisitor.hpp @@ -99,6 +99,10 @@ namespace bolt { return static_cast(this)->visitArrowTypeExpression(static_cast(N)); case NodeKind::VarTypeExpression: return static_cast(this)->visitVarTypeExpression(static_cast(N)); + case NodeKind::NestedTypeExpression: + return static_cast(this)->visitNestedTypeExpression(static_cast(N)); + case NodeKind::TupleTypeExpression: + return static_cast(this)->visitTupleTypeExpression(static_cast(N)); case NodeKind::BindPattern: return static_cast(this)->visitBindPattern(static_cast(N)); case NodeKind::LiteralPattern: @@ -348,6 +352,14 @@ namespace bolt { visitTypeExpression(N); } + void visitNestedTypeExpression(NestedTypeExpression* N) { + visitTypeExpression(N); + } + + void visitTupleTypeExpression(TupleTypeExpression* N) { + visitTypeExpression(N); + } + void visitPattern(Pattern* N) { visitNode(N); } @@ -604,6 +616,12 @@ namespace bolt { case NodeKind::VarTypeExpression: visitEachChild(static_cast(N)); break; + case NodeKind::NestedTypeExpression: + visitEachChild(static_cast(N)); + break; + case NodeKind::TupleTypeExpression: + visitEachChild(static_cast(N)); + break; case NodeKind::BindPattern: visitEachChild(static_cast(N)); break; @@ -846,6 +864,23 @@ namespace bolt { BOLT_VISIT(N->Name); } + void visitEachChild(NestedTypeExpression* N) { + BOLT_VISIT(N->LParen); + BOLT_VISIT(N->TE); + BOLT_VISIT(N->RParen); + } + + void visitEachChild(TupleTypeExpression* N) { + BOLT_VISIT(N->LParen); + for (auto [TE, Comma]: N->Elements) { + if (Comma) { + BOLT_VISIT(Comma); + } + BOLT_VISIT(TE); + } + BOLT_VISIT(N->RParen); + } + void visitEachChild(BindPattern* N) { BOLT_VISIT(N->Name); } diff --git a/src/CST.cc b/src/CST.cc index ff4a3a574..7082195ea 100644 --- a/src/CST.cc +++ b/src/CST.cc @@ -229,6 +229,22 @@ namespace bolt { return Name; } + Token* NestedTypeExpression::getLastToken() { + return LParen; + } + + Token* NestedTypeExpression::getFirstToken() { + return RParen; + } + + Token* TupleTypeExpression::getLastToken() { + return LParen; + } + + Token* TupleTypeExpression::getFirstToken() { + return RParen; + } + Token* BindPattern::getFirstToken() { return Name; } diff --git a/src/Checker.cc b/src/Checker.cc index 103b6eec1..3c61cf49d 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -585,6 +585,26 @@ namespace bolt { return Ty; } + case NodeKind::TupleTypeExpression: + { + auto TupleTE = static_cast(N); + std::vector ElementTypes; + for (auto [TE, Comma]: TupleTE->Elements) { + ElementTypes.push_back(inferTypeExpression(TE)); + } + auto Ty = new TTuple(ElementTypes); + N->setType(Ty); + return Ty; + } + + case NodeKind::NestedTypeExpression: + { + auto NestedTE = static_cast(N); + auto Ty = inferTypeExpression(NestedTE->TE); + N->setType(Ty); + return Ty; + } + case NodeKind::ArrowTypeExpression: { auto ArrowTE = static_cast(N); diff --git a/src/Parser.cc b/src/Parser.cc index 16b95ba9e..0a5806901 100644 --- a/src/Parser.cc +++ b/src/Parser.cc @@ -173,6 +173,39 @@ after_constraints: switch (T0->getKind()) { case NodeKind::Identifier: return parseVarTypeExpression(); + case NodeKind::LParen: + { + Tokens.get(); + auto LParen = static_cast(T0); + std::vector> Elements; + RParen* RParen; + for (;;) { + auto T1 = Tokens.peek(); + if (llvm::isa(T1)) { + Tokens.get(); + RParen = static_cast(T1); + break; + } + auto TE = parseTypeExpression(); + auto T2 = Tokens.get(); + switch (T2->getKind()) { + case NodeKind::RParen: + RParen = static_cast(T1); + Elements.push_back({ TE, nullptr }); + goto after_tuple_element; + case NodeKind::Comma: + Elements.push_back({ TE, static_cast(T2) }); + continue; + default: + throw UnexpectedTokenDiagnostic(File, T2, { NodeKind::Comma, NodeKind::RParen }); + } + } +after_tuple_element: + if (Elements.size() == 1) { + return new NestedTypeExpression { LParen, std::get<0>(Elements.front()), RParen }; + } + return new TupleTypeExpression { LParen, Elements, RParen }; + } case NodeKind::IdentifierAlt: { std::vector> ModulePath; @@ -547,7 +580,7 @@ after_params: return parseExpressionStatement(); } } -# + ConstraintExpression* Parser::parseConstraintExpression() { bool HasTilde = false; for (std::size_t I = 0; ; I++) {