Add support for record type expressions

This commit is contained in:
Sam Vervaeck 2024-01-21 08:51:50 +01:00
parent 1b5c32fe29
commit 00dbada7ac
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
6 changed files with 250 additions and 5 deletions

View file

@ -87,6 +87,7 @@ namespace bolt {
}; };
enum class NodeKind { enum class NodeKind {
VBar,
Equals, Equals,
Colon, Colon,
Comma, Comma,
@ -132,6 +133,8 @@ namespace bolt {
TypeAssertAnnotation, TypeAssertAnnotation,
TypeclassConstraintExpression, TypeclassConstraintExpression,
EqualityConstraintExpression, EqualityConstraintExpression,
RecordTypeExpressionField,
RecordTypeExpression,
QualifiedTypeExpression, QualifiedTypeExpression,
ReferenceTypeExpression, ReferenceTypeExpression,
ArrowTypeExpression, ArrowTypeExpression,
@ -363,6 +366,20 @@ namespace bolt {
}; };
class VBar : public Token {
public:
inline VBar(TextLoc StartLoc):
Token(NodeKind::VBar, StartLoc) {}
std::string getText() const override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::VBar;
}
};
class Colon : public Token { class Colon : public Token {
public: public:
@ -1085,6 +1102,54 @@ namespace bolt {
}; };
class RecordTypeExpressionField : public Node {
public:
Identifier* Name;
Colon* Colon;
TypeExpression* TE;
inline RecordTypeExpressionField(
Identifier* Name,
class Colon* Colon,
TypeExpression* TE
): Node(NodeKind::RecordTypeExpressionField),
Name(Name),
Colon(Colon),
TE(TE) {}
Token* getFirstToken() const override;
Token* getLastToken() const override;
};
class RecordTypeExpression : public TypeExpression {
public:
LBrace* LBrace;
std::vector<std::tuple<RecordTypeExpressionField*, Comma*>> Fields;
VBar* VBar;
TypeExpression* Rest;
RBrace* RBrace;
inline RecordTypeExpression(
class LBrace* LBrace,
std::vector<std::tuple<RecordTypeExpressionField*, Comma*>> Fields,
class VBar* VBar,
TypeExpression* Rest,
class RBrace* RBrace
): TypeExpression(NodeKind::RecordTypeExpression),
LBrace(LBrace),
Fields(Fields),
VBar(VBar),
Rest(Rest),
RBrace(RBrace) {}
Token* getFirstToken() const override;
Token* getLastToken() const override;
};
class VarTypeExpression; class VarTypeExpression;
class TypeclassConstraintExpression : public ConstraintExpression { class TypeclassConstraintExpression : public ConstraintExpression {
@ -2079,7 +2144,7 @@ namespace bolt {
bool isVariable() const noexcept { bool isVariable() const noexcept {
// Variables in classes and instances are never possible, so we reflect this by excluding them here. // Variables in classes and instances are never possible, so we reflect this by excluding them here.
return !isSignature() && !isClass() && !isInstance() && (Pattern->getKind() != NodeKind::BindPattern || !Body); return !isSignature() && !isClass() && !isInstance() && Params.empty() && (Pattern->getKind() != NodeKind::BindPattern || !Body);
} }
bool isFunction() const noexcept { bool isFunction() const noexcept {

View file

@ -19,6 +19,7 @@ namespace bolt {
return static_cast<D*>(this)->visit ## name(static_cast<name*>(N)); return static_cast<D*>(this)->visit ## name(static_cast<name*>(N));
switch (N->getKind()) { switch (N->getKind()) {
BOLT_GEN_CASE(VBar)
BOLT_GEN_CASE(Equals) BOLT_GEN_CASE(Equals)
BOLT_GEN_CASE(Colon) BOLT_GEN_CASE(Colon)
BOLT_GEN_CASE(Comma) BOLT_GEN_CASE(Comma)
@ -64,6 +65,8 @@ namespace bolt {
BOLT_GEN_CASE(TypeAssertAnnotation) BOLT_GEN_CASE(TypeAssertAnnotation)
BOLT_GEN_CASE(TypeclassConstraintExpression) BOLT_GEN_CASE(TypeclassConstraintExpression)
BOLT_GEN_CASE(EqualityConstraintExpression) BOLT_GEN_CASE(EqualityConstraintExpression)
BOLT_GEN_CASE(RecordTypeExpressionField)
BOLT_GEN_CASE(RecordTypeExpression)
BOLT_GEN_CASE(QualifiedTypeExpression) BOLT_GEN_CASE(QualifiedTypeExpression)
BOLT_GEN_CASE(ReferenceTypeExpression) BOLT_GEN_CASE(ReferenceTypeExpression)
BOLT_GEN_CASE(ArrowTypeExpression) BOLT_GEN_CASE(ArrowTypeExpression)
@ -121,6 +124,10 @@ namespace bolt {
static_cast<D*>(this)->visitNode(N); static_cast<D*>(this)->visitNode(N);
} }
void visitVBar(VBar* N) {
static_cast<D*>(this)->visitToken(N);
}
void visitEquals(Equals* N) { void visitEquals(Equals* N) {
static_cast<D*>(this)->visitToken(N); static_cast<D*>(this)->visitToken(N);
} }
@ -313,6 +320,14 @@ namespace bolt {
static_cast<D*>(this)->visitNode(N); static_cast<D*>(this)->visitNode(N);
} }
void visitRecordTypeExpressionField(RecordTypeExpressionField * N) {
static_cast<D*>(this)->visitNode(N);
}
void visitRecordTypeExpression(RecordTypeExpression* N) {
static_cast<D*>(this)->visitTypeExpression(N);
}
void visitQualifiedTypeExpression(QualifiedTypeExpression* N) { void visitQualifiedTypeExpression(QualifiedTypeExpression* N) {
static_cast<D*>(this)->visitTypeExpression(N); static_cast<D*>(this)->visitTypeExpression(N);
} }
@ -519,6 +534,7 @@ namespace bolt {
break; break;
switch (N->getKind()) { switch (N->getKind()) {
BOLT_GEN_CHILD_CASE(VBar)
BOLT_GEN_CHILD_CASE(Equals) BOLT_GEN_CHILD_CASE(Equals)
BOLT_GEN_CHILD_CASE(Colon) BOLT_GEN_CHILD_CASE(Colon)
BOLT_GEN_CHILD_CASE(Comma) BOLT_GEN_CHILD_CASE(Comma)
@ -564,6 +580,8 @@ namespace bolt {
BOLT_GEN_CHILD_CASE(TypeAssertAnnotation) BOLT_GEN_CHILD_CASE(TypeAssertAnnotation)
BOLT_GEN_CHILD_CASE(TypeclassConstraintExpression) BOLT_GEN_CHILD_CASE(TypeclassConstraintExpression)
BOLT_GEN_CHILD_CASE(EqualityConstraintExpression) BOLT_GEN_CHILD_CASE(EqualityConstraintExpression)
BOLT_GEN_CHILD_CASE(RecordTypeExpressionField)
BOLT_GEN_CHILD_CASE(RecordTypeExpression)
BOLT_GEN_CHILD_CASE(QualifiedTypeExpression) BOLT_GEN_CHILD_CASE(QualifiedTypeExpression)
BOLT_GEN_CHILD_CASE(ReferenceTypeExpression) BOLT_GEN_CHILD_CASE(ReferenceTypeExpression)
BOLT_GEN_CHILD_CASE(ArrowTypeExpression) BOLT_GEN_CHILD_CASE(ArrowTypeExpression)
@ -613,6 +631,9 @@ namespace bolt {
#define BOLT_VISIT(node) static_cast<D*>(this)->visit(node) #define BOLT_VISIT(node) static_cast<D*>(this)->visit(node)
void visitEachChild(VBar* N) {
}
void visitEachChild(Equals* N) { void visitEachChild(Equals* N) {
} }
@ -760,6 +781,29 @@ namespace bolt {
BOLT_VISIT(N->Right); BOLT_VISIT(N->Right);
} }
void visitEachChild(RecordTypeExpressionField* N) {
BOLT_VISIT(N->Name);
BOLT_VISIT(N->Colon);
BOLT_VISIT(N->TE);
}
void visitEachChild(RecordTypeExpression* N) {
BOLT_VISIT(N->LBrace);
for (auto [Field, Comma]: N->Fields) {
BOLT_VISIT(Field);
if (Comma) {
BOLT_VISIT(Comma);
}
}
if (N->VBar) {
BOLT_VISIT(N->VBar);
}
if (N->Rest) {
BOLT_VISIT(N->Rest);
}
BOLT_VISIT(N->RBrace);
}
void visitEachChild(QualifiedTypeExpression* N) { void visitEachChild(QualifiedTypeExpression* N) {
for (auto [CE, Comma]: N->Constraints) { for (auto [CE, Comma]: N->Constraints) {
BOLT_VISIT(CE); BOLT_VISIT(CE);

View file

@ -393,6 +393,22 @@ namespace bolt {
return Left->getLastToken(); return Left->getLastToken();
} }
Token* RecordTypeExpressionField::getFirstToken() const {
return Name;
}
Token* RecordTypeExpressionField::getLastToken() const {
return TE->getLastToken();
}
Token* RecordTypeExpression::getFirstToken() const {
return LBrace;
}
Token* RecordTypeExpression::getLastToken() const {
return RBrace;
}
Token* QualifiedTypeExpression::getFirstToken() const { Token* QualifiedTypeExpression::getFirstToken() const {
if (!Constraints.empty()) { if (!Constraints.empty()) {
return std::get<0>(Constraints.front())->getFirstToken(); return std::get<0>(Constraints.front())->getFirstToken();
@ -840,6 +856,10 @@ namespace bolt {
return nullptr; return nullptr;
} }
std::string VBar::getText() const {
return "|";
}
std::string Equals::getText() const { std::string Equals::getText() const {
return "="; return "=";
} }

View file

@ -677,9 +677,12 @@ namespace bolt {
case NodeKind::LetDeclaration: case NodeKind::LetDeclaration:
{ {
// Function declarations are handled separately in inferLetDeclaration() // Function declarations are handled separately in inferFunctionDeclaration()
auto Decl = static_cast<LetDeclaration*>(N); auto Decl = static_cast<LetDeclaration*>(N);
if (Decl->isFunction() && !Decl->Visited) { if (Decl->Visited) {
break;
}
if (Decl->isFunction()) {
Decl->IsCycleActive = true; Decl->IsCycleActive = true;
Decl->Visited = true; Decl->Visited = true;
inferFunctionDeclaration(Decl); inferFunctionDeclaration(Decl);
@ -854,6 +857,17 @@ namespace bolt {
return Ty; return Ty;
} }
case NodeKind::RecordTypeExpression:
{
auto RecTE = static_cast<RecordTypeExpression*>(N);
auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, IsPoly) : new Type(TNil());
for (auto [Field, Comma]: RecTE->Fields) {
Ty = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(inferTypeExpression(Field->TE, IsPoly))), Ty));
}
N->setType(Ty);
return Ty;
}
case NodeKind::TupleTypeExpression: case NodeKind::TupleTypeExpression:
{ {
auto TupleTE = static_cast<TupleTypeExpression*>(N); auto TupleTE = static_cast<TupleTypeExpression*>(N);

View file

@ -1,6 +1,7 @@
// TODO check for memory leaks everywhere a nullptr is returned // TODO check for memory leaks everywhere a nullptr is returned
#include <tuple>
#include <vector> #include <vector>
#include "bolt/Common.hpp" #include "bolt/Common.hpp"
@ -385,6 +386,9 @@ after_constraints:
LParen->unref(); LParen->unref();
for (auto [CE, Comma]: Constraints) { for (auto [CE, Comma]: Constraints) {
CE->unref(); CE->unref();
if (Comma) {
Comma->unref();
}
} }
RParen->unref(); RParen->unref();
RArrowAlt->unref(); RArrowAlt->unref();
@ -398,6 +402,93 @@ after_constraints:
switch (T0->getKind()) { switch (T0->getKind()) {
case NodeKind::Identifier: case NodeKind::Identifier:
return parseVarTypeExpression(); return parseVarTypeExpression();
case NodeKind::LBrace:
{
Tokens.get();
auto LBrace = static_cast<class LBrace*>(T0);
std::vector<std::tuple<RecordTypeExpressionField*, Comma*>> Fields;
VBar* VBar = nullptr;
TypeExpression* Rest = nullptr;
for (;;) {
auto T1 = Tokens.peek();
if (T1->getKind() == NodeKind::RBrace) {
break;
}
auto Name = expectToken<Identifier>();
if (Name == nullptr) {
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
return nullptr;
}
auto Colon = expectToken<class Colon>();
if (Colon == nullptr) {
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
Name->unref();
return nullptr;
}
auto TE = parseTypeExpression();
if (TE == nullptr) {
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
Name->unref();
Colon->unref();
return nullptr;
}
auto Field = new RecordTypeExpressionField(Name, Colon, TE);
auto T3 = Tokens.peek();
if (T3->getKind() == NodeKind::RBrace) {
Fields.push_back(std::make_tuple(Field, nullptr));
break;
}
if (T3->getKind() == NodeKind::VBar) {
Tokens.get();
VBar = static_cast<class VBar*>(T3);
Rest = parseTypeExpression();
if (!Rest) {
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
Field->unref();
return nullptr;
}
auto T4 = Tokens.peek();
if (T4->getKind() != NodeKind::RBrace) {
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
Field->unref();
Rest->unref();
DE.add<UnexpectedTokenDiagnostic>(File, T4, std::vector { NodeKind::RBrace });
return nullptr;
}
break;
}
if (T3->getKind() == NodeKind::Comma) {
Tokens.get();
auto Comma = static_cast<class Comma*>(T3);
Fields.push_back(std::make_tuple(Field, Comma));
continue;
}
DE.add<UnexpectedTokenDiagnostic>(File, T3, std::vector { NodeKind::RBrace, NodeKind::Comma, NodeKind::VBar });
for (auto [Field, Comma]: Fields) {
Field->unref();
Comma->unref();
}
Field->unref();
return nullptr;
}
auto RBrace = static_cast<class RBrace*>(Tokens.get());
return new RecordTypeExpression(LBrace, Fields, VBar, Rest, RBrace);
}
case NodeKind::LParen: case NodeKind::LParen:
{ {
Tokens.get(); Tokens.get();
@ -488,7 +579,16 @@ after_tuple_element:
for (;;) { for (;;) {
auto T1 = Tokens.peek(); auto T1 = Tokens.peek();
auto Kind = T1->getKind(); auto Kind = T1->getKind();
if (Kind == NodeKind::RArrow || Kind == NodeKind::Equals || Kind == NodeKind::BlockStart || Kind == NodeKind::LineFoldEnd || Kind == NodeKind::EndOfFile || Kind == NodeKind::RParen) { if (Kind == NodeKind::Comma
|| Kind == NodeKind::RArrow
|| Kind == NodeKind::Equals
|| Kind == NodeKind::BlockStart
|| Kind == NodeKind::LineFoldEnd
|| Kind == NodeKind::EndOfFile
|| Kind == NodeKind::RParen
|| Kind == NodeKind::RBracket
|| Kind == NodeKind::RBrace
|| Kind == NodeKind::VBar) {
break; break;
} }
auto TE = parsePrimitiveTypeExpression(); auto TE = parsePrimitiveTypeExpression();

View file

@ -379,7 +379,9 @@ after_string_contents:
Text.push_back(static_cast<char>(C1)); Text.push_back(static_cast<char>(C1));
getChar(); getChar();
} }
if (Text == "->") { if (Text == "|") {
return new VBar(StartLoc);
} else if (Text == "->") {
return new RArrow(StartLoc); return new RArrow(StartLoc);
} else if (Text == "=>") { } else if (Text == "=>") {
return new RArrowAlt(StartLoc); return new RArrowAlt(StartLoc);