Some fixes and new pattern syntax

- Fix critical bug in deallocator
 - Add TuplePattern
 - Add ListPattern
 - Make Checker hold a ListType
This commit is contained in:
Sam Vervaeck 2023-05-31 12:38:29 +02:00
parent fa294b826e
commit 717a2a663a
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
8 changed files with 264 additions and 40 deletions

View file

@ -140,7 +140,9 @@ namespace bolt {
BindPattern,
LiteralPattern,
NamedPattern,
TuplePattern,
NestedPattern,
ListPattern,
ReferenceExpression,
MatchCase,
MatchExpression,
@ -1244,6 +1246,27 @@ namespace bolt {
};
class TuplePattern : public Pattern {
public:
LParen* LParen;
std::vector<std::tuple<Pattern*, Comma*>> Elements;
RParen* RParen;
inline TuplePattern(
class LParen* LParen,
std::vector<std::tuple<Pattern*, Comma*>> Elements,
class RParen* RParen
): Pattern(NodeKind::TuplePattern),
LParen(LParen),
Elements(Elements),
RParen(RParen) {}
Token* getFirstToken() const override;
Token* getLastToken() const override;
};
class NestedPattern : public Pattern {
public:
@ -1265,6 +1288,28 @@ namespace bolt {
};
class ListPattern : public Pattern {
public:
class LBracket* LBracket;
std::vector<std::tuple<Pattern*, Comma*>> Elements;
class RBracket* RBracket;
inline ListPattern(
class LBracket* LBracket,
std::vector<std::tuple<Pattern*, Comma*>> Elements,
class RBracket* RBracket
): Pattern(NodeKind::ListPattern),
LBracket(LBracket),
Elements(Elements),
RBracket(RBracket) {}
Token* getFirstToken() const override;
Token* getLastToken() const override;
};
class Expression : public TypedNode {
protected:

View file

@ -70,7 +70,9 @@ namespace bolt {
BOLT_GEN_CASE(BindPattern)
BOLT_GEN_CASE(LiteralPattern)
BOLT_GEN_CASE(NamedPattern)
BOLT_GEN_CASE(TuplePattern)
BOLT_GEN_CASE(NestedPattern)
BOLT_GEN_CASE(ListPattern)
BOLT_GEN_CASE(ReferenceExpression)
BOLT_GEN_CASE(MatchCase)
BOLT_GEN_CASE(MatchExpression)
@ -334,10 +336,18 @@ namespace bolt {
visitPattern(N);
}
void visitTuplePattern(TuplePattern* N) {
visitPattern(N);
}
void visitNestedPattern(NestedPattern* N) {
visitPattern(N);
}
void visitListPattern(ListPattern* N) {
visitPattern(N);
}
void visitExpression(Expression* N) {
visitNode(N);
}
@ -536,7 +546,9 @@ namespace bolt {
BOLT_GEN_CHILD_CASE(BindPattern)
BOLT_GEN_CHILD_CASE(LiteralPattern)
BOLT_GEN_CHILD_CASE(NamedPattern)
BOLT_GEN_CHILD_CASE(TuplePattern)
BOLT_GEN_CHILD_CASE(NestedPattern)
BOLT_GEN_CHILD_CASE(ListPattern)
BOLT_GEN_CHILD_CASE(ReferenceExpression)
BOLT_GEN_CHILD_CASE(MatchCase)
BOLT_GEN_CHILD_CASE(MatchExpression)
@ -774,12 +786,34 @@ namespace bolt {
}
}
void visitEachChild(TuplePattern* N) {
BOLT_VISIT(N->LParen);
for (auto [P, Comma]: N->Elements) {
BOLT_VISIT(P);
if (Comma) {
BOLT_VISIT(Comma);
}
}
BOLT_VISIT(N->RParen);
}
void visitEachChild(NestedPattern* N) {
BOLT_VISIT(N->LParen);
BOLT_VISIT(N->P);
BOLT_VISIT(N->RParen);
}
void visitEachChild(ListPattern* N) {
BOLT_VISIT(N->LBracket);
for (auto [Element, Separator]: N->Elements) {
BOLT_VISIT(Element);
if (Separator) {
BOLT_VISIT(Separator);
}
}
BOLT_VISIT(N->RBracket);
}
void visitEachChild(ReferenceExpression* N) {
for (auto [Name, Dot]: N->ModulePath) {
BOLT_VISIT(Name);

View file

@ -175,6 +175,7 @@ namespace bolt {
size_t NextTypeVarId = 0;
Type* BoolType;
Type* ListType;
Type* IntType;
Type* StringType;

View file

@ -103,8 +103,10 @@ namespace bolt {
TypeExpression* parseTypeExpression();
Pattern* parsePrimitivePattern();
Pattern* parsePattern();
ListPattern* parseListPattern();
Pattern* parsePrimitivePattern(bool IsNarrow);
Pattern* parseWidePattern();
Pattern* parseNarrowPattern();
Parameter* parseParam();

View file

@ -163,6 +163,22 @@ namespace bolt {
visitPattern(Y->P, Decl);
break;
}
case NodeKind::TuplePattern:
{
auto Y = static_cast<TuplePattern*>(X);
for (auto [Element, Comma]: Y->Elements) {
visitPattern(Element, Decl);
}
break;
}
case NodeKind::ListPattern:
{
auto Y = static_cast<ListPattern*>(X);
for (auto [Element, Separator]: Y->Elements) {
visitPattern(Element, Decl);
}
break;
}
case NodeKind::LiteralPattern:
break;
default:
@ -278,8 +294,8 @@ namespace bolt {
struct UnrefVisitor : public CSTVisitor<UnrefVisitor> {
void visit(Node* N) {
N->unref();
visitEachChild(N);
N->unref();
}
};
@ -412,6 +428,14 @@ namespace bolt {
return Name;
}
Token* TuplePattern::getFirstToken() const {
return LParen;
}
Token* TuplePattern::getLastToken() const {
return RParen;
}
Token* NestedPattern::getFirstToken() const {
return LParen;
}
@ -420,6 +444,14 @@ namespace bolt {
return RParen;
}
Token* ListPattern::getFirstToken() const {
return LBracket;
}
Token* ListPattern::getLastToken() const {
return RBracket;
}
Token* ReferenceExpression::getFirstToken() const {
if (!ModulePath.empty()) {
return std::get<0>(ModulePath.front());

View file

@ -18,6 +18,8 @@
// TODO Add a pattern that only performs a type assert
// TODO create the constraint in addConstraint, not the other way round
#include <algorithm>
#include <iterator>
#include <stack>
@ -101,6 +103,7 @@ namespace bolt {
BoolType = createConType("Bool");
IntType = createConType("Int");
StringType = createConType("String");
ListType = createConType("List");
}
Scheme* Checker::lookup(ByteString Name) {
@ -1075,6 +1078,26 @@ namespace bolt {
return RetTy;
}
case NodeKind::TuplePattern:
{
auto P = static_cast<TuplePattern*>(Pattern);
std::vector<Type*> ElementTypes;
for (auto [Element, Comma]: P->Elements) {
ElementTypes.push_back(inferPattern(Element));
}
return new TTuple(ElementTypes);
}
case NodeKind::ListPattern:
{
auto P = static_cast<ListPattern*>(Pattern);
auto ElementType = createTypeVar();
for (auto [Element, Separator]: P->Elements) {
addConstraint(new CEqual(ElementType, inferPattern(Element), P));
}
return new TApp(ListType, ElementType);
}
case NodeKind::NestedPattern:
{
auto P = static_cast<NestedPattern*>(Pattern);
@ -1292,6 +1315,7 @@ namespace bolt {
addBinding("String", new Forall(StringType));
addBinding("Int", new Forall(IntType));
addBinding("Bool", new Forall(BoolType));
addBinding("List", new Forall(ListType));
addBinding("True", new Forall(BoolType));
addBinding("False", new Forall(BoolType));
auto A = createTypeVar();

View file

@ -128,6 +128,8 @@ namespace bolt {
return "an if-statement";
case NodeKind::IfStatementPart:
return "a branch of an if-statement";
case NodeKind::ListPattern:
return "a list pattern";
default:
ZEN_UNREACHABLE
}

View file

@ -102,7 +102,49 @@ namespace bolt {
return T;
}
Pattern* Parser::parsePrimitivePattern() {
ListPattern* Parser::parseListPattern() {
auto LBracket = expectToken<class LBracket>();
if (!LBracket) {
return nullptr;
}
std::vector<std::tuple<Pattern*, Comma*>> Elements;
RBracket* RBracket;
auto T0 = Tokens.peek();
if (T0->getKind() == NodeKind::RBracket) {
Tokens.get();
RBracket = static_cast<class RBracket*>(T0);
goto finish;
}
for (;;) {
auto P = parseWidePattern();
if (!P) {
LBracket->unref();
for (auto [Element, Separator]: Elements) {
Element->unref();
Separator->unref();
}
return nullptr;
}
auto T1 = Tokens.peek();
switch (T1->getKind()) {
case NodeKind::Comma:
Tokens.get();
Elements.push_back(std::make_tuple(P, static_cast<Comma*>(T1)));
break;
case NodeKind::RBracket:
Tokens.get();
Elements.push_back(std::make_tuple(P, nullptr));
RBracket = static_cast<class RBracket*>(T1);
goto finish;
default:
DE.add<UnexpectedTokenDiagnostic>(File, T1, std::vector { NodeKind::Comma, NodeKind::RBracket });
}
}
finish:
return new ListPattern { LBracket, Elements, RBracket };
}
Pattern* Parser::parsePrimitivePattern(bool IsNarrow) {
auto T0 = Tokens.peek();
switch (T0->getKind()) {
case NodeKind::StringLiteral:
@ -113,28 +155,28 @@ namespace bolt {
Tokens.get();
return new BindPattern(static_cast<Identifier*>(T0));
case NodeKind::IdentifierAlt:
Tokens.get();
return new NamedPattern(static_cast<IdentifierAlt*>(T0), {});
case NodeKind::LParen:
{
Tokens.get();
auto LParen = static_cast<class LParen*>(T0);
auto T1 = Tokens.peek();
RParen* RParen;
if (T1->getKind() == NodeKind::IdentifierAlt) {
Tokens.get();
auto Name = static_cast<IdentifierAlt*>(T1);
auto Name = static_cast<IdentifierAlt*>(T0);
if (IsNarrow) {
return new NamedPattern(Name, {});
}
std::vector<Pattern*> Patterns;
for (;;) {
auto T2 = Tokens.peek();
if (T2->getKind() == NodeKind::RParen) {
Tokens.get();
RParen = static_cast<class RParen*>(T2);
if (T2->getKind() == NodeKind::RParen
|| T2->getKind() == NodeKind::RBracket
|| T2->getKind() == NodeKind::RBrace
|| T2->getKind() == NodeKind::Comma
|| T2->getKind() == NodeKind::Colon
|| T2->getKind() == NodeKind::Equals
|| T2->getKind() == NodeKind::BlockStart
|| T2->getKind() == NodeKind::RArrowAlt) {
break;
}
auto P = parsePrimitivePattern();
auto P = parseNarrowPattern();
if (!P) {
LParen->unref();
Name->unref();
for (auto P: Patterns) {
P->unref();
}
@ -142,31 +184,73 @@ namespace bolt {
}
Patterns.push_back(P);
}
return new NestedPattern { LParen, new NamedPattern { Name, Patterns }, RParen };
} else {
auto P = parsePattern();
return new NamedPattern { Name, Patterns };
}
case NodeKind::LBracket:
return parseListPattern();
case NodeKind::LParen:
{
Tokens.get();
auto LParen = static_cast<class LParen*>(T0);
std::vector<std::tuple<Pattern*, Comma*>> Elements;
RParen* RParen;
for (;;) {
auto P = parseWidePattern();
if (!P) {
LParen->unref();
return nullptr;
}
auto RParen = expectToken<class RParen>();
if (!RParen) {
LParen->unref();
for (auto [P, Comma]: Elements) {
P->unref();
Comma->unref();
}
// TODO maybe skip to next comma?
return nullptr;
}
return new NestedPattern { LParen, P, RParen };
auto T1 = Tokens.peek();
if (T1->getKind() == NodeKind::Comma) {
Tokens.get();
Elements.push_back(std::make_tuple(P, static_cast<Comma*>(T1)));
} else if (T1->getKind() == NodeKind::RParen) {
Tokens.get();
RParen = static_cast<class RParen*>(T1);
Elements.push_back(std::make_tuple(P, nullptr));
break;
} else {
DE.add<UnexpectedTokenDiagnostic>(File, T1, std::vector { NodeKind::Comma, NodeKind::RParen });
LParen->unref();
for (auto [P, Comma]: Elements) {
P->unref();
Comma->unref();
}
// TODO maybe skip to next comma?
return nullptr;
}
}
if (Elements.size() == 1) {
return new NestedPattern { LParen, std::get<0>(Elements.front()), RParen };
}
return new TuplePattern(LParen, Elements, RParen);
}
default:
// Tokens.get();
DE.add<UnexpectedTokenDiagnostic>(File, T0, std::vector { NodeKind::Identifier, NodeKind::IdentifierAlt, NodeKind::StringLiteral, NodeKind::IntegerLiteral });
DE.add<UnexpectedTokenDiagnostic>(File, T0, std::vector {
NodeKind::Identifier,
NodeKind::IdentifierAlt,
NodeKind::StringLiteral,
NodeKind::IntegerLiteral,
NodeKind::LParen,
NodeKind::LBracket
});
return nullptr;
}
}
Pattern* Parser::parsePattern() {
return parsePrimitivePattern();
Pattern* Parser::parseWidePattern() {
return parsePrimitivePattern(false);
}
Pattern* Parser::parseNarrowPattern() {
return parsePrimitivePattern(true);
}
TypeExpression* Parser::parseTypeExpression() {
@ -431,7 +515,7 @@ after_tuple_element:
Tokens.get()->unref();
break;
}
auto Pattern = parsePattern();
auto Pattern = parseWidePattern();
if (!Pattern) {
skipToLineFoldEnd();
continue;
@ -863,7 +947,7 @@ VariableDeclaration* Parser::parseVariableDeclaration() {
Tokens.get();
}
auto P = parsePattern();
auto P = parseWidePattern();
if (!P) {
if (Pub) {
Pub->unref();
@ -993,7 +1077,7 @@ finish:
case NodeKind::Colon:
goto after_params;
default:
auto P = parsePattern();
auto P = parseNarrowPattern();
if (!P) {
Tokens.get();
P = new BindPattern(new Identifier("_"));