diff --git a/include/bolt/CST.hpp b/include/bolt/CST.hpp index 13e15fd80..56ef96b7e 100644 --- a/include/bolt/CST.hpp +++ b/include/bolt/CST.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "zen/config.hpp" @@ -70,6 +71,7 @@ namespace bolt { ArrowTypeExpression, VarTypeExpression, BindPattern, + LiteralPattern, ReferenceExpression, MatchCase, MatchExpression, @@ -760,32 +762,53 @@ namespace bolt { }; - class StringLiteral : public Token { + using Value = std::variant; + + class Literal : public Token { + public: + + inline Literal(NodeKind Kind, TextLoc StartLoc): + Token(Kind, StartLoc) {} + + virtual Value getValue() = 0; + + static bool classof(const Node* N) { + return N->getKind() == NodeKind::StringLiteral + || N->getKind() == NodeKind::IntegerLiteral; + } + + }; + + class StringLiteral : public Literal { public: ByteString Text; StringLiteral(ByteString Text, TextLoc StartLoc): - Token(NodeKind::StringLiteral, StartLoc), Text(Text) {} + Literal(NodeKind::StringLiteral, StartLoc), Text(Text) {} std::string getText() const override; + Value getValue() override; + static bool classof(const Node* N) { return N->getKind() == NodeKind::StringLiteral; } }; - class IntegerLiteral : public Token { + class IntegerLiteral : public Literal { public: - Integer Value; + Integer V; IntegerLiteral(Integer Value, TextLoc StartLoc): - Token(NodeKind::IntegerLiteral, StartLoc), Value(Value) {} + Literal(NodeKind::IntegerLiteral, StartLoc), V(Value) {} std::string getText() const override; + Value getValue() override; + static bool classof(const Node* N) { return N->getKind() == NodeKind::IntegerLiteral; } @@ -977,6 +1000,20 @@ namespace bolt { }; + class LiteralPattern : public Pattern { + public: + + class Literal* Literal; + + LiteralPattern(class Literal* Literal): + Pattern(NodeKind::LiteralPattern), + Literal(Literal) {} + + Token* getFirstToken() override; + Token* getLastToken() override; + + }; + class Expression : public TypedNode { protected: @@ -1074,10 +1111,10 @@ namespace bolt { class ConstantExpression : public Expression { public: - class Token* Token; + class Literal* Token; ConstantExpression( - class Token* Token + class Literal* Token ): Expression(NodeKind::ConstantExpression), Token(Token) {} diff --git a/include/bolt/CSTVisitor.hpp b/include/bolt/CSTVisitor.hpp index b8cd578c7..e6fb99ec4 100644 --- a/include/bolt/CSTVisitor.hpp +++ b/include/bolt/CSTVisitor.hpp @@ -101,6 +101,8 @@ namespace bolt { return static_cast(this)->visitVarTypeExpression(static_cast(N)); case NodeKind::BindPattern: return static_cast(this)->visitBindPattern(static_cast(N)); + case NodeKind::LiteralPattern: + return static_cast(this)->visitLiteralPattern(static_cast(N)); case NodeKind::ReferenceExpression: return static_cast(this)->visitReferenceExpression(static_cast(N)); case NodeKind::MatchCase: @@ -350,6 +352,10 @@ namespace bolt { visitPattern(N); } + void visitLiteralPattern(LiteralPattern* N) { + visitPattern(N); + } + void visitExpression(Expression* N) { visitNode(N); } @@ -589,6 +595,9 @@ namespace bolt { case NodeKind::BindPattern: visitEachChild(static_cast(N)); break; + case NodeKind::LiteralPattern: + visitEachChild(static_cast(N)); + break; case NodeKind::ReferenceExpression: visitEachChild(static_cast(N)); break; @@ -823,6 +832,10 @@ namespace bolt { BOLT_VISIT(N->Name); } + void visitEachChild(LiteralPattern* N) { + BOLT_VISIT(N->Literal); + } + void visitEachChild(ReferenceExpression* N) { for (auto [Name, Dot]: N->ModulePath) { BOLT_VISIT(Name); diff --git a/include/bolt/Checker.hpp b/include/bolt/Checker.hpp index b96458e44..e6a6a4df5 100644 --- a/include/bolt/Checker.hpp +++ b/include/bolt/Checker.hpp @@ -1,4 +1,5 @@ + #pragma once #include "zen/config.hpp" @@ -355,6 +356,7 @@ namespace bolt { Type* inferExpression(Expression* Expression); Type* inferTypeExpression(TypeExpression* TE); + Type* inferLiteral(Literal* Lit); void inferBindings(Pattern* Pattern, Type* T, ConstraintSet* Constraints, TVSet* TVs); void inferBindings(Pattern* Pattern, Type* T); diff --git a/src/CST.cc b/src/CST.cc index 25d0bce78..a39c5d3fa 100644 --- a/src/CST.cc +++ b/src/CST.cc @@ -221,6 +221,14 @@ namespace bolt { return Name; } + Token* LiteralPattern::getFirstToken() { + return Literal; + } + + Token* LiteralPattern::getLastToken() { + return Literal; + } + Token* ReferenceExpression::getFirstToken() { if (!ModulePath.empty()) { return std::get<0>(ModulePath.front()); @@ -586,7 +594,7 @@ namespace bolt { } std::string IntegerLiteral::getText() const { - return std::to_string(Value); + return std::to_string(V); } std::string DotDot::getText() const { @@ -613,6 +621,14 @@ namespace bolt { return Text; } + Value StringLiteral::getValue() { + return Text; + } + + Value IntegerLiteral::getValue() { + return V; + } + SymbolPath ReferenceExpression::getSymbolPath() const { std::vector ModuleNames; for (auto [Name, Dot]: ModulePath) { diff --git a/src/Checker.cc b/src/Checker.cc index 5fe1b2b25..4a5529017 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -723,18 +723,7 @@ namespace bolt { case NodeKind::ConstantExpression: { auto Const = static_cast(X); - Type* Ty = nullptr; - switch (Const->Token->getKind()) { - case NodeKind::IntegerLiteral: - Ty = lookupMono("Int"); - break; - case NodeKind::StringLiteral: - Ty = lookupMono("String"); - break; - default: - ZEN_UNREACHABLE - } - ZEN_ASSERT(Ty != nullptr); + auto Ty = inferLiteral(Const->Token); X->setType(Ty); return Ty; } @@ -815,7 +804,15 @@ namespace bolt { case NodeKind::BindPattern: { - addBinding(static_cast(Pattern)->Name->getCanonicalText(), new Forall(TVs, Constraints, Type)); + auto P = static_cast(Pattern); + addBinding(P->Name->getCanonicalText(), new Forall(TVs, Constraints, Type)); + break; + } + + case NodeKind::LiteralPattern: + { + auto P = static_cast(Pattern); + addConstraint(new CEqual(inferLiteral(P->Literal), Type, P)); break; } @@ -830,6 +827,22 @@ namespace bolt { inferBindings(Pattern, Type, new ConstraintSet, new TVSet); } + Type* Checker::inferLiteral(Literal* L) { + Type* Ty; + switch (L->getKind()) { + case NodeKind::IntegerLiteral: + Ty = lookupMono("Int"); + break; + case NodeKind::StringLiteral: + Ty = lookupMono("String"); + break; + default: + ZEN_UNREACHABLE + } + ZEN_ASSERT(Ty != nullptr); + return Ty; + } + void collectTypeclasses(LetDeclaration* Decl, std::vector& Out) { if (llvm::isa(Decl->Parent)) { auto Class = llvm::cast(Decl->Parent); diff --git a/src/Parser.cc b/src/Parser.cc index b0049923c..ad58cfb46 100644 --- a/src/Parser.cc +++ b/src/Parser.cc @@ -89,11 +89,15 @@ namespace bolt { Pattern* Parser::parsePattern() { auto T0 = Tokens.peek(); switch (T0->getKind()) { + case NodeKind::StringLiteral: + case NodeKind::IntegerLiteral: + Tokens.get(); + return new LiteralPattern(static_cast(T0)); case NodeKind::Identifier: Tokens.get(); return new BindPattern(static_cast(T0)); default: - throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeKind::Identifier }); + throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeKind::Identifier, NodeKind::StringLiteral, NodeKind::IntegerLiteral }); } } @@ -269,7 +273,7 @@ after_constraints: case NodeKind::IntegerLiteral: case NodeKind::StringLiteral: Tokens.get(); - return new ConstantExpression(T0); + return new ConstantExpression(static_cast(T0)); default: throw UnexpectedTokenDiagnostic(File, T0, { NodeKind::MatchKeyword, NodeKind::Identifier, NodeKind::IdentifierAlt, NodeKind::IntegerLiteral, NodeKind::StringLiteral }); }