Add support for literal patterns

This commit is contained in:
Sam Vervaeck 2023-05-21 17:36:44 +02:00
parent 17d21d234b
commit 56cbfc6fbe
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
6 changed files with 108 additions and 23 deletions

View file

@ -5,6 +5,7 @@
#include <istream>
#include <iterator>
#include <unordered_map>
#include <variant>
#include <vector>
#include "zen/config.hpp"
@ -70,6 +71,7 @@ namespace bolt {
ArrowTypeExpression,
VarTypeExpression,
BindPattern,
LiteralPattern,
ReferenceExpression,
MatchCase,
MatchExpression,
@ -760,32 +762,53 @@ namespace bolt {
};
class StringLiteral : public Token {
using Value = std::variant<ByteString, Integer>;
class Literal : public Token {
public:
inline Literal(NodeKind Kind, TextLoc StartLoc):
Token(Kind, StartLoc) {}
virtual Value getValue() = 0;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::StringLiteral
|| N->getKind() == NodeKind::IntegerLiteral;
}
};
class StringLiteral : public Literal {
public:
ByteString Text;
StringLiteral(ByteString Text, TextLoc StartLoc):
Token(NodeKind::StringLiteral, StartLoc), Text(Text) {}
Literal(NodeKind::StringLiteral, StartLoc), Text(Text) {}
std::string getText() const override;
Value getValue() override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::StringLiteral;
}
};
class IntegerLiteral : public Token {
class IntegerLiteral : public Literal {
public:
Integer Value;
Integer V;
IntegerLiteral(Integer Value, TextLoc StartLoc):
Token(NodeKind::IntegerLiteral, StartLoc), Value(Value) {}
Literal(NodeKind::IntegerLiteral, StartLoc), V(Value) {}
std::string getText() const override;
Value getValue() override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::IntegerLiteral;
}
@ -977,6 +1000,20 @@ namespace bolt {
};
class LiteralPattern : public Pattern {
public:
class Literal* Literal;
LiteralPattern(class Literal* Literal):
Pattern(NodeKind::LiteralPattern),
Literal(Literal) {}
Token* getFirstToken() override;
Token* getLastToken() override;
};
class Expression : public TypedNode {
protected:
@ -1074,10 +1111,10 @@ namespace bolt {
class ConstantExpression : public Expression {
public:
class Token* Token;
class Literal* Token;
ConstantExpression(
class Token* Token
class Literal* Token
): Expression(NodeKind::ConstantExpression),
Token(Token) {}

View file

@ -101,6 +101,8 @@ namespace bolt {
return static_cast<D*>(this)->visitVarTypeExpression(static_cast<VarTypeExpression*>(N));
case NodeKind::BindPattern:
return static_cast<D*>(this)->visitBindPattern(static_cast<BindPattern*>(N));
case NodeKind::LiteralPattern:
return static_cast<D*>(this)->visitLiteralPattern(static_cast<LiteralPattern*>(N));
case NodeKind::ReferenceExpression:
return static_cast<D*>(this)->visitReferenceExpression(static_cast<ReferenceExpression*>(N));
case NodeKind::MatchCase:
@ -350,6 +352,10 @@ namespace bolt {
visitPattern(N);
}
void visitLiteralPattern(LiteralPattern* N) {
visitPattern(N);
}
void visitExpression(Expression* N) {
visitNode(N);
}
@ -589,6 +595,9 @@ namespace bolt {
case NodeKind::BindPattern:
visitEachChild(static_cast<BindPattern*>(N));
break;
case NodeKind::LiteralPattern:
visitEachChild(static_cast<LiteralPattern*>(N));
break;
case NodeKind::ReferenceExpression:
visitEachChild(static_cast<ReferenceExpression*>(N));
break;
@ -823,6 +832,10 @@ namespace bolt {
BOLT_VISIT(N->Name);
}
void visitEachChild(LiteralPattern* N) {
BOLT_VISIT(N->Literal);
}
void visitEachChild(ReferenceExpression* N) {
for (auto [Name, Dot]: N->ModulePath) {
BOLT_VISIT(Name);

View file

@ -1,4 +1,5 @@
#pragma once
#include "zen/config.hpp"
@ -355,6 +356,7 @@ namespace bolt {
Type* inferExpression(Expression* Expression);
Type* inferTypeExpression(TypeExpression* TE);
Type* inferLiteral(Literal* Lit);
void inferBindings(Pattern* Pattern, Type* T, ConstraintSet* Constraints, TVSet* TVs);
void inferBindings(Pattern* Pattern, Type* T);

View file

@ -221,6 +221,14 @@ namespace bolt {
return Name;
}
Token* LiteralPattern::getFirstToken() {
return Literal;
}
Token* LiteralPattern::getLastToken() {
return Literal;
}
Token* ReferenceExpression::getFirstToken() {
if (!ModulePath.empty()) {
return std::get<0>(ModulePath.front());
@ -586,7 +594,7 @@ namespace bolt {
}
std::string IntegerLiteral::getText() const {
return std::to_string(Value);
return std::to_string(V);
}
std::string DotDot::getText() const {
@ -613,6 +621,14 @@ namespace bolt {
return Text;
}
Value StringLiteral::getValue() {
return Text;
}
Value IntegerLiteral::getValue() {
return V;
}
SymbolPath ReferenceExpression::getSymbolPath() const {
std::vector<ByteString> ModuleNames;
for (auto [Name, Dot]: ModulePath) {

View file

@ -723,18 +723,7 @@ namespace bolt {
case NodeKind::ConstantExpression:
{
auto Const = static_cast<ConstantExpression*>(X);
Type* Ty = nullptr;
switch (Const->Token->getKind()) {
case NodeKind::IntegerLiteral:
Ty = lookupMono("Int");
break;
case NodeKind::StringLiteral:
Ty = lookupMono("String");
break;
default:
ZEN_UNREACHABLE
}
ZEN_ASSERT(Ty != nullptr);
auto Ty = inferLiteral(Const->Token);
X->setType(Ty);
return Ty;
}
@ -815,7 +804,15 @@ namespace bolt {
case NodeKind::BindPattern:
{
addBinding(static_cast<BindPattern*>(Pattern)->Name->getCanonicalText(), new Forall(TVs, Constraints, Type));
auto P = static_cast<BindPattern*>(Pattern);
addBinding(P->Name->getCanonicalText(), new Forall(TVs, Constraints, Type));
break;
}
case NodeKind::LiteralPattern:
{
auto P = static_cast<LiteralPattern*>(Pattern);
addConstraint(new CEqual(inferLiteral(P->Literal), Type, P));
break;
}
@ -830,6 +827,22 @@ namespace bolt {
inferBindings(Pattern, Type, new ConstraintSet, new TVSet);
}
Type* Checker::inferLiteral(Literal* L) {
Type* Ty;
switch (L->getKind()) {
case NodeKind::IntegerLiteral:
Ty = lookupMono("Int");
break;
case NodeKind::StringLiteral:
Ty = lookupMono("String");
break;
default:
ZEN_UNREACHABLE
}
ZEN_ASSERT(Ty != nullptr);
return Ty;
}
void collectTypeclasses(LetDeclaration* Decl, std::vector<TypeclassSignature>& Out) {
if (llvm::isa<ClassDeclaration>(Decl->Parent)) {
auto Class = llvm::cast<ClassDeclaration>(Decl->Parent);

View file

@ -89,11 +89,15 @@ namespace bolt {
Pattern* Parser::parsePattern() {
auto T0 = Tokens.peek();
switch (T0->getKind()) {
case NodeKind::StringLiteral:
case NodeKind::IntegerLiteral:
Tokens.get();
return new LiteralPattern(static_cast<Literal*>(T0));
case NodeKind::Identifier:
Tokens.get();
return new BindPattern(static_cast<Identifier*>(T0));
default:
throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeKind::Identifier });
throw UnexpectedTokenDiagnostic(File, T0, std::vector { NodeKind::Identifier, NodeKind::StringLiteral, NodeKind::IntegerLiteral });
}
}
@ -269,7 +273,7 @@ after_constraints:
case NodeKind::IntegerLiteral:
case NodeKind::StringLiteral:
Tokens.get();
return new ConstantExpression(T0);
return new ConstantExpression(static_cast<Literal*>(T0));
default:
throw UnexpectedTokenDiagnostic(File, T0, { NodeKind::MatchKeyword, NodeKind::Identifier, NodeKind::IdentifierAlt, NodeKind::IntegerLiteral, NodeKind::StringLiteral });
}