Add NamedRecordPattern and rename NamedPattern to NamedTuplePattern

This commit is contained in:
Sam Vervaeck 2024-01-21 05:40:35 +01:00
parent 7ac3c39164
commit 1b5c32fe29
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
7 changed files with 353 additions and 63 deletions

View file

@ -141,7 +141,9 @@ namespace bolt {
TupleTypeExpression,
BindPattern,
LiteralPattern,
NamedPattern,
RecordPatternField,
NamedRecordPattern,
NamedTuplePattern,
TuplePattern,
NestedPattern,
ListPattern,
@ -1313,16 +1315,68 @@ namespace bolt {
};
class NamedPattern : public Pattern {
class RecordPatternField : public Node {
public:
Identifier* Name;
Equals* Equals;
Pattern* Pattern;
inline RecordPatternField(
Identifier* Name,
class Equals* Equals,
class Pattern* Pattern
): Node(NodeKind::RecordPatternField),
Name(Name),
Equals(Equals),
Pattern(Pattern) {}
inline RecordPatternField(
Identifier* Name
): RecordPatternField(Name, nullptr, nullptr) {}
Token* getFirstToken() const override;
Token* getLastToken() const override;
};
class NamedRecordPattern : public Pattern {
public:
std::vector<std::tuple<IdentifierAlt*, Dot*>> ModulePath;
IdentifierAlt* Name;
LBrace* LBrace;
std::vector<std::tuple<RecordPatternField*, Comma*>> Fields;
RBrace* RBrace;
inline NamedRecordPattern(
std::vector<std::tuple<IdentifierAlt*, Dot*>> ModulePath,
IdentifierAlt* Name,
class LBrace* LBrace,
std::vector<std::tuple<RecordPatternField*, Comma*>> Fields,
class RBrace* RBrace
): Pattern(NodeKind::NamedRecordPattern),
ModulePath(ModulePath),
Name(Name),
LBrace(LBrace),
Fields(Fields),
RBrace(RBrace) {}
Token* getFirstToken() const override;
Token* getLastToken() const override;
};
class NamedTuplePattern : public Pattern {
public:
IdentifierAlt* Name;
std::vector<Pattern*> Patterns;
inline NamedPattern(
inline NamedTuplePattern(
IdentifierAlt* Name,
std::vector<Pattern*> Patterns
): Pattern(NodeKind::NamedPattern),
): Pattern(NodeKind::NamedTuplePattern),
Name(Name),
Patterns(Patterns) {}

View file

@ -73,7 +73,9 @@ namespace bolt {
BOLT_GEN_CASE(TupleTypeExpression)
BOLT_GEN_CASE(BindPattern)
BOLT_GEN_CASE(LiteralPattern)
BOLT_GEN_CASE(NamedPattern)
BOLT_GEN_CASE(RecordPatternField)
BOLT_GEN_CASE(NamedRecordPattern)
BOLT_GEN_CASE(NamedTuplePattern)
BOLT_GEN_CASE(TuplePattern)
BOLT_GEN_CASE(NestedPattern)
BOLT_GEN_CASE(ListPattern)
@ -351,7 +353,15 @@ namespace bolt {
static_cast<D*>(this)->visitPattern(N);
}
void visitNamedPattern(NamedPattern* N) {
void visitRecordPatternField(RecordPatternField* N) {
static_cast<D*>(this)->visitNode(N);
}
void visitNamedRecordPattern(NamedRecordPattern* N) {
static_cast<D*>(this)->visitPattern(N);
}
void visitNamedTuplePattern(NamedTuplePattern* N) {
static_cast<D*>(this)->visitPattern(N);
}
@ -563,7 +573,9 @@ namespace bolt {
BOLT_GEN_CHILD_CASE(TupleTypeExpression)
BOLT_GEN_CHILD_CASE(BindPattern)
BOLT_GEN_CHILD_CASE(LiteralPattern)
BOLT_GEN_CHILD_CASE(NamedPattern)
BOLT_GEN_CHILD_CASE(RecordPatternField)
BOLT_GEN_CHILD_CASE(NamedRecordPattern)
BOLT_GEN_CHILD_CASE(NamedTuplePattern)
BOLT_GEN_CHILD_CASE(TuplePattern)
BOLT_GEN_CHILD_CASE(NestedPattern)
BOLT_GEN_CHILD_CASE(ListPattern)
@ -810,7 +822,36 @@ namespace bolt {
BOLT_VISIT(N->Literal);
}
void visitEachChild(NamedPattern* N) {
void visitEachChild(RecordPatternField* N) {
BOLT_VISIT(N->Name);
if (N->Equals) {
BOLT_VISIT(N->Equals);
}
if (N->Pattern) {
BOLT_VISIT(N->Pattern);
}
}
void visitEachChild(NamedRecordPattern* N) {
for (auto [Name, Dot]: N->ModulePath) {
BOLT_VISIT(Name);
if (Dot) {
BOLT_VISIT(Dot);
}
}
BOLT_VISIT(N->Name);
BOLT_VISIT(N->LBrace);
for (auto [Field, Comma]: N->Fields) {
BOLT_VISIT(Field);
if (Comma) {
BOLT_VISIT(Comma);
}
}
BOLT_VISIT(N->LBrace);
BOLT_VISIT(N->RBrace);
}
void visitEachChild(NamedTuplePattern* N) {
BOLT_VISIT(N->Name);
for (auto P: N->Patterns) {
BOLT_VISIT(P);

View file

@ -71,7 +71,8 @@ namespace bolt {
Token* expectToken(NodeKind Ty);
std::vector<RecordDeclarationField*> parseRecordFields();
std::vector<RecordDeclarationField*> parseRecordDeclarationFields();
std::optional<std::vector<std::tuple<RecordPatternField*, Comma*>>> parseRecordPatternFields();
template<typename T>
T* expectToken() {
@ -97,7 +98,8 @@ namespace bolt {
std::vector<Annotation*> parseAnnotations();
void checkLineFoldEnd();
void skipToLineFoldEnd();
void skipPastLineFoldEnd();
void skipToRBrace();
public:

View file

@ -3,6 +3,7 @@
#include "bolt/CST.hpp"
#include "bolt/CSTVisitor.hpp"
#include <variant>
namespace bolt {
@ -173,9 +174,21 @@ namespace bolt {
addSymbol(Y->Name->Text, Decl, SymbolKind::Var);
break;
}
case NodeKind::NamedPattern:
case NodeKind::NamedRecordPattern:
{
auto Y = static_cast<NamedPattern*>(X);
auto Y = static_cast<NamedRecordPattern*>(X);
for (auto [Field, Comma]: Y->Fields) {
if (Field->Pattern) {
visitPattern(Field->Pattern, Decl);
} else {
addSymbol(Field->Name->Text, Decl, SymbolKind::Var);
}
}
break;
}
case NodeKind::NamedTuplePattern:
{
auto Y = static_cast<NamedTuplePattern*>(X);
for (auto P: Y->Patterns) {
visitPattern(P, Decl);
}
@ -464,11 +477,36 @@ namespace bolt {
return Literal;
}
Token* NamedPattern::getFirstToken() const {
Token* RecordPatternField::getFirstToken() const {
return Name;
}
Token* NamedPattern::getLastToken() const {
Token* RecordPatternField::getLastToken() const {
if (Pattern) {
return Pattern->getLastToken();
}
if (Equals) {
return Equals;
}
return Name;
}
Token* NamedRecordPattern::getFirstToken() const {
if (!ModulePath.empty()) {
return std::get<0>(ModulePath.back());
}
return Name;
}
Token* NamedRecordPattern::getLastToken() const {
return RBrace;
}
Token* NamedTuplePattern::getFirstToken() const {
return Name;
}
Token* NamedTuplePattern::getLastToken() const {
if (Patterns.size()) {
return Patterns.back()->getLastToken();
}

View file

@ -1113,13 +1113,13 @@ namespace bolt {
return Ty;
}
case NodeKind::NamedPattern:
case NodeKind::NamedTuplePattern:
{
auto P = static_cast<NamedPattern*>(Pattern);
auto P = static_cast<NamedTuplePattern*>(Pattern);
auto Scm = lookup(P->Name->getCanonicalText(), SymKind::Var);
std::vector<Type*> ParamTypes;
std::vector<Type*> ElementTypes;
for (auto P2: P->Patterns) {
ParamTypes.push_back(inferPattern(P2, Constraints, TVs));
ElementTypes.push_back(inferPattern(P2, Constraints, TVs));
}
if (!Scm) {
DE.add<BindingNotFoundDiagnostic>(P->Name->getCanonicalText(), P->Name);
@ -1127,7 +1127,32 @@ namespace bolt {
}
auto Ty = instantiate(Scm, P);
auto RetTy = createTypeVar();
makeEqual(Ty, Type::buildArrow(ParamTypes, RetTy), P);
makeEqual(Ty, Type::buildArrow(ElementTypes, RetTy), P);
return RetTy;
}
case NodeKind::NamedRecordPattern:
{
auto P = static_cast<NamedRecordPattern*>(Pattern);
auto Scm = lookup(P->Name->getCanonicalText(), SymKind::Var);
if (Scm == nullptr) {
DE.add<BindingNotFoundDiagnostic>(P->Name->getCanonicalText(), P->Name);
return createTypeVar();
}
auto RecordTy = new Type(TNil());
for (auto [Field, Comma]: P->Fields) {
Type* FieldTy;
if (Field->Pattern) {
FieldTy = inferPattern(Field->Pattern, Constraints, TVs);
} else {
FieldTy = createTypeVar();
addBinding(Field->Name->getCanonicalText(), new Forall(TVs, Constraints, FieldTy), SymKind::Var);
}
RecordTy = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(FieldTy)), RecordTy));
}
auto Ty = instantiate(Scm, P);
auto RetTy = createTypeVar();
makeEqual(Ty, new Type(TArrow(RecordTy, RetTy)), P);
return RetTy;
}

View file

@ -89,7 +89,7 @@ namespace bolt {
return "a tuple type expression";
case NodeKind::BindPattern:
return "a variable binder";
case NodeKind::NamedPattern:
case NodeKind::NamedTuplePattern:
return "a pattern for a variant member";
case NodeKind::TuplePattern:
return "a pattern for a tuple";

View file

@ -150,6 +150,39 @@ finish:
return new ListPattern { LBracket, Elements, RBracket };
}
std::optional<std::vector<std::tuple<RecordPatternField*, Comma*>>> Parser::parseRecordPatternFields() {
std::vector<std::tuple<RecordPatternField*, Comma*>> Fields;
for (;;) {
auto T0 = Tokens.peek();
if (T0->getKind() == NodeKind::RBrace) {
break;
}
auto Name = expectToken<Identifier>();
Equals* Equals = nullptr;
Pattern* Pattern = nullptr;
auto T1 = Tokens.peek();
if (T1->getKind() == NodeKind::Equals) {
Tokens.get();
Equals = static_cast<class Equals*>(T1);
Pattern = parseWidePattern();
}
auto Field = new RecordPatternField(Name, Equals, Pattern);
auto T2 = Tokens.peek();
if (T2->getKind() == NodeKind::RBrace) {
Fields.push_back(std::make_tuple(Field, nullptr));
break;
}
if (T2->getKind() != NodeKind::Comma) {
DE.add<UnexpectedTokenDiagnostic>(File, T2, std::vector { NodeKind::RBrace, NodeKind::Comma });
return {};
}
Tokens.get();
auto Comma = static_cast<class Comma*>(T2);
Fields.push_back(std::make_tuple(Field, Comma));
}
return Fields;
}
Pattern* Parser::parsePrimitivePattern(bool IsNarrow) {
auto T0 = Tokens.peek();
switch (T0->getKind()) {
@ -165,7 +198,19 @@ finish:
Tokens.get();
auto Name = static_cast<IdentifierAlt*>(T0);
if (IsNarrow) {
return new NamedPattern(Name, {});
return new NamedTuplePattern(Name, {});
}
auto T1 = Tokens.peek();
if (T1->getKind() == NodeKind::LBrace) {
auto LBrace = static_cast<class LBrace*>(T1);
Tokens.get();
auto Fields = parseRecordPatternFields();
if (!Fields) {
skipToRBrace();
return nullptr;
}
auto RBrace = static_cast<class RBrace*>(Tokens.get());
return new NamedRecordPattern({}, Name, LBrace, *Fields, RBrace);
}
std::vector<Pattern*> Patterns;
for (;;) {
@ -190,7 +235,7 @@ finish:
}
Patterns.push_back(P);
}
return new NamedPattern { Name, Patterns };
return new NamedTuplePattern { Name, Patterns };
}
case NodeKind::LBracket:
return parseListPattern();
@ -523,20 +568,20 @@ after_tuple_element:
}
auto Pattern = parseWidePattern();
if (!Pattern) {
skipToLineFoldEnd();
skipPastLineFoldEnd();
continue;
}
auto RArrowAlt = expectToken<class RArrowAlt>();
if (!RArrowAlt) {
Pattern->unref();
skipToLineFoldEnd();
skipPastLineFoldEnd();
continue;
}
auto Expression = parseExpression();
if (!Expression) {
Pattern->unref();
RArrowAlt->unref();
skipToLineFoldEnd();
skipPastLineFoldEnd();
continue;
}
checkLineFoldEnd();
@ -853,7 +898,7 @@ finish:
ExpressionStatement* Parser::parseExpressionStatement() {
auto E = parseExpression();
if (!E) {
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
checkLineFoldEnd();
@ -875,7 +920,7 @@ finish:
Expression = parseExpression();
if (!Expression) {
ReturnKeyword->unref();
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
checkLineFoldEnd();
@ -890,14 +935,14 @@ finish:
auto Test = parseExpression();
if (!Test) {
IfKeyword->unref();
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
auto T1 = expectToken<BlockStart>();
if (!T1) {
IfKeyword->unref();
Test->unref();
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
std::vector<Node*> Then;
@ -980,7 +1025,7 @@ finish:
if (Foreign) {
Foreign->unref();
}
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
Let = static_cast<LetKeyword*>(T0);
@ -1002,7 +1047,7 @@ finish:
if (Mut) {
Mut->unref();
}
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
@ -1034,7 +1079,7 @@ after_params:
if (TE) {
TA = new TypeAssert(static_cast<Colon*>(T2), TE);
} else {
skipToLineFoldEnd();
skipPastLineFoldEnd();
goto finish;
}
T2 = Tokens.peek();
@ -1064,7 +1109,7 @@ after_params:
Tokens.get();
auto E = parseExpression();
if (!E) {
skipToLineFoldEnd();
skipPastLineFoldEnd();
goto finish;
}
Body = new LetExprBody(static_cast<Equals*>(T2), E);
@ -1195,13 +1240,13 @@ after_vars:
InstanceDeclaration* Parser::parseInstanceDeclaration() {
auto InstanceKeyword = expectToken<class InstanceKeyword>();
if (!InstanceKeyword) {
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
auto Name = expectToken<IdentifierAlt>();
if (!Name) {
InstanceKeyword->unref();
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
std::vector<TypeExpression*> TypeExps;
@ -1217,7 +1262,7 @@ after_vars:
for (auto TE: TypeExps) {
TE->unref();
}
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
TypeExps.push_back(TE);
@ -1229,7 +1274,7 @@ after_vars:
for (auto TE: TypeExps) {
TE->unref();
}
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
std::vector<Node*> Elements;
@ -1266,7 +1311,7 @@ after_vars:
if (PubKeyword) {
PubKeyword->unref();
}
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
auto Name = expectToken<IdentifierAlt>();
@ -1275,7 +1320,7 @@ after_vars:
PubKeyword->unref();
}
ClassKeyword->unref();
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
std::vector<VarTypeExpression*> TypeVars;
@ -1293,7 +1338,7 @@ after_vars:
for (auto TV: TypeVars) {
TV->unref();
}
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
TypeVars.push_back(TE);
@ -1307,7 +1352,7 @@ after_vars:
for (auto TV: TypeVars) {
TV->unref();
}
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
std::vector<Node*> Elements;
@ -1333,7 +1378,7 @@ after_vars:
);
}
std::vector<RecordDeclarationField*> Parser::parseRecordFields() {
std::vector<RecordDeclarationField*> Parser::parseRecordDeclarationFields() {
std::vector<RecordDeclarationField*> Fields;
for (;;) {
auto T1 = Tokens.peek();
@ -1343,20 +1388,20 @@ after_vars:
}
auto Name = expectToken<Identifier>();
if (!Name) {
skipToLineFoldEnd();
skipPastLineFoldEnd();
continue;
}
auto Colon = expectToken<class Colon>();
if (!Colon) {
Name->unref();
skipToLineFoldEnd();
skipPastLineFoldEnd();
continue;
}
auto TE = parseTypeExpression();
if (!TE) {
Name->unref();
Colon->unref();
skipToLineFoldEnd();
skipPastLineFoldEnd();
continue;
}
checkLineFoldEnd();
@ -1377,7 +1422,7 @@ after_vars:
if (Pub) {
Pub->unref();
}
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
auto Name = expectToken<IdentifierAlt>();
@ -1386,7 +1431,7 @@ after_vars:
Pub->unref();
}
Struct->unref();
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
std::vector<VarTypeExpression*> Vars;
@ -1407,10 +1452,10 @@ after_vars:
}
Struct->unref();
Name->unref();
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
auto Fields = parseRecordFields();
auto Fields = parseRecordDeclarationFields();
Tokens.get()->unref(); // Always a LineFoldEnd
return new RecordDeclaration { Pub, Struct, Name, Vars, BS, Fields };
}
@ -1427,7 +1472,7 @@ after_vars:
if (Pub) {
Pub->unref();
}
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
auto Name = expectToken<IdentifierAlt>();
@ -1436,7 +1481,7 @@ after_vars:
Pub->unref();
}
Enum->unref();
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
std::vector<VarTypeExpression*> TVs;
@ -1457,7 +1502,7 @@ after_vars:
}
Enum->unref();
Name->unref();
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
std::vector<VariantDeclarationMember*> Members;
@ -1470,14 +1515,14 @@ next_member:
}
auto Name = expectToken<IdentifierAlt>();
if (!Name) {
skipToLineFoldEnd();
skipPastLineFoldEnd();
continue;
}
auto T1 = Tokens.peek();
if (T1->getKind() == NodeKind::BlockStart) {
Tokens.get();
auto BS = static_cast<BlockStart*>(T1);
auto Fields = parseRecordFields();
auto Fields = parseRecordDeclarationFields();
// TODO continue; on error in Fields
Members.push_back(new RecordVariantDeclarationMember { Name, BS, Fields });
} else {
@ -1514,7 +1559,7 @@ next_member:
// TODO
default:
DE.add<UnexpectedTokenDiagnostic>(File, T0, std::vector<NodeKind> { NodeKind::LetKeyword, NodeKind::TypeKeyword });
skipToLineFoldEnd();
skipPastLineFoldEnd();
return nullptr;
}
}
@ -1584,7 +1629,7 @@ next_member:
auto E = parseExpression();
if (!E) {
At->unref();
skipToLineFoldEnd();
skipPastLineFoldEnd();
continue;
}
checkLineFoldEnd();
@ -1602,8 +1647,69 @@ next_annotation:;
return Annotations;
}
void Parser::skipToLineFoldEnd() {
unsigned Level = 0;
void Parser::skipToRBrace() {
unsigned ParenLevel = 0;
unsigned BracketLevel = 0;
unsigned BraceLevel = 0;
unsigned BlockLevel = 0;
for (;;) {
auto T0 = Tokens.peek();
switch (T0->getKind()) {
case NodeKind::EndOfFile:
return;
case NodeKind::LineFoldEnd:
Tokens.get()->unref();
if (BlockLevel == 0 && ParenLevel == 0 && BracketLevel == 0 && BlockLevel == 0) {
return;
}
break;
case NodeKind::BlockStart:
Tokens.get()->unref();
BlockLevel++;
break;
case NodeKind::BlockEnd:
Tokens.get()->unref();
BlockLevel--;
break;
case NodeKind::LParen:
Tokens.get()->unref();
ParenLevel++;
break;
case NodeKind::LBracket:
Tokens.get()->unref();
BracketLevel++;
break;
case NodeKind::LBrace:
Tokens.get()->unref();
BraceLevel++;
break;
case NodeKind::RParen:
Tokens.get()->unref();
ParenLevel--;
break;
case NodeKind::RBracket:
Tokens.get()->unref();
BracketLevel--;
break;
case NodeKind::RBrace:
if (BlockLevel == 0 && ParenLevel == 0 && BracketLevel == 0 && BlockLevel == 0) {
return;
}
Tokens.get()->unref();
BraceLevel--;
break;
default:
Tokens.get()->unref();
break;
}
}
}
void Parser::skipPastLineFoldEnd() {
unsigned ParenLevel = 0;
unsigned BracketLevel = 0;
unsigned BraceLevel = 0;
unsigned BlockLevel = 0;
for (;;) {
auto T0 = Tokens.get();
switch (T0->getKind()) {
@ -1611,17 +1717,41 @@ next_annotation:;
return;
case NodeKind::LineFoldEnd:
T0->unref();
if (Level == 0) {
if (BlockLevel == 0 && ParenLevel == 0 && BracketLevel == 0 && BlockLevel == 0) {
return;
}
break;
case NodeKind::BlockStart:
T0->unref();
Level++;
BlockLevel++;
break;
case NodeKind::BlockEnd:
T0->unref();
Level--;
BlockLevel--;
break;
case NodeKind::LParen:
T0->unref();
ParenLevel++;
break;
case NodeKind::LBracket:
T0->unref();
BracketLevel++;
break;
case NodeKind::LBrace:
T0->unref();
BraceLevel++;
break;
case NodeKind::RParen:
T0->unref();
ParenLevel--;
break;
case NodeKind::RBracket:
T0->unref();
BracketLevel--;
break;
case NodeKind::RBrace:
T0->unref();
BraceLevel--;
break;
default:
T0->unref();
@ -1636,7 +1766,7 @@ next_annotation:;
Tokens.get()->unref();
} else {
DE.add<UnexpectedTokenDiagnostic>(File, T0, std::vector { NodeKind::LineFoldEnd });
skipToLineFoldEnd();
skipPastLineFoldEnd();
}
}