Add support for record type expressions
This commit is contained in:
parent
1b5c32fe29
commit
00dbada7ac
6 changed files with 250 additions and 5 deletions
|
@ -87,6 +87,7 @@ namespace bolt {
|
|||
};
|
||||
|
||||
enum class NodeKind {
|
||||
VBar,
|
||||
Equals,
|
||||
Colon,
|
||||
Comma,
|
||||
|
@ -132,6 +133,8 @@ namespace bolt {
|
|||
TypeAssertAnnotation,
|
||||
TypeclassConstraintExpression,
|
||||
EqualityConstraintExpression,
|
||||
RecordTypeExpressionField,
|
||||
RecordTypeExpression,
|
||||
QualifiedTypeExpression,
|
||||
ReferenceTypeExpression,
|
||||
ArrowTypeExpression,
|
||||
|
@ -363,6 +366,20 @@ namespace bolt {
|
|||
|
||||
};
|
||||
|
||||
class VBar : public Token {
|
||||
public:
|
||||
|
||||
inline VBar(TextLoc StartLoc):
|
||||
Token(NodeKind::VBar, StartLoc) {}
|
||||
|
||||
std::string getText() const override;
|
||||
|
||||
static bool classof(const Node* N) {
|
||||
return N->getKind() == NodeKind::VBar;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class Colon : public Token {
|
||||
public:
|
||||
|
||||
|
@ -1085,6 +1102,54 @@ namespace bolt {
|
|||
|
||||
};
|
||||
|
||||
class RecordTypeExpressionField : public Node {
|
||||
public:
|
||||
|
||||
Identifier* Name;
|
||||
Colon* Colon;
|
||||
TypeExpression* TE;
|
||||
|
||||
inline RecordTypeExpressionField(
|
||||
Identifier* Name,
|
||||
class Colon* Colon,
|
||||
TypeExpression* TE
|
||||
): Node(NodeKind::RecordTypeExpressionField),
|
||||
Name(Name),
|
||||
Colon(Colon),
|
||||
TE(TE) {}
|
||||
|
||||
Token* getFirstToken() const override;
|
||||
Token* getLastToken() const override;
|
||||
|
||||
};
|
||||
|
||||
class RecordTypeExpression : public TypeExpression {
|
||||
public:
|
||||
|
||||
LBrace* LBrace;
|
||||
std::vector<std::tuple<RecordTypeExpressionField*, Comma*>> Fields;
|
||||
VBar* VBar;
|
||||
TypeExpression* Rest;
|
||||
RBrace* RBrace;
|
||||
|
||||
inline RecordTypeExpression(
|
||||
class LBrace* LBrace,
|
||||
std::vector<std::tuple<RecordTypeExpressionField*, Comma*>> Fields,
|
||||
class VBar* VBar,
|
||||
TypeExpression* Rest,
|
||||
class RBrace* RBrace
|
||||
): TypeExpression(NodeKind::RecordTypeExpression),
|
||||
LBrace(LBrace),
|
||||
Fields(Fields),
|
||||
VBar(VBar),
|
||||
Rest(Rest),
|
||||
RBrace(RBrace) {}
|
||||
|
||||
Token* getFirstToken() const override;
|
||||
Token* getLastToken() const override;
|
||||
|
||||
};
|
||||
|
||||
class VarTypeExpression;
|
||||
|
||||
class TypeclassConstraintExpression : public ConstraintExpression {
|
||||
|
@ -2079,7 +2144,7 @@ namespace bolt {
|
|||
|
||||
bool isVariable() const noexcept {
|
||||
// Variables in classes and instances are never possible, so we reflect this by excluding them here.
|
||||
return !isSignature() && !isClass() && !isInstance() && (Pattern->getKind() != NodeKind::BindPattern || !Body);
|
||||
return !isSignature() && !isClass() && !isInstance() && Params.empty() && (Pattern->getKind() != NodeKind::BindPattern || !Body);
|
||||
}
|
||||
|
||||
bool isFunction() const noexcept {
|
||||
|
|
|
@ -19,6 +19,7 @@ namespace bolt {
|
|||
return static_cast<D*>(this)->visit ## name(static_cast<name*>(N));
|
||||
|
||||
switch (N->getKind()) {
|
||||
BOLT_GEN_CASE(VBar)
|
||||
BOLT_GEN_CASE(Equals)
|
||||
BOLT_GEN_CASE(Colon)
|
||||
BOLT_GEN_CASE(Comma)
|
||||
|
@ -64,6 +65,8 @@ namespace bolt {
|
|||
BOLT_GEN_CASE(TypeAssertAnnotation)
|
||||
BOLT_GEN_CASE(TypeclassConstraintExpression)
|
||||
BOLT_GEN_CASE(EqualityConstraintExpression)
|
||||
BOLT_GEN_CASE(RecordTypeExpressionField)
|
||||
BOLT_GEN_CASE(RecordTypeExpression)
|
||||
BOLT_GEN_CASE(QualifiedTypeExpression)
|
||||
BOLT_GEN_CASE(ReferenceTypeExpression)
|
||||
BOLT_GEN_CASE(ArrowTypeExpression)
|
||||
|
@ -121,6 +124,10 @@ namespace bolt {
|
|||
static_cast<D*>(this)->visitNode(N);
|
||||
}
|
||||
|
||||
void visitVBar(VBar* N) {
|
||||
static_cast<D*>(this)->visitToken(N);
|
||||
}
|
||||
|
||||
void visitEquals(Equals* N) {
|
||||
static_cast<D*>(this)->visitToken(N);
|
||||
}
|
||||
|
@ -313,6 +320,14 @@ namespace bolt {
|
|||
static_cast<D*>(this)->visitNode(N);
|
||||
}
|
||||
|
||||
void visitRecordTypeExpressionField(RecordTypeExpressionField * N) {
|
||||
static_cast<D*>(this)->visitNode(N);
|
||||
}
|
||||
|
||||
void visitRecordTypeExpression(RecordTypeExpression* N) {
|
||||
static_cast<D*>(this)->visitTypeExpression(N);
|
||||
}
|
||||
|
||||
void visitQualifiedTypeExpression(QualifiedTypeExpression* N) {
|
||||
static_cast<D*>(this)->visitTypeExpression(N);
|
||||
}
|
||||
|
@ -519,6 +534,7 @@ namespace bolt {
|
|||
break;
|
||||
|
||||
switch (N->getKind()) {
|
||||
BOLT_GEN_CHILD_CASE(VBar)
|
||||
BOLT_GEN_CHILD_CASE(Equals)
|
||||
BOLT_GEN_CHILD_CASE(Colon)
|
||||
BOLT_GEN_CHILD_CASE(Comma)
|
||||
|
@ -564,6 +580,8 @@ namespace bolt {
|
|||
BOLT_GEN_CHILD_CASE(TypeAssertAnnotation)
|
||||
BOLT_GEN_CHILD_CASE(TypeclassConstraintExpression)
|
||||
BOLT_GEN_CHILD_CASE(EqualityConstraintExpression)
|
||||
BOLT_GEN_CHILD_CASE(RecordTypeExpressionField)
|
||||
BOLT_GEN_CHILD_CASE(RecordTypeExpression)
|
||||
BOLT_GEN_CHILD_CASE(QualifiedTypeExpression)
|
||||
BOLT_GEN_CHILD_CASE(ReferenceTypeExpression)
|
||||
BOLT_GEN_CHILD_CASE(ArrowTypeExpression)
|
||||
|
@ -613,6 +631,9 @@ namespace bolt {
|
|||
|
||||
#define BOLT_VISIT(node) static_cast<D*>(this)->visit(node)
|
||||
|
||||
void visitEachChild(VBar* N) {
|
||||
}
|
||||
|
||||
void visitEachChild(Equals* N) {
|
||||
}
|
||||
|
||||
|
@ -760,6 +781,29 @@ namespace bolt {
|
|||
BOLT_VISIT(N->Right);
|
||||
}
|
||||
|
||||
void visitEachChild(RecordTypeExpressionField* N) {
|
||||
BOLT_VISIT(N->Name);
|
||||
BOLT_VISIT(N->Colon);
|
||||
BOLT_VISIT(N->TE);
|
||||
}
|
||||
|
||||
void visitEachChild(RecordTypeExpression* N) {
|
||||
BOLT_VISIT(N->LBrace);
|
||||
for (auto [Field, Comma]: N->Fields) {
|
||||
BOLT_VISIT(Field);
|
||||
if (Comma) {
|
||||
BOLT_VISIT(Comma);
|
||||
}
|
||||
}
|
||||
if (N->VBar) {
|
||||
BOLT_VISIT(N->VBar);
|
||||
}
|
||||
if (N->Rest) {
|
||||
BOLT_VISIT(N->Rest);
|
||||
}
|
||||
BOLT_VISIT(N->RBrace);
|
||||
}
|
||||
|
||||
void visitEachChild(QualifiedTypeExpression* N) {
|
||||
for (auto [CE, Comma]: N->Constraints) {
|
||||
BOLT_VISIT(CE);
|
||||
|
|
|
@ -393,6 +393,22 @@ namespace bolt {
|
|||
return Left->getLastToken();
|
||||
}
|
||||
|
||||
Token* RecordTypeExpressionField::getFirstToken() const {
|
||||
return Name;
|
||||
}
|
||||
|
||||
Token* RecordTypeExpressionField::getLastToken() const {
|
||||
return TE->getLastToken();
|
||||
}
|
||||
|
||||
Token* RecordTypeExpression::getFirstToken() const {
|
||||
return LBrace;
|
||||
}
|
||||
|
||||
Token* RecordTypeExpression::getLastToken() const {
|
||||
return RBrace;
|
||||
}
|
||||
|
||||
Token* QualifiedTypeExpression::getFirstToken() const {
|
||||
if (!Constraints.empty()) {
|
||||
return std::get<0>(Constraints.front())->getFirstToken();
|
||||
|
@ -840,6 +856,10 @@ namespace bolt {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
std::string VBar::getText() const {
|
||||
return "|";
|
||||
}
|
||||
|
||||
std::string Equals::getText() const {
|
||||
return "=";
|
||||
}
|
||||
|
|
|
@ -677,9 +677,12 @@ namespace bolt {
|
|||
|
||||
case NodeKind::LetDeclaration:
|
||||
{
|
||||
// Function declarations are handled separately in inferLetDeclaration()
|
||||
// Function declarations are handled separately in inferFunctionDeclaration()
|
||||
auto Decl = static_cast<LetDeclaration*>(N);
|
||||
if (Decl->isFunction() && !Decl->Visited) {
|
||||
if (Decl->Visited) {
|
||||
break;
|
||||
}
|
||||
if (Decl->isFunction()) {
|
||||
Decl->IsCycleActive = true;
|
||||
Decl->Visited = true;
|
||||
inferFunctionDeclaration(Decl);
|
||||
|
@ -854,6 +857,17 @@ namespace bolt {
|
|||
return Ty;
|
||||
}
|
||||
|
||||
case NodeKind::RecordTypeExpression:
|
||||
{
|
||||
auto RecTE = static_cast<RecordTypeExpression*>(N);
|
||||
auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, IsPoly) : new Type(TNil());
|
||||
for (auto [Field, Comma]: RecTE->Fields) {
|
||||
Ty = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(inferTypeExpression(Field->TE, IsPoly))), Ty));
|
||||
}
|
||||
N->setType(Ty);
|
||||
return Ty;
|
||||
}
|
||||
|
||||
case NodeKind::TupleTypeExpression:
|
||||
{
|
||||
auto TupleTE = static_cast<TupleTypeExpression*>(N);
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
|
||||
// TODO check for memory leaks everywhere a nullptr is returned
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "bolt/Common.hpp"
|
||||
|
@ -385,6 +386,9 @@ after_constraints:
|
|||
LParen->unref();
|
||||
for (auto [CE, Comma]: Constraints) {
|
||||
CE->unref();
|
||||
if (Comma) {
|
||||
Comma->unref();
|
||||
}
|
||||
}
|
||||
RParen->unref();
|
||||
RArrowAlt->unref();
|
||||
|
@ -398,6 +402,93 @@ after_constraints:
|
|||
switch (T0->getKind()) {
|
||||
case NodeKind::Identifier:
|
||||
return parseVarTypeExpression();
|
||||
case NodeKind::LBrace:
|
||||
{
|
||||
Tokens.get();
|
||||
auto LBrace = static_cast<class LBrace*>(T0);
|
||||
std::vector<std::tuple<RecordTypeExpressionField*, Comma*>> Fields;
|
||||
VBar* VBar = nullptr;
|
||||
TypeExpression* Rest = nullptr;
|
||||
for (;;) {
|
||||
auto T1 = Tokens.peek();
|
||||
if (T1->getKind() == NodeKind::RBrace) {
|
||||
break;
|
||||
}
|
||||
auto Name = expectToken<Identifier>();
|
||||
if (Name == nullptr) {
|
||||
for (auto [Field, Comma]: Fields) {
|
||||
Field->unref();
|
||||
Comma->unref();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
auto Colon = expectToken<class Colon>();
|
||||
if (Colon == nullptr) {
|
||||
for (auto [Field, Comma]: Fields) {
|
||||
Field->unref();
|
||||
Comma->unref();
|
||||
}
|
||||
Name->unref();
|
||||
return nullptr;
|
||||
}
|
||||
auto TE = parseTypeExpression();
|
||||
if (TE == nullptr) {
|
||||
for (auto [Field, Comma]: Fields) {
|
||||
Field->unref();
|
||||
Comma->unref();
|
||||
}
|
||||
Name->unref();
|
||||
Colon->unref();
|
||||
return nullptr;
|
||||
}
|
||||
auto Field = new RecordTypeExpressionField(Name, Colon, TE);
|
||||
auto T3 = Tokens.peek();
|
||||
if (T3->getKind() == NodeKind::RBrace) {
|
||||
Fields.push_back(std::make_tuple(Field, nullptr));
|
||||
break;
|
||||
}
|
||||
if (T3->getKind() == NodeKind::VBar) {
|
||||
Tokens.get();
|
||||
VBar = static_cast<class VBar*>(T3);
|
||||
Rest = parseTypeExpression();
|
||||
if (!Rest) {
|
||||
for (auto [Field, Comma]: Fields) {
|
||||
Field->unref();
|
||||
Comma->unref();
|
||||
}
|
||||
Field->unref();
|
||||
return nullptr;
|
||||
}
|
||||
auto T4 = Tokens.peek();
|
||||
if (T4->getKind() != NodeKind::RBrace) {
|
||||
for (auto [Field, Comma]: Fields) {
|
||||
Field->unref();
|
||||
Comma->unref();
|
||||
}
|
||||
Field->unref();
|
||||
Rest->unref();
|
||||
DE.add<UnexpectedTokenDiagnostic>(File, T4, std::vector { NodeKind::RBrace });
|
||||
return nullptr;
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (T3->getKind() == NodeKind::Comma) {
|
||||
Tokens.get();
|
||||
auto Comma = static_cast<class Comma*>(T3);
|
||||
Fields.push_back(std::make_tuple(Field, Comma));
|
||||
continue;
|
||||
}
|
||||
DE.add<UnexpectedTokenDiagnostic>(File, T3, std::vector { NodeKind::RBrace, NodeKind::Comma, NodeKind::VBar });
|
||||
for (auto [Field, Comma]: Fields) {
|
||||
Field->unref();
|
||||
Comma->unref();
|
||||
}
|
||||
Field->unref();
|
||||
return nullptr;
|
||||
}
|
||||
auto RBrace = static_cast<class RBrace*>(Tokens.get());
|
||||
return new RecordTypeExpression(LBrace, Fields, VBar, Rest, RBrace);
|
||||
}
|
||||
case NodeKind::LParen:
|
||||
{
|
||||
Tokens.get();
|
||||
|
@ -488,7 +579,16 @@ after_tuple_element:
|
|||
for (;;) {
|
||||
auto T1 = Tokens.peek();
|
||||
auto Kind = T1->getKind();
|
||||
if (Kind == NodeKind::RArrow || Kind == NodeKind::Equals || Kind == NodeKind::BlockStart || Kind == NodeKind::LineFoldEnd || Kind == NodeKind::EndOfFile || Kind == NodeKind::RParen) {
|
||||
if (Kind == NodeKind::Comma
|
||||
|| Kind == NodeKind::RArrow
|
||||
|| Kind == NodeKind::Equals
|
||||
|| Kind == NodeKind::BlockStart
|
||||
|| Kind == NodeKind::LineFoldEnd
|
||||
|| Kind == NodeKind::EndOfFile
|
||||
|| Kind == NodeKind::RParen
|
||||
|| Kind == NodeKind::RBracket
|
||||
|| Kind == NodeKind::RBrace
|
||||
|| Kind == NodeKind::VBar) {
|
||||
break;
|
||||
}
|
||||
auto TE = parsePrimitiveTypeExpression();
|
||||
|
|
|
@ -379,7 +379,9 @@ after_string_contents:
|
|||
Text.push_back(static_cast<char>(C1));
|
||||
getChar();
|
||||
}
|
||||
if (Text == "->") {
|
||||
if (Text == "|") {
|
||||
return new VBar(StartLoc);
|
||||
} else if (Text == "->") {
|
||||
return new RArrow(StartLoc);
|
||||
} else if (Text == "=>") {
|
||||
return new RArrowAlt(StartLoc);
|
||||
|
|
Loading…
Reference in a new issue