Add support for record type expressions

This commit is contained in:
Sam Vervaeck 2024-01-21 08:51:50 +01:00
parent 1b5c32fe29
commit 00dbada7ac
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
6 changed files with 250 additions and 5 deletions

View file

@ -87,6 +87,7 @@ namespace bolt {
};
enum class NodeKind {
VBar,
Equals,
Colon,
Comma,
@ -132,6 +133,8 @@ namespace bolt {
TypeAssertAnnotation,
TypeclassConstraintExpression,
EqualityConstraintExpression,
RecordTypeExpressionField,
RecordTypeExpression,
QualifiedTypeExpression,
ReferenceTypeExpression,
ArrowTypeExpression,
@ -363,6 +366,20 @@ namespace bolt {
};
class VBar : public Token {
public:
inline VBar(TextLoc StartLoc):
Token(NodeKind::VBar, StartLoc) {}
std::string getText() const override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::VBar;
}
};
class Colon : public Token {
public:
@ -1085,6 +1102,54 @@ namespace bolt {
};
class RecordTypeExpressionField : public Node {
public:
Identifier* Name;
Colon* Colon;
TypeExpression* TE;
inline RecordTypeExpressionField(
Identifier* Name,
class Colon* Colon,
TypeExpression* TE
): Node(NodeKind::RecordTypeExpressionField),
Name(Name),
Colon(Colon),
TE(TE) {}
Token* getFirstToken() const override;
Token* getLastToken() const override;
};
class RecordTypeExpression : public TypeExpression {
public:
LBrace* LBrace;
std::vector<std::tuple<RecordTypeExpressionField*, Comma*>> Fields;
VBar* VBar;
TypeExpression* Rest;
RBrace* RBrace;
inline RecordTypeExpression(
class LBrace* LBrace,
std::vector<std::tuple<RecordTypeExpressionField*, Comma*>> Fields,
class VBar* VBar,
TypeExpression* Rest,
class RBrace* RBrace
): TypeExpression(NodeKind::RecordTypeExpression),
LBrace(LBrace),
Fields(Fields),
VBar(VBar),
Rest(Rest),
RBrace(RBrace) {}
Token* getFirstToken() const override;
Token* getLastToken() const override;
};
class VarTypeExpression;
class TypeclassConstraintExpression : public ConstraintExpression {
@ -2079,7 +2144,7 @@ namespace bolt {
bool isVariable() const noexcept {
// Variables in classes and instances are never possible, so we reflect this by excluding them here.
return !isSignature() && !isClass() && !isInstance() && (Pattern->getKind() != NodeKind::BindPattern || !Body);
return !isSignature() && !isClass() && !isInstance() && Params.empty() && (Pattern->getKind() != NodeKind::BindPattern || !Body);
}
bool isFunction() const noexcept {

View file

@ -19,6 +19,7 @@ namespace bolt {
return static_cast<D*>(this)->visit ## name(static_cast<name*>(N));
switch (N->getKind()) {
BOLT_GEN_CASE(VBar)
BOLT_GEN_CASE(Equals)
BOLT_GEN_CASE(Colon)
BOLT_GEN_CASE(Comma)
@ -64,6 +65,8 @@ namespace bolt {
BOLT_GEN_CASE(TypeAssertAnnotation)
BOLT_GEN_CASE(TypeclassConstraintExpression)
BOLT_GEN_CASE(EqualityConstraintExpression)
BOLT_GEN_CASE(RecordTypeExpressionField)
BOLT_GEN_CASE(RecordTypeExpression)
BOLT_GEN_CASE(QualifiedTypeExpression)
BOLT_GEN_CASE(ReferenceTypeExpression)
BOLT_GEN_CASE(ArrowTypeExpression)
@ -121,6 +124,10 @@ namespace bolt {
static_cast<D*>(this)->visitNode(N);
}
void visitVBar(VBar* N) {
static_cast<D*>(this)->visitToken(N);
}
void visitEquals(Equals* N) {
static_cast<D*>(this)->visitToken(N);
}
@ -313,6 +320,14 @@ namespace bolt {
static_cast<D*>(this)->visitNode(N);
}
void visitRecordTypeExpressionField(RecordTypeExpressionField * N) {
static_cast<D*>(this)->visitNode(N);
}
void visitRecordTypeExpression(RecordTypeExpression* N) {
static_cast<D*>(this)->visitTypeExpression(N);
}
void visitQualifiedTypeExpression(QualifiedTypeExpression* N) {
static_cast<D*>(this)->visitTypeExpression(N);
}
@ -519,6 +534,7 @@ namespace bolt {
break;
switch (N->getKind()) {
BOLT_GEN_CHILD_CASE(VBar)
BOLT_GEN_CHILD_CASE(Equals)
BOLT_GEN_CHILD_CASE(Colon)
BOLT_GEN_CHILD_CASE(Comma)
@ -564,6 +580,8 @@ namespace bolt {
BOLT_GEN_CHILD_CASE(TypeAssertAnnotation)
BOLT_GEN_CHILD_CASE(TypeclassConstraintExpression)
BOLT_GEN_CHILD_CASE(EqualityConstraintExpression)
BOLT_GEN_CHILD_CASE(RecordTypeExpressionField)
BOLT_GEN_CHILD_CASE(RecordTypeExpression)
BOLT_GEN_CHILD_CASE(QualifiedTypeExpression)
BOLT_GEN_CHILD_CASE(ReferenceTypeExpression)
BOLT_GEN_CHILD_CASE(ArrowTypeExpression)
@ -613,6 +631,9 @@ namespace bolt {
#define BOLT_VISIT(node) static_cast<D*>(this)->visit(node)
void visitEachChild(VBar* N) {
}
void visitEachChild(Equals* N) {
}
@ -760,6 +781,29 @@ namespace bolt {
BOLT_VISIT(N->Right);
}
void visitEachChild(RecordTypeExpressionField* N) {
BOLT_VISIT(N->Name);
BOLT_VISIT(N->Colon);
BOLT_VISIT(N->TE);
}
void visitEachChild(RecordTypeExpression* N) {
BOLT_VISIT(N->LBrace);
for (auto [Field, Comma]: N->Fields) {
BOLT_VISIT(Field);
if (Comma) {
BOLT_VISIT(Comma);
}
}
if (N->VBar) {
BOLT_VISIT(N->VBar);
}
if (N->Rest) {
BOLT_VISIT(N->Rest);
}
BOLT_VISIT(N->RBrace);
}
void visitEachChild(QualifiedTypeExpression* N) {
for (auto [CE, Comma]: N->Constraints) {
BOLT_VISIT(CE);

View file

@ -393,6 +393,22 @@ namespace bolt {
return Left->getLastToken();
}
Token* RecordTypeExpressionField::getFirstToken() const {
return Name;
}
Token* RecordTypeExpressionField::getLastToken() const {
return TE->getLastToken();
}
Token* RecordTypeExpression::getFirstToken() const {
return LBrace;
}
Token* RecordTypeExpression::getLastToken() const {
return RBrace;
}
Token* QualifiedTypeExpression::getFirstToken() const {
if (!Constraints.empty()) {
return std::get<0>(Constraints.front())->getFirstToken();
@ -840,6 +856,10 @@ namespace bolt {
return nullptr;
}
std::string VBar::getText() const {
return "|";
}
std::string Equals::getText() const {
return "=";
}

View file

@ -677,9 +677,12 @@ namespace bolt {
case NodeKind::LetDeclaration:
{
// Function declarations are handled separately in inferLetDeclaration()
// Function declarations are handled separately in inferFunctionDeclaration()
auto Decl = static_cast<LetDeclaration*>(N);
if (Decl->isFunction() && !Decl->Visited) {
if (Decl->Visited) {
break;
}
if (Decl->isFunction()) {
Decl->IsCycleActive = true;
Decl->Visited = true;
inferFunctionDeclaration(Decl);
@ -854,6 +857,17 @@ namespace bolt {
return Ty;
}
case NodeKind::RecordTypeExpression:
{
auto RecTE = static_cast<RecordTypeExpression*>(N);
auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, IsPoly) : new Type(TNil());
for (auto [Field, Comma]: RecTE->Fields) {
Ty = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(inferTypeExpression(Field->TE, IsPoly))), Ty));
}
N->setType(Ty);
return Ty;
}
case NodeKind::TupleTypeExpression:
{
auto TupleTE = static_cast<TupleTypeExpression*>(N);

View file

@ -1,6 +1,7 @@
// TODO check for memory leaks everywhere a nullptr is returned
#include <tuple>
#include <vector>
#include "bolt/Common.hpp"
@ -385,6 +386,9 @@ after_constraints:
LParen->unref();
for (auto [CE, Comma]: Constraints) {
CE->unref();
if (Comma) {
Comma->unref();
}
}
RParen->unref();
RArrowAlt->unref();
@ -398,6 +402,93 @@ after_constraints:
switch (T0->getKind()) {
case NodeKind::Identifier:
return parseVarTypeExpression();
case NodeKind::LBrace:
{
Tokens.get();
auto LBrace = static_cast<class LBrace*>(T0);
std::vector<std::tuple<RecordTypeExpressionField*, Comma*>> Fields;
VBar* VBar = nullptr;
TypeExpression* Rest = nullptr;
for (;;) {
auto T1 = Tokens.peek();
if (T1->getKind() == NodeKind::RBrace) {
break;
}
auto Name = expectToken<Identifier>();
if (Name == nullptr) {
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
return nullptr;
}
auto Colon = expectToken<class Colon>();
if (Colon == nullptr) {
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
Name->unref();
return nullptr;
}
auto TE = parseTypeExpression();
if (TE == nullptr) {
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
Name->unref();
Colon->unref();
return nullptr;
}
auto Field = new RecordTypeExpressionField(Name, Colon, TE);
auto T3 = Tokens.peek();
if (T3->getKind() == NodeKind::RBrace) {
Fields.push_back(std::make_tuple(Field, nullptr));
break;
}
if (T3->getKind() == NodeKind::VBar) {
Tokens.get();
VBar = static_cast<class VBar*>(T3);
Rest = parseTypeExpression();
if (!Rest) {
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
Field->unref();
return nullptr;
}
auto T4 = Tokens.peek();
if (T4->getKind() != NodeKind::RBrace) {
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
Field->unref();
Rest->unref();
DE.add<UnexpectedTokenDiagnostic>(File, T4, std::vector { NodeKind::RBrace });
return nullptr;
}
break;
}
if (T3->getKind() == NodeKind::Comma) {
Tokens.get();
auto Comma = static_cast<class Comma*>(T3);
Fields.push_back(std::make_tuple(Field, Comma));
continue;
}
DE.add<UnexpectedTokenDiagnostic>(File, T3, std::vector { NodeKind::RBrace, NodeKind::Comma, NodeKind::VBar });
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
Field->unref();
return nullptr;
}
auto RBrace = static_cast<class RBrace*>(Tokens.get());
return new RecordTypeExpression(LBrace, Fields, VBar, Rest, RBrace);
}
case NodeKind::LParen:
{
Tokens.get();
@ -488,7 +579,16 @@ after_tuple_element:
for (;;) {
auto T1 = Tokens.peek();
auto Kind = T1->getKind();
if (Kind == NodeKind::RArrow || Kind == NodeKind::Equals || Kind == NodeKind::BlockStart || Kind == NodeKind::LineFoldEnd || Kind == NodeKind::EndOfFile || Kind == NodeKind::RParen) {
if (Kind == NodeKind::Comma
|| Kind == NodeKind::RArrow
|| Kind == NodeKind::Equals
|| Kind == NodeKind::BlockStart
|| Kind == NodeKind::LineFoldEnd
|| Kind == NodeKind::EndOfFile
|| Kind == NodeKind::RParen
|| Kind == NodeKind::RBracket
|| Kind == NodeKind::RBrace
|| Kind == NodeKind::VBar) {
break;
}
auto TE = parsePrimitiveTypeExpression();

View file

@ -379,7 +379,9 @@ after_string_contents:
Text.push_back(static_cast<char>(C1));
getChar();
}
if (Text == "->") {
if (Text == "|") {
return new VBar(StartLoc);
} else if (Text == "->") {
return new RArrow(StartLoc);
} else if (Text == "=>") {
return new RArrowAlt(StartLoc);