Add support for nested/tuple type expressions

This commit is contained in:
Sam Vervaeck 2023-05-23 22:36:01 +02:00
parent 1bba5facc7
commit 5ac162cd72
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
5 changed files with 149 additions and 1 deletions

View file

@ -70,6 +70,8 @@ namespace bolt {
ReferenceTypeExpression, ReferenceTypeExpression,
ArrowTypeExpression, ArrowTypeExpression,
VarTypeExpression, VarTypeExpression,
NestedTypeExpression,
TupleTypeExpression,
BindPattern, BindPattern,
LiteralPattern, LiteralPattern,
ReferenceExpression, ReferenceExpression,
@ -1005,6 +1007,48 @@ namespace bolt {
}; };
class NestedTypeExpression : public TypeExpression {
public:
LParen* LParen;
TypeExpression* TE;
RParen* RParen;
inline NestedTypeExpression(
class LParen* LParen,
TypeExpression* TE,
class RParen* RParen
): TypeExpression(NodeKind::NestedTypeExpression),
LParen(LParen),
TE(TE),
RParen(RParen) {}
Token* getFirstToken() override;
Token* getLastToken() override;
};
class TupleTypeExpression : public TypeExpression {
public:
LParen* LParen;
std::vector<std::tuple<TypeExpression*, Comma*>> Elements;
RParen* RParen;
inline TupleTypeExpression(
class LParen* LParen,
std::vector<std::tuple<TypeExpression*, Comma*>> Elements,
class RParen* RParen
): TypeExpression(NodeKind::TupleTypeExpression),
LParen(LParen),
Elements(Elements),
RParen(RParen) {}
Token* getFirstToken() override;
Token* getLastToken() override;
};
class Pattern : public Node { class Pattern : public Node {
protected: protected:

View file

@ -99,6 +99,10 @@ namespace bolt {
return static_cast<D*>(this)->visitArrowTypeExpression(static_cast<ArrowTypeExpression*>(N)); return static_cast<D*>(this)->visitArrowTypeExpression(static_cast<ArrowTypeExpression*>(N));
case NodeKind::VarTypeExpression: case NodeKind::VarTypeExpression:
return static_cast<D*>(this)->visitVarTypeExpression(static_cast<VarTypeExpression*>(N)); return static_cast<D*>(this)->visitVarTypeExpression(static_cast<VarTypeExpression*>(N));
case NodeKind::NestedTypeExpression:
return static_cast<D*>(this)->visitNestedTypeExpression(static_cast<NestedTypeExpression*>(N));
case NodeKind::TupleTypeExpression:
return static_cast<D*>(this)->visitTupleTypeExpression(static_cast<TupleTypeExpression*>(N));
case NodeKind::BindPattern: case NodeKind::BindPattern:
return static_cast<D*>(this)->visitBindPattern(static_cast<BindPattern*>(N)); return static_cast<D*>(this)->visitBindPattern(static_cast<BindPattern*>(N));
case NodeKind::LiteralPattern: case NodeKind::LiteralPattern:
@ -348,6 +352,14 @@ namespace bolt {
visitTypeExpression(N); visitTypeExpression(N);
} }
void visitNestedTypeExpression(NestedTypeExpression* N) {
visitTypeExpression(N);
}
void visitTupleTypeExpression(TupleTypeExpression* N) {
visitTypeExpression(N);
}
void visitPattern(Pattern* N) { void visitPattern(Pattern* N) {
visitNode(N); visitNode(N);
} }
@ -604,6 +616,12 @@ namespace bolt {
case NodeKind::VarTypeExpression: case NodeKind::VarTypeExpression:
visitEachChild(static_cast<VarTypeExpression*>(N)); visitEachChild(static_cast<VarTypeExpression*>(N));
break; break;
case NodeKind::NestedTypeExpression:
visitEachChild(static_cast<NestedTypeExpression*>(N));
break;
case NodeKind::TupleTypeExpression:
visitEachChild(static_cast<TupleTypeExpression*>(N));
break;
case NodeKind::BindPattern: case NodeKind::BindPattern:
visitEachChild(static_cast<BindPattern*>(N)); visitEachChild(static_cast<BindPattern*>(N));
break; break;
@ -846,6 +864,23 @@ namespace bolt {
BOLT_VISIT(N->Name); BOLT_VISIT(N->Name);
} }
void visitEachChild(NestedTypeExpression* N) {
BOLT_VISIT(N->LParen);
BOLT_VISIT(N->TE);
BOLT_VISIT(N->RParen);
}
void visitEachChild(TupleTypeExpression* N) {
BOLT_VISIT(N->LParen);
for (auto [TE, Comma]: N->Elements) {
if (Comma) {
BOLT_VISIT(Comma);
}
BOLT_VISIT(TE);
}
BOLT_VISIT(N->RParen);
}
void visitEachChild(BindPattern* N) { void visitEachChild(BindPattern* N) {
BOLT_VISIT(N->Name); BOLT_VISIT(N->Name);
} }

View file

@ -229,6 +229,22 @@ namespace bolt {
return Name; return Name;
} }
Token* NestedTypeExpression::getLastToken() {
return LParen;
}
Token* NestedTypeExpression::getFirstToken() {
return RParen;
}
Token* TupleTypeExpression::getLastToken() {
return LParen;
}
Token* TupleTypeExpression::getFirstToken() {
return RParen;
}
Token* BindPattern::getFirstToken() { Token* BindPattern::getFirstToken() {
return Name; return Name;
} }

View file

@ -585,6 +585,26 @@ namespace bolt {
return Ty; return Ty;
} }
case NodeKind::TupleTypeExpression:
{
auto TupleTE = static_cast<TupleTypeExpression*>(N);
std::vector<Type*> ElementTypes;
for (auto [TE, Comma]: TupleTE->Elements) {
ElementTypes.push_back(inferTypeExpression(TE));
}
auto Ty = new TTuple(ElementTypes);
N->setType(Ty);
return Ty;
}
case NodeKind::NestedTypeExpression:
{
auto NestedTE = static_cast<NestedTypeExpression*>(N);
auto Ty = inferTypeExpression(NestedTE->TE);
N->setType(Ty);
return Ty;
}
case NodeKind::ArrowTypeExpression: case NodeKind::ArrowTypeExpression:
{ {
auto ArrowTE = static_cast<ArrowTypeExpression*>(N); auto ArrowTE = static_cast<ArrowTypeExpression*>(N);

View file

@ -173,6 +173,39 @@ after_constraints:
switch (T0->getKind()) { switch (T0->getKind()) {
case NodeKind::Identifier: case NodeKind::Identifier:
return parseVarTypeExpression(); return parseVarTypeExpression();
case NodeKind::LParen:
{
Tokens.get();
auto LParen = static_cast<class LParen*>(T0);
std::vector<std::tuple<TypeExpression*, Comma*>> Elements;
RParen* RParen;
for (;;) {
auto T1 = Tokens.peek();
if (llvm::isa<class RParen>(T1)) {
Tokens.get();
RParen = static_cast<class RParen*>(T1);
break;
}
auto TE = parseTypeExpression();
auto T2 = Tokens.get();
switch (T2->getKind()) {
case NodeKind::RParen:
RParen = static_cast<class RParen*>(T1);
Elements.push_back({ TE, nullptr });
goto after_tuple_element;
case NodeKind::Comma:
Elements.push_back({ TE, static_cast<Comma*>(T2) });
continue;
default:
throw UnexpectedTokenDiagnostic(File, T2, { NodeKind::Comma, NodeKind::RParen });
}
}
after_tuple_element:
if (Elements.size() == 1) {
return new NestedTypeExpression { LParen, std::get<0>(Elements.front()), RParen };
}
return new TupleTypeExpression { LParen, Elements, RParen };
}
case NodeKind::IdentifierAlt: case NodeKind::IdentifierAlt:
{ {
std::vector<std::tuple<IdentifierAlt*, Dot*>> ModulePath; std::vector<std::tuple<IdentifierAlt*, Dot*>> ModulePath;
@ -547,7 +580,7 @@ after_params:
return parseExpressionStatement(); return parseExpressionStatement();
} }
} }
#
ConstraintExpression* Parser::parseConstraintExpression() { ConstraintExpression* Parser::parseConstraintExpression() {
bool HasTilde = false; bool HasTilde = false;
for (std::size_t I = 0; ; I++) { for (std::size_t I = 0; ; I++) {