Split let-declaration up into function/variable declarations

This commit is contained in:
Sam Vervaeck 2023-05-30 21:34:40 +02:00
parent 053d45868e
commit 87af4686b7
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
11 changed files with 545 additions and 551 deletions

View file

@ -103,6 +103,7 @@ namespace bolt {
RArrow,
RArrowAlt,
LetKeyword,
FnKeyword,
MutKeyword,
PubKeyword,
TypeKeyword,
@ -160,7 +161,8 @@ namespace bolt {
Parameter,
LetBlockBody,
LetExprBody,
LetDeclaration,
FunctionDeclaration,
VariableDeclaration,
RecordDeclarationField,
RecordDeclaration,
VariantDeclaration,
@ -549,6 +551,20 @@ namespace bolt {
};
class FnKeyword : public Token {
public:
inline FnKeyword(TextLoc StartLoc):
Token(NodeKind::FnKeyword, StartLoc) {}
std::string getText() const override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::FnKeyword;
}
};
class MutKeyword : public Token {
public:
@ -1661,7 +1677,7 @@ namespace bolt {
};
class LetDeclaration : public Node {
class FunctionDeclaration : public Node {
Scope* TheScope = nullptr;
@ -1672,26 +1688,23 @@ namespace bolt {
class Type* Ty;
class PubKeyword* PubKeyword;
class LetKeyword* LetKeyword;
class MutKeyword* MutKeyword;
class Pattern* Pattern;
class FnKeyword* FnKeyword;
class Identifier* Name;
std::vector<Parameter*> Params;
class TypeAssert* TypeAssert;
LetBody* Body;
LetDeclaration(
FunctionDeclaration(
class PubKeyword* PubKeyword,
class LetKeyword* LetKeywod,
class MutKeyword* MutKeyword,
class Pattern* Pattern,
class FnKeyword* FnKeyword,
class Identifier* Name,
std::vector<Parameter*> Params,
class TypeAssert* TypeAssert,
LetBody* Body
): Node(NodeKind::LetDeclaration),
): Node(NodeKind::FunctionDeclaration),
PubKeyword(PubKeyword),
LetKeyword(LetKeywod),
MutKeyword(MutKeyword),
Pattern(Pattern),
FnKeyword(FnKeyword),
Name(Name),
Params(Params),
TypeAssert(TypeAssert),
Body(Body) {}
@ -1703,14 +1716,6 @@ namespace bolt {
return TheScope;
}
bool isFunc() const noexcept {
return !Params.empty();
}
bool isVar() const noexcept {
return !isFunc();
}
bool isInstance() const noexcept {
return Parent->getKind() == NodeKind::InstanceDeclaration;
}
@ -1723,11 +1728,47 @@ namespace bolt {
Token* getLastToken() const override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::LetDeclaration;
return N->getKind() == NodeKind::FunctionDeclaration;
}
};
class VariableDeclaration : public TypedNode {
Scope* TheScope = nullptr;
public:
bool IsCycleActive = false;
class PubKeyword* PubKeyword;
class LetKeyword* LetKeyword;
class MutKeyword* MutKeyword;
class Pattern* Pattern;
std::vector<Parameter*> Params;
class TypeAssert* TypeAssert;
LetBody* Body;
VariableDeclaration(
class PubKeyword* PubKeyword,
class LetKeyword* LetKeyword,
class MutKeyword* MutKeyword,
class Pattern* Pattern,
class TypeAssert* TypeAssert,
LetBody* Body
): TypedNode(NodeKind::VariableDeclaration),
PubKeyword(PubKeyword),
LetKeyword(LetKeyword),
MutKeyword(MutKeyword),
Pattern(Pattern),
TypeAssert(TypeAssert),
Body(Body) {}
Token* getFirstToken() const override;
Token* getLastToken() const override;
};
class InstanceDeclaration : public Node {
public:
@ -1977,6 +2018,7 @@ namespace bolt {
template<> inline NodeKind getNodeType<RArrow>() { return NodeKind::RArrow; }
template<> inline NodeKind getNodeType<RArrowAlt>() { return NodeKind::RArrowAlt; }
template<> inline NodeKind getNodeType<LetKeyword>() { return NodeKind::LetKeyword; }
template<> inline NodeKind getNodeType<FnKeyword>() { return NodeKind::FnKeyword; }
template<> inline NodeKind getNodeType<MutKeyword>() { return NodeKind::MutKeyword; }
template<> inline NodeKind getNodeType<PubKeyword>() { return NodeKind::PubKeyword; }
template<> inline NodeKind getNodeType<TypeKeyword>() { return NodeKind::TypeKeyword; }
@ -2019,7 +2061,8 @@ namespace bolt {
template<> inline NodeKind getNodeType<Parameter>() { return NodeKind::Parameter; }
template<> inline NodeKind getNodeType<LetBlockBody>() { return NodeKind::LetBlockBody; }
template<> inline NodeKind getNodeType<LetExprBody>() { return NodeKind::LetExprBody; }
template<> inline NodeKind getNodeType<LetDeclaration>() { return NodeKind::LetDeclaration; }
template<> inline NodeKind getNodeType<FunctionDeclaration>() { return NodeKind::FunctionDeclaration; }
template<> inline NodeKind getNodeType<VariableDeclaration>() { return NodeKind::VariableDeclaration; }
template<> inline NodeKind getNodeType<RecordDeclarationField>() { return NodeKind::RecordDeclarationField; }
template<> inline NodeKind getNodeType<RecordDeclaration>() { return NodeKind::RecordDeclaration; }
template<> inline NodeKind getNodeType<ClassDeclaration>() { return NodeKind::ClassDeclaration; }

View file

@ -12,170 +12,96 @@ namespace bolt {
public:
void visit(Node* N) {
#define BOLT_GEN_CASE(name) \
case NodeKind::name: \
return static_cast<D*>(this)->visit ## name(static_cast<name*>(N));
switch (N->getKind()) {
case NodeKind::Equals:
return static_cast<D*>(this)->visitEquals(static_cast<Equals*>(N));
case NodeKind::Colon:
return static_cast<D*>(this)->visitColon(static_cast<Colon*>(N));
case NodeKind::Comma:
return static_cast<D*>(this)->visitComma(static_cast<Comma*>(N));
case NodeKind::Dot:
return static_cast<D*>(this)->visitDot(static_cast<Dot*>(N));
case NodeKind::DotDot:
return static_cast<D*>(this)->visitDotDot(static_cast<DotDot*>(N));
case NodeKind::Tilde:
return static_cast<D*>(this)->visitTilde(static_cast<Tilde*>(N));
case NodeKind::LParen:
return static_cast<D*>(this)->visitLParen(static_cast<LParen*>(N));
case NodeKind::RParen:
return static_cast<D*>(this)->visitRParen(static_cast<RParen*>(N));
case NodeKind::LBracket:
return static_cast<D*>(this)->visitLBracket(static_cast<LBracket*>(N));
case NodeKind::RBracket:
return static_cast<D*>(this)->visitRBracket(static_cast<RBracket*>(N));
case NodeKind::LBrace:
return static_cast<D*>(this)->visitLBrace(static_cast<LBrace*>(N));
case NodeKind::RBrace:
return static_cast<D*>(this)->visitRBrace(static_cast<RBrace*>(N));
case NodeKind::RArrow:
return static_cast<D*>(this)->visitRArrow(static_cast<RArrow*>(N));
case NodeKind::RArrowAlt:
return static_cast<D*>(this)->visitRArrowAlt(static_cast<RArrowAlt*>(N));
case NodeKind::LetKeyword:
return static_cast<D*>(this)->visitLetKeyword(static_cast<LetKeyword*>(N));
case NodeKind::MutKeyword:
return static_cast<D*>(this)->visitMutKeyword(static_cast<MutKeyword*>(N));
case NodeKind::PubKeyword:
return static_cast<D*>(this)->visitPubKeyword(static_cast<PubKeyword*>(N));
case NodeKind::TypeKeyword:
return static_cast<D*>(this)->visitTypeKeyword(static_cast<TypeKeyword*>(N));
case NodeKind::ReturnKeyword:
return static_cast<D*>(this)->visitReturnKeyword(static_cast<ReturnKeyword*>(N));
case NodeKind::ModKeyword:
return static_cast<D*>(this)->visitModKeyword(static_cast<ModKeyword*>(N));
case NodeKind::StructKeyword:
return static_cast<D*>(this)->visitStructKeyword(static_cast<StructKeyword*>(N));
case NodeKind::EnumKeyword:
return static_cast<D*>(this)->visitEnumKeyword(static_cast<EnumKeyword*>(N));
case NodeKind::ClassKeyword:
return static_cast<D*>(this)->visitClassKeyword(static_cast<ClassKeyword*>(N));
case NodeKind::InstanceKeyword:
return static_cast<D*>(this)->visitInstanceKeyword(static_cast<InstanceKeyword*>(N));
case NodeKind::ElifKeyword:
return static_cast<D*>(this)->visitElifKeyword(static_cast<ElifKeyword*>(N));
case NodeKind::IfKeyword:
return static_cast<D*>(this)->visitIfKeyword(static_cast<IfKeyword*>(N));
case NodeKind::ElseKeyword:
return static_cast<D*>(this)->visitElseKeyword(static_cast<ElseKeyword*>(N));
case NodeKind::MatchKeyword:
return static_cast<D*>(this)->visitMatchKeyword(static_cast<MatchKeyword*>(N));
case NodeKind::Invalid:
return static_cast<D*>(this)->visitInvalid(static_cast<Invalid*>(N));
case NodeKind::EndOfFile:
return static_cast<D*>(this)->visitEndOfFile(static_cast<EndOfFile*>(N));
case NodeKind::BlockStart:
return static_cast<D*>(this)->visitBlockStart(static_cast<BlockStart*>(N));
case NodeKind::BlockEnd:
return static_cast<D*>(this)->visitBlockEnd(static_cast<BlockEnd*>(N));
case NodeKind::LineFoldEnd:
return static_cast<D*>(this)->visitLineFoldEnd(static_cast<LineFoldEnd*>(N));
case NodeKind::CustomOperator:
return static_cast<D*>(this)->visitCustomOperator(static_cast<CustomOperator*>(N));
case NodeKind::Assignment:
return static_cast<D*>(this)->visitAssignment(static_cast<Assignment*>(N));
case NodeKind::Identifier:
return static_cast<D*>(this)->visitIdentifier(static_cast<Identifier*>(N));
case NodeKind::IdentifierAlt:
return static_cast<D*>(this)->visitIdentifierAlt(static_cast<IdentifierAlt*>(N));
case NodeKind::StringLiteral:
return static_cast<D*>(this)->visitStringLiteral(static_cast<StringLiteral*>(N));
case NodeKind::IntegerLiteral:
return static_cast<D*>(this)->visitIntegerLiteral(static_cast<IntegerLiteral*>(N));
case NodeKind::TypeclassConstraintExpression:
return static_cast<D*>(this)->visitTypeclassConstraintExpression(static_cast<TypeclassConstraintExpression*>(N));
case NodeKind::EqualityConstraintExpression:
return static_cast<D*>(this)->visitEqualityConstraintExpression(static_cast<EqualityConstraintExpression*>(N));
case NodeKind::QualifiedTypeExpression:
return static_cast<D*>(this)->visitQualifiedTypeExpression(static_cast<QualifiedTypeExpression*>(N));
case NodeKind::ReferenceTypeExpression:
return static_cast<D*>(this)->visitReferenceTypeExpression(static_cast<ReferenceTypeExpression*>(N));
case NodeKind::ArrowTypeExpression:
return static_cast<D*>(this)->visitArrowTypeExpression(static_cast<ArrowTypeExpression*>(N));
case NodeKind::AppTypeExpression:
return static_cast<D*>(this)->visitAppTypeExpression(static_cast<AppTypeExpression*>(N));
case NodeKind::VarTypeExpression:
return static_cast<D*>(this)->visitVarTypeExpression(static_cast<VarTypeExpression*>(N));
case NodeKind::NestedTypeExpression:
return static_cast<D*>(this)->visitNestedTypeExpression(static_cast<NestedTypeExpression*>(N));
case NodeKind::TupleTypeExpression:
return static_cast<D*>(this)->visitTupleTypeExpression(static_cast<TupleTypeExpression*>(N));
case NodeKind::BindPattern:
return static_cast<D*>(this)->visitBindPattern(static_cast<BindPattern*>(N));
case NodeKind::LiteralPattern:
return static_cast<D*>(this)->visitLiteralPattern(static_cast<LiteralPattern*>(N));
case NodeKind::NamedPattern:
return static_cast<D*>(this)->visitNamedPattern(static_cast<NamedPattern*>(N));
case NodeKind::NestedPattern:
return static_cast<D*>(this)->visitNestedPattern(static_cast<NestedPattern*>(N));
case NodeKind::ReferenceExpression:
return static_cast<D*>(this)->visitReferenceExpression(static_cast<ReferenceExpression*>(N));
case NodeKind::MatchCase:
return static_cast<D*>(this)->visitMatchCase(static_cast<MatchCase*>(N));
case NodeKind::MatchExpression:
return static_cast<D*>(this)->visitMatchExpression(static_cast<MatchExpression*>(N));
case NodeKind::MemberExpression:
return static_cast<D*>(this)->visitMemberExpression(static_cast<MemberExpression*>(N));
case NodeKind::TupleExpression:
return static_cast<D*>(this)->visitTupleExpression(static_cast<TupleExpression*>(N));
case NodeKind::NestedExpression:
return static_cast<D*>(this)->visitNestedExpression(static_cast<NestedExpression*>(N));
case NodeKind::ConstantExpression:
return static_cast<D*>(this)->visitConstantExpression(static_cast<ConstantExpression*>(N));
case NodeKind::CallExpression:
return static_cast<D*>(this)->visitCallExpression(static_cast<CallExpression*>(N));
case NodeKind::InfixExpression:
return static_cast<D*>(this)->visitInfixExpression(static_cast<InfixExpression*>(N));
case NodeKind::PrefixExpression:
return static_cast<D*>(this)->visitPrefixExpression(static_cast<PrefixExpression*>(N));
case NodeKind::RecordExpressionField:
return static_cast<D*>(this)->visitRecordExpressionField(static_cast<RecordExpressionField*>(N));
case NodeKind::RecordExpression:
return static_cast<D*>(this)->visitRecordExpression(static_cast<RecordExpression*>(N));
case NodeKind::ExpressionStatement:
return static_cast<D*>(this)->visitExpressionStatement(static_cast<ExpressionStatement*>(N));
case NodeKind::ReturnStatement:
return static_cast<D*>(this)->visitReturnStatement(static_cast<ReturnStatement*>(N));
case NodeKind::IfStatement:
return static_cast<D*>(this)->visitIfStatement(static_cast<IfStatement*>(N));
case NodeKind::IfStatementPart:
return static_cast<D*>(this)->visitIfStatementPart(static_cast<IfStatementPart*>(N));
case NodeKind::TypeAssert:
return static_cast<D*>(this)->visitTypeAssert(static_cast<TypeAssert*>(N));
case NodeKind::Parameter:
return static_cast<D*>(this)->visitParameter(static_cast<Parameter*>(N));
case NodeKind::LetBlockBody:
return static_cast<D*>(this)->visitLetBlockBody(static_cast<LetBlockBody*>(N));
case NodeKind::LetExprBody:
return static_cast<D*>(this)->visitLetExprBody(static_cast<LetExprBody*>(N));
case NodeKind::LetDeclaration:
return static_cast<D*>(this)->visitLetDeclaration(static_cast<LetDeclaration*>(N));
case NodeKind::RecordDeclarationField:
return static_cast<D*>(this)->visitRecordDeclarationField(static_cast<RecordDeclarationField*>(N));
case NodeKind::RecordDeclaration:
return static_cast<D*>(this)->visitRecordDeclaration(static_cast<RecordDeclaration*>(N));
case NodeKind::VariantDeclaration:
return static_cast<D*>(this)->visitVariantDeclaration(static_cast<VariantDeclaration*>(N));
case NodeKind::TupleVariantDeclarationMember:
return static_cast<D*>(this)->visitTupleVariantDeclarationMember(static_cast<TupleVariantDeclarationMember*>(N));
case NodeKind::RecordVariantDeclarationMember:
return static_cast<D*>(this)->visitRecordVariantDeclarationMember(static_cast<RecordVariantDeclarationMember*>(N));
case NodeKind::ClassDeclaration:
return static_cast<D*>(this)->visitClassDeclaration(static_cast<ClassDeclaration*>(N));
case NodeKind::InstanceDeclaration:
return static_cast<D*>(this)->visitInstanceDeclaration(static_cast<InstanceDeclaration*>(N));
case NodeKind::SourceFile:
return static_cast<D*>(this)->visitSourceFile(static_cast<SourceFile*>(N));
}
BOLT_GEN_CASE(Equals)
BOLT_GEN_CASE(Colon)
BOLT_GEN_CASE(Comma)
BOLT_GEN_CASE(Dot)
BOLT_GEN_CASE(DotDot)
BOLT_GEN_CASE(Tilde)
BOLT_GEN_CASE(LParen)
BOLT_GEN_CASE(RParen)
BOLT_GEN_CASE(LBracket)
BOLT_GEN_CASE(RBracket)
BOLT_GEN_CASE(LBrace)
BOLT_GEN_CASE(RBrace)
BOLT_GEN_CASE(RArrow)
BOLT_GEN_CASE(RArrowAlt)
BOLT_GEN_CASE(LetKeyword)
BOLT_GEN_CASE(FnKeyword)
BOLT_GEN_CASE(MutKeyword)
BOLT_GEN_CASE(PubKeyword)
BOLT_GEN_CASE(TypeKeyword)
BOLT_GEN_CASE(ReturnKeyword)
BOLT_GEN_CASE(ModKeyword)
BOLT_GEN_CASE(StructKeyword)
BOLT_GEN_CASE(EnumKeyword)
BOLT_GEN_CASE(ClassKeyword)
BOLT_GEN_CASE(InstanceKeyword)
BOLT_GEN_CASE(ElifKeyword)
BOLT_GEN_CASE(IfKeyword)
BOLT_GEN_CASE(ElseKeyword)
BOLT_GEN_CASE(MatchKeyword)
BOLT_GEN_CASE(Invalid)
BOLT_GEN_CASE(EndOfFile)
BOLT_GEN_CASE(BlockStart)
BOLT_GEN_CASE(BlockEnd)
BOLT_GEN_CASE(LineFoldEnd)
BOLT_GEN_CASE(CustomOperator)
BOLT_GEN_CASE(Assignment)
BOLT_GEN_CASE(Identifier)
BOLT_GEN_CASE(IdentifierAlt)
BOLT_GEN_CASE(StringLiteral)
BOLT_GEN_CASE(IntegerLiteral)
BOLT_GEN_CASE(TypeclassConstraintExpression)
BOLT_GEN_CASE(EqualityConstraintExpression)
BOLT_GEN_CASE(QualifiedTypeExpression)
BOLT_GEN_CASE(ReferenceTypeExpression)
BOLT_GEN_CASE(ArrowTypeExpression)
BOLT_GEN_CASE(AppTypeExpression)
BOLT_GEN_CASE(VarTypeExpression)
BOLT_GEN_CASE(NestedTypeExpression)
BOLT_GEN_CASE(TupleTypeExpression)
BOLT_GEN_CASE(BindPattern)
BOLT_GEN_CASE(LiteralPattern)
BOLT_GEN_CASE(NamedPattern)
BOLT_GEN_CASE(NestedPattern)
BOLT_GEN_CASE(ReferenceExpression)
BOLT_GEN_CASE(MatchCase)
BOLT_GEN_CASE(MatchExpression)
BOLT_GEN_CASE(MemberExpression)
BOLT_GEN_CASE(TupleExpression)
BOLT_GEN_CASE(NestedExpression)
BOLT_GEN_CASE(ConstantExpression)
BOLT_GEN_CASE(CallExpression)
BOLT_GEN_CASE(InfixExpression)
BOLT_GEN_CASE(PrefixExpression)
BOLT_GEN_CASE(RecordExpressionField)
BOLT_GEN_CASE(RecordExpression)
BOLT_GEN_CASE(ExpressionStatement)
BOLT_GEN_CASE(ReturnStatement)
BOLT_GEN_CASE(IfStatement)
BOLT_GEN_CASE(IfStatementPart)
BOLT_GEN_CASE(TypeAssert)
BOLT_GEN_CASE(Parameter)
BOLT_GEN_CASE(LetBlockBody)
BOLT_GEN_CASE(LetExprBody)
BOLT_GEN_CASE(FunctionDeclaration)
BOLT_GEN_CASE(VariableDeclaration)
BOLT_GEN_CASE(RecordDeclaration)
BOLT_GEN_CASE(RecordDeclarationField)
BOLT_GEN_CASE(VariantDeclaration)
BOLT_GEN_CASE(TupleVariantDeclarationMember)
BOLT_GEN_CASE(RecordVariantDeclarationMember)
BOLT_GEN_CASE(ClassDeclaration)
BOLT_GEN_CASE(InstanceDeclaration)
BOLT_GEN_CASE(SourceFile)
}
}
protected:
@ -248,6 +174,10 @@ namespace bolt {
visitToken(N);
}
void visitFnKeyword(FnKeyword* N) {
visitToken(N);
}
void visitMutKeyword(MutKeyword* N) {
visitToken(N);
}
@ -500,7 +430,11 @@ namespace bolt {
visitLetBody(N);
}
void visitLetDeclaration(LetDeclaration* N) {
void visitFunctionDeclaration(FunctionDeclaration* N) {
visitNode(N);
}
void visitVariableDeclaration(VariableDeclaration* N) {
visitNode(N);
}
@ -543,252 +477,96 @@ namespace bolt {
public:
void visitEachChild(Node* N) {
#define BOLT_GEN_CHILD_CASE(name) \
case NodeKind::name: \
visitEachChild(static_cast<name*>(N)); \
break;
switch (N->getKind()) {
case NodeKind::Equals:
visitEachChild(static_cast<Equals*>(N));
break;
case NodeKind::Colon:
visitEachChild(static_cast<Colon*>(N));
break;
case NodeKind::Comma:
visitEachChild(static_cast<Comma*>(N));
break;
case NodeKind::Dot:
visitEachChild(static_cast<Dot*>(N));
break;
case NodeKind::DotDot:
visitEachChild(static_cast<DotDot*>(N));
break;
case NodeKind::Tilde:
visitEachChild(static_cast<Tilde*>(N));
break;
case NodeKind::LParen:
visitEachChild(static_cast<LParen*>(N));
break;
case NodeKind::RParen:
visitEachChild(static_cast<RParen*>(N));
break;
case NodeKind::LBracket:
visitEachChild(static_cast<LBracket*>(N));
break;
case NodeKind::RBracket:
visitEachChild(static_cast<RBracket*>(N));
break;
case NodeKind::LBrace:
visitEachChild(static_cast<LBrace*>(N));
break;
case NodeKind::RBrace:
visitEachChild(static_cast<RBrace*>(N));
break;
case NodeKind::RArrow:
visitEachChild(static_cast<RArrow*>(N));
break;
case NodeKind::RArrowAlt:
visitEachChild(static_cast<RArrowAlt*>(N));
break;
case NodeKind::LetKeyword:
visitEachChild(static_cast<LetKeyword*>(N));
break;
case NodeKind::MutKeyword:
visitEachChild(static_cast<MutKeyword*>(N));
break;
case NodeKind::PubKeyword:
visitEachChild(static_cast<PubKeyword*>(N));
break;
case NodeKind::TypeKeyword:
visitEachChild(static_cast<TypeKeyword*>(N));
break;
case NodeKind::ReturnKeyword:
visitEachChild(static_cast<ReturnKeyword*>(N));
break;
case NodeKind::ModKeyword:
visitEachChild(static_cast<ModKeyword*>(N));
break;
case NodeKind::StructKeyword:
visitEachChild(static_cast<StructKeyword*>(N));
break;
case NodeKind::EnumKeyword:
visitEachChild(static_cast<EnumKeyword*>(N));
break;
case NodeKind::ClassKeyword:
visitEachChild(static_cast<ClassKeyword*>(N));
break;
case NodeKind::InstanceKeyword:
visitEachChild(static_cast<InstanceKeyword*>(N));
break;
case NodeKind::ElifKeyword:
visitEachChild(static_cast<ElifKeyword*>(N));
break;
case NodeKind::IfKeyword:
visitEachChild(static_cast<IfKeyword*>(N));
break;
case NodeKind::ElseKeyword:
visitEachChild(static_cast<ElseKeyword*>(N));
break;
case NodeKind::MatchKeyword:
visitEachChild(static_cast<MatchKeyword*>(N));
break;
case NodeKind::Invalid:
visitEachChild(static_cast<Invalid*>(N));
break;
case NodeKind::EndOfFile:
visitEachChild(static_cast<EndOfFile*>(N));
break;
case NodeKind::BlockStart:
visitEachChild(static_cast<BlockStart*>(N));
break;
case NodeKind::BlockEnd:
visitEachChild(static_cast<BlockEnd*>(N));
break;
case NodeKind::LineFoldEnd:
visitEachChild(static_cast<LineFoldEnd*>(N));
break;
case NodeKind::CustomOperator:
visitEachChild(static_cast<CustomOperator*>(N));
break;
case NodeKind::Assignment:
visitEachChild(static_cast<Assignment*>(N));
break;
case NodeKind::Identifier:
visitEachChild(static_cast<Identifier*>(N));
break;
case NodeKind::IdentifierAlt:
visitEachChild(static_cast<IdentifierAlt*>(N));
break;
case NodeKind::StringLiteral:
visitEachChild(static_cast<StringLiteral*>(N));
break;
case NodeKind::IntegerLiteral:
visitEachChild(static_cast<IntegerLiteral*>(N));
break;
case NodeKind::TypeclassConstraintExpression:
visitEachChild(static_cast<TypeclassConstraintExpression*>(N));
break;
case NodeKind::EqualityConstraintExpression:
visitEachChild(static_cast<EqualityConstraintExpression*>(N));
break;
case NodeKind::QualifiedTypeExpression:
visitEachChild(static_cast<QualifiedTypeExpression*>(N));
break;
case NodeKind::ReferenceTypeExpression:
visitEachChild(static_cast<ReferenceTypeExpression*>(N));
break;
case NodeKind::ArrowTypeExpression:
visitEachChild(static_cast<ArrowTypeExpression*>(N));
break;
case NodeKind::AppTypeExpression:
visitEachChild(static_cast<AppTypeExpression*>(N));
break;
case NodeKind::VarTypeExpression:
visitEachChild(static_cast<VarTypeExpression*>(N));
break;
case NodeKind::NestedTypeExpression:
visitEachChild(static_cast<NestedTypeExpression*>(N));
break;
case NodeKind::TupleTypeExpression:
visitEachChild(static_cast<TupleTypeExpression*>(N));
break;
case NodeKind::BindPattern:
visitEachChild(static_cast<BindPattern*>(N));
break;
case NodeKind::LiteralPattern:
visitEachChild(static_cast<LiteralPattern*>(N));
break;
case NodeKind::NamedPattern:
visitEachChild(static_cast<NamedPattern*>(N));
break;
case NodeKind::NestedPattern:
visitEachChild(static_cast<NestedPattern*>(N));
break;
case NodeKind::ReferenceExpression:
visitEachChild(static_cast<ReferenceExpression*>(N));
break;
case NodeKind::MatchCase:
visitEachChild(static_cast<MatchCase*>(N));
break;
case NodeKind::MatchExpression:
visitEachChild(static_cast<MatchExpression*>(N));
break;
case NodeKind::MemberExpression:
visitEachChild(static_cast<MemberExpression*>(N));
break;
case NodeKind::TupleExpression:
visitEachChild(static_cast<TupleExpression*>(N));
break;
case NodeKind::NestedExpression:
visitEachChild(static_cast<NestedExpression*>(N));
break;
case NodeKind::ConstantExpression:
visitEachChild(static_cast<ConstantExpression*>(N));
break;
case NodeKind::CallExpression:
visitEachChild(static_cast<CallExpression*>(N));
break;
case NodeKind::InfixExpression:
visitEachChild(static_cast<InfixExpression*>(N));
break;
case NodeKind::PrefixExpression:
visitEachChild(static_cast<PrefixExpression*>(N));
break;
case NodeKind::RecordExpressionField:
visitEachChild(static_cast<RecordExpressionField*>(N));
break;
case NodeKind::RecordExpression:
visitEachChild(static_cast<RecordExpression*>(N));
break;
case NodeKind::ExpressionStatement:
visitEachChild(static_cast<ExpressionStatement*>(N));
break;
case NodeKind::ReturnStatement:
visitEachChild(static_cast<ReturnStatement*>(N));
break;
case NodeKind::IfStatement:
visitEachChild(static_cast<IfStatement*>(N));
break;
case NodeKind::IfStatementPart:
visitEachChild(static_cast<IfStatementPart*>(N));
break;
case NodeKind::TypeAssert:
visitEachChild(static_cast<TypeAssert*>(N));
break;
case NodeKind::Parameter:
visitEachChild(static_cast<Parameter*>(N));
break;
case NodeKind::LetBlockBody:
visitEachChild(static_cast<LetBlockBody*>(N));
break;
case NodeKind::LetExprBody:
visitEachChild(static_cast<LetExprBody*>(N));
break;
case NodeKind::LetDeclaration:
visitEachChild(static_cast<LetDeclaration*>(N));
break;
case NodeKind::RecordDeclaration:
visitEachChild(static_cast<RecordDeclaration*>(N));
break;
case NodeKind::RecordDeclarationField:
visitEachChild(static_cast<RecordDeclarationField*>(N));
break;
case NodeKind::VariantDeclaration:
visitEachChild(static_cast<VariantDeclaration*>(N));
break;
case NodeKind::TupleVariantDeclarationMember:
visitEachChild(static_cast<TupleVariantDeclarationMember*>(N));
break;
case NodeKind::RecordVariantDeclarationMember:
visitEachChild(static_cast<RecordVariantDeclarationMember*>(N));
break;
case NodeKind::ClassDeclaration:
visitEachChild(static_cast<ClassDeclaration*>(N));
break;
case NodeKind::InstanceDeclaration:
visitEachChild(static_cast<InstanceDeclaration*>(N));
break;
case NodeKind::SourceFile:
visitEachChild(static_cast<SourceFile*>(N));
break;
default:
ZEN_UNREACHABLE
BOLT_GEN_CHILD_CASE(Equals)
BOLT_GEN_CHILD_CASE(Colon)
BOLT_GEN_CHILD_CASE(Comma)
BOLT_GEN_CHILD_CASE(Dot)
BOLT_GEN_CHILD_CASE(DotDot)
BOLT_GEN_CHILD_CASE(Tilde)
BOLT_GEN_CHILD_CASE(LParen)
BOLT_GEN_CHILD_CASE(RParen)
BOLT_GEN_CHILD_CASE(LBracket)
BOLT_GEN_CHILD_CASE(RBracket)
BOLT_GEN_CHILD_CASE(LBrace)
BOLT_GEN_CHILD_CASE(RBrace)
BOLT_GEN_CHILD_CASE(RArrow)
BOLT_GEN_CHILD_CASE(RArrowAlt)
BOLT_GEN_CHILD_CASE(LetKeyword)
BOLT_GEN_CHILD_CASE(FnKeyword)
BOLT_GEN_CHILD_CASE(MutKeyword)
BOLT_GEN_CHILD_CASE(PubKeyword)
BOLT_GEN_CHILD_CASE(TypeKeyword)
BOLT_GEN_CHILD_CASE(ReturnKeyword)
BOLT_GEN_CHILD_CASE(ModKeyword)
BOLT_GEN_CHILD_CASE(StructKeyword)
BOLT_GEN_CHILD_CASE(EnumKeyword)
BOLT_GEN_CHILD_CASE(ClassKeyword)
BOLT_GEN_CHILD_CASE(InstanceKeyword)
BOLT_GEN_CHILD_CASE(ElifKeyword)
BOLT_GEN_CHILD_CASE(IfKeyword)
BOLT_GEN_CHILD_CASE(ElseKeyword)
BOLT_GEN_CHILD_CASE(MatchKeyword)
BOLT_GEN_CHILD_CASE(Invalid)
BOLT_GEN_CHILD_CASE(EndOfFile)
BOLT_GEN_CHILD_CASE(BlockStart)
BOLT_GEN_CHILD_CASE(BlockEnd)
BOLT_GEN_CHILD_CASE(LineFoldEnd)
BOLT_GEN_CHILD_CASE(CustomOperator)
BOLT_GEN_CHILD_CASE(Assignment)
BOLT_GEN_CHILD_CASE(Identifier)
BOLT_GEN_CHILD_CASE(IdentifierAlt)
BOLT_GEN_CHILD_CASE(StringLiteral)
BOLT_GEN_CHILD_CASE(IntegerLiteral)
BOLT_GEN_CHILD_CASE(TypeclassConstraintExpression)
BOLT_GEN_CHILD_CASE(EqualityConstraintExpression)
BOLT_GEN_CHILD_CASE(QualifiedTypeExpression)
BOLT_GEN_CHILD_CASE(ReferenceTypeExpression)
BOLT_GEN_CHILD_CASE(ArrowTypeExpression)
BOLT_GEN_CHILD_CASE(AppTypeExpression)
BOLT_GEN_CHILD_CASE(VarTypeExpression)
BOLT_GEN_CHILD_CASE(NestedTypeExpression)
BOLT_GEN_CHILD_CASE(TupleTypeExpression)
BOLT_GEN_CHILD_CASE(BindPattern)
BOLT_GEN_CHILD_CASE(LiteralPattern)
BOLT_GEN_CHILD_CASE(NamedPattern)
BOLT_GEN_CHILD_CASE(NestedPattern)
BOLT_GEN_CHILD_CASE(ReferenceExpression)
BOLT_GEN_CHILD_CASE(MatchCase)
BOLT_GEN_CHILD_CASE(MatchExpression)
BOLT_GEN_CHILD_CASE(MemberExpression)
BOLT_GEN_CHILD_CASE(TupleExpression)
BOLT_GEN_CHILD_CASE(NestedExpression)
BOLT_GEN_CHILD_CASE(ConstantExpression)
BOLT_GEN_CHILD_CASE(CallExpression)
BOLT_GEN_CHILD_CASE(InfixExpression)
BOLT_GEN_CHILD_CASE(PrefixExpression)
BOLT_GEN_CHILD_CASE(RecordExpressionField)
BOLT_GEN_CHILD_CASE(RecordExpression)
BOLT_GEN_CHILD_CASE(ExpressionStatement)
BOLT_GEN_CHILD_CASE(ReturnStatement)
BOLT_GEN_CHILD_CASE(IfStatement)
BOLT_GEN_CHILD_CASE(IfStatementPart)
BOLT_GEN_CHILD_CASE(TypeAssert)
BOLT_GEN_CHILD_CASE(Parameter)
BOLT_GEN_CHILD_CASE(LetBlockBody)
BOLT_GEN_CHILD_CASE(LetExprBody)
BOLT_GEN_CHILD_CASE(FunctionDeclaration)
BOLT_GEN_CHILD_CASE(VariableDeclaration)
BOLT_GEN_CHILD_CASE(RecordDeclaration)
BOLT_GEN_CHILD_CASE(RecordDeclarationField)
BOLT_GEN_CHILD_CASE(VariantDeclaration)
BOLT_GEN_CHILD_CASE(TupleVariantDeclarationMember)
BOLT_GEN_CHILD_CASE(RecordVariantDeclarationMember)
BOLT_GEN_CHILD_CASE(ClassDeclaration)
BOLT_GEN_CHILD_CASE(InstanceDeclaration)
BOLT_GEN_CHILD_CASE(SourceFile)
}
}
@ -839,6 +617,9 @@ namespace bolt {
void visitEachChild(LetKeyword* N) {
}
void visitEachChild(FnKeyword* N) {
}
void visitEachChild(MutKeyword* N) {
}
@ -1136,18 +917,29 @@ namespace bolt {
BOLT_VISIT(N->Expression);
}
void visitEachChild(LetDeclaration* N) {
void visitEachChild(FunctionDeclaration* N) {
if (N->PubKeyword) {
BOLT_VISIT(N->PubKeyword);
}
BOLT_VISIT(N->FnKeyword);
BOLT_VISIT(N->Name);
for (auto Param: N->Params) {
BOLT_VISIT(Param);
}
if (N->TypeAssert) {
BOLT_VISIT(N->TypeAssert);
}
if (N->Body) {
BOLT_VISIT(N->Body);
}
}
void visitEachChild(VariableDeclaration* N) {
if (N->PubKeyword) {
BOLT_VISIT(N->PubKeyword);
}
BOLT_VISIT(N->LetKeyword);
if (N->MutKeyword) {
BOLT_VISIT(N->MutKeyword);
}
BOLT_VISIT(N->Pattern);
for (auto Param: N->Params) {
BOLT_VISIT(Param);
}
if (N->TypeAssert) {
BOLT_VISIT(N->TypeAssert);
}

View file

@ -197,16 +197,16 @@ namespace bolt {
void addConstraint(Constraint* Constraint);
void forwardDeclare(Node* Node);
void forwardDeclareLetDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints);
void forwardDeclareFunctionDeclaration(FunctionDeclaration* N, TVSet* TVs, ConstraintSet* Constraints);
Type* inferExpression(Expression* Expression);
Type* inferTypeExpression(TypeExpression* TE);
Type* inferTypeExpression(TypeExpression* TE, bool IsPoly = true);
Type* inferLiteral(Literal* Lit);
Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet);
void infer(Node* node);
void inferLetDeclaration(LetDeclaration* N);
void inferFunctionDeclaration(FunctionDeclaration* N);
Constraint* convertToConstraint(ConstraintExpression* C);

View file

@ -9,7 +9,7 @@ namespace bolt {
ConfigFlags_TypeVarsRequireForall = 1 << 0,
};
unsigned Flags;
unsigned Flags = 0;
public:

View file

@ -109,9 +109,9 @@ namespace bolt {
public:
TypeclassSignature Sig;
LetDeclaration* Decl;
FunctionDeclaration* Decl;
inline TypeclassMissingDiagnostic(TypeclassSignature Sig, LetDeclaration* Decl):
inline TypeclassMissingDiagnostic(TypeclassSignature Sig, FunctionDeclaration* Decl):
Diagnostic(DiagnosticKind::TypeclassMissing), Sig(Sig), Decl(Decl) {}
inline Node* getNode() const override {

View file

@ -124,7 +124,9 @@ namespace bolt {
Node* parseLetBodyElement();
LetDeclaration* parseLetDeclaration();
FunctionDeclaration* parseFunctionDeclaration();
VariableDeclaration* parseVariableDeclaration();
Node* parseClassElement();

View file

@ -69,9 +69,9 @@ namespace bolt {
}
break;
}
case NodeKind::LetDeclaration:
case NodeKind::FunctionDeclaration:
{
auto Decl = static_cast<LetDeclaration*>(X);
auto Decl = static_cast<FunctionDeclaration*>(X);
for (auto Param: Decl->Params) {
visitPattern(Param->Pattern, Param);
}
@ -112,12 +112,18 @@ namespace bolt {
}
break;
}
case NodeKind::LetDeclaration:
case NodeKind::VariableDeclaration:
{
auto Decl = static_cast<LetDeclaration*>(X);
auto Decl = static_cast<VariableDeclaration*>(X);
visitPattern(Decl->Pattern, Decl);
break;
}
case NodeKind::FunctionDeclaration:
{
auto Decl = static_cast<FunctionDeclaration*>(X);
addSymbol(Decl->Name->getCanonicalText(), Decl, SymbolKind::Var);
break;
}
case NodeKind::RecordDeclaration:
{
auto Decl = static_cast<RecordDeclaration*>(X);
@ -597,14 +603,14 @@ namespace bolt {
return Expression->getLastToken();
}
Token* LetDeclaration::getFirstToken() const {
Token* FunctionDeclaration::getFirstToken() const {
if (PubKeyword) {
return PubKeyword;
}
return LetKeyword;
return FnKeyword;
}
Token* LetDeclaration::getLastToken() const {
Token* FunctionDeclaration::getLastToken() const {
if (Body) {
return Body->getLastToken();
}
@ -614,6 +620,23 @@ namespace bolt {
if (Params.size()) {
return Params.back()->getLastToken();
}
return Name;
}
Token* VariableDeclaration::getFirstToken() const {
if (PubKeyword) {
return PubKeyword;
}
return LetKeyword;
}
Token* VariableDeclaration::getLastToken() const {
if (Body) {
return Body->getLastToken();
}
if (TypeAssert) {
return TypeAssert->getLastToken();
}
return Pattern->getLastToken();
}
@ -766,6 +789,10 @@ namespace bolt {
return "let";
}
std::string FnKeyword::getText() const {
return "fn";
}
std::string MutKeyword::getText() const {
return "mut";
}

View file

@ -12,6 +12,12 @@
// TODO see if we can merge UnificationError diagnostics so that we get a list of **all** types that were wrong on a given node
// TODO When a forall variable is missing, do not just insert a blank one into the env. It will result in too few diagnostics being emitted.
// Same goes for reference expressions.
// If running the compiler as a language server, this matters.
// TODO Add a pattern that only performs a type assert
#include <algorithm>
#include <iterator>
#include <stack>
@ -296,10 +302,14 @@ namespace bolt {
break;
}
case NodeKind::LetDeclaration:
case NodeKind::FunctionDeclaration:
// These declarations will be handled separately in check()
break;
case NodeKind::VariableDeclaration:
// All of this node's semantics will be handled in infer()
break;
case NodeKind::VariantDeclaration:
{
auto Decl = static_cast<VariantDeclaration*>(X);
@ -376,8 +386,8 @@ namespace bolt {
for (auto TV: Vars) {
RetTy = new TApp(RetTy, TV);
}
Decl->Ctx->Parent->Env.emplace(Name, new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow({ FieldsTy }, RetTy)));
popContext();
addBinding(Name, new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow({ FieldsTy }, RetTy)));
break;
}
@ -423,18 +433,18 @@ namespace bolt {
Contexts.pop();
}
void visitLetDeclaration(LetDeclaration* Let) {
if (Let->isFunc()) {
Let->Ctx = createDerivedContext();
Contexts.push(Let->Ctx);
visitEachChild(Let);
Contexts.pop();
} else {
Let->Ctx = Contexts.top();
visitEachChild(Let);
}
void visitFunctionDeclaration(FunctionDeclaration* Let) {
Let->Ctx = createDerivedContext();
Contexts.push(Let->Ctx);
visitEachChild(Let);
Contexts.pop();
}
// void visitVariableDeclaration(VariableDeclaration* Var) {
// Var->Ctx = Contexts.top();
// visitEachChild(Var);
// }
};
Init I { {}, *this };
@ -442,9 +452,7 @@ namespace bolt {
}
void Checker::forwardDeclareLetDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints) {
auto Let = static_cast<LetDeclaration*>(N);
void Checker::forwardDeclareFunctionDeclaration(FunctionDeclaration* Let, TVSet* TVs, ConstraintSet* Constraints) {
setContext(Let->Ctx);
@ -495,7 +503,7 @@ namespace bolt {
Params.push_back(TV);
}
auto SigLet = llvm::cast<LetDeclaration>(Class->getScope()->lookupDirect({ {}, llvm::cast<BindPattern>(Let->Pattern)->Name->getCanonicalText() }, SymbolKind::Var));
auto SigLet = llvm::cast<FunctionDeclaration>(Class->getScope()->lookupDirect({ {}, Let->Name->getCanonicalText() }, SymbolKind::Var));
// It would be very strange if there was no type assert in the type
// class let-declaration but we rather not let the compiler crash if that happens.
@ -520,9 +528,7 @@ namespace bolt {
case NodeKind::LetBlockBody:
{
auto Block = static_cast<LetBlockBody*>(Let->Body);
if (Let->isFunc()) {
Let->Ctx->ReturnType = createTypeVar();
}
Let->Ctx->ReturnType = createTypeVar();
for (auto Element: Block->Elements) {
forwardDeclare(Element);
}
@ -533,19 +539,11 @@ namespace bolt {
}
}
Type* BindTy;
if (Let->isFunc()) {
popContext();
BindTy = inferPattern(Let->Pattern, Let->Ctx->Constraints, Let->Ctx->TVs);
} else {
BindTy = inferPattern(Let->Pattern);
}
addConstraint(new CEqual(BindTy, Ty, Let));
Let->Ctx->Parent->Env.emplace(Let->Name->getCanonicalText(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty));
}
void Checker::inferLetDeclaration(LetDeclaration* Decl) {
void Checker::inferFunctionDeclaration(FunctionDeclaration* Decl) {
setContext(Decl->Ctx);
@ -553,7 +551,6 @@ namespace bolt {
Type* RetType;
for (auto Param: Decl->Params) {
// TODO incorporate Param->TypeAssert or make it a kind of pattern
ParamTypes.push_back(inferPattern(Param->Pattern));
}
@ -568,7 +565,6 @@ namespace bolt {
case NodeKind::LetBlockBody:
{
auto Block = static_cast<LetBlockBody*>(Decl->Body);
ZEN_ASSERT(Decl->isFunc());
RetType = Decl->Ctx->ReturnType;
for (auto Element: Block->Elements) {
infer(Element);
@ -582,13 +578,7 @@ namespace bolt {
RetType = createTypeVar();
}
if (Decl->isFunc()) {
popContext();
addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), Decl });
} else {
// Declaration is a plain (typed) variable
addConstraint(new CEqual { Decl->Ty, RetType, Decl });
}
addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), Decl });
}
@ -642,7 +632,7 @@ namespace bolt {
break;
}
case NodeKind::LetDeclaration:
case NodeKind::FunctionDeclaration:
break;
case NodeKind::ReturnStatement:
@ -658,6 +648,33 @@ namespace bolt {
break;
}
case NodeKind::VariableDeclaration:
{
auto Decl = static_cast<VariableDeclaration*>(N);
Type* Ty = nullptr;
if (Decl->TypeAssert) {
Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression, false);
}
if (Decl->Body) {
ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody);
auto E = static_cast<LetExprBody*>(Decl->Body);
auto Ty2 = inferExpression(E->Expression);
if (Ty) {
addConstraint(new CEqual(Ty, Ty2, Decl));
} else {
Ty = Ty2;
}
}
auto Ty3 = inferPattern(Decl->Pattern);
if (Ty) {
addConstraint(new CEqual(Ty, Ty3, Decl));
} else {
Ty = Ty3;
}
Decl->setType(Ty);
break;
}
case NodeKind::ExpressionStatement:
{
auto ExprStmt = static_cast<ExpressionStatement*>(N);
@ -764,7 +781,7 @@ namespace bolt {
}
}
Type* Checker::inferTypeExpression(TypeExpression* N) {
Type* Checker::inferTypeExpression(TypeExpression* N, bool IsPoly) {
switch (N->getKind()) {
@ -786,9 +803,9 @@ namespace bolt {
case NodeKind::AppTypeExpression:
{
auto AppTE = static_cast<AppTypeExpression*>(N);
Type* Ty = inferTypeExpression(AppTE->Op);
Type* Ty = inferTypeExpression(AppTE->Op, IsPoly);
for (auto Arg: AppTE->Args) {
Ty = new TApp(Ty, inferTypeExpression(Arg));
Ty = new TApp(Ty, inferTypeExpression(Arg, IsPoly));
}
return Ty;
}
@ -798,10 +815,10 @@ namespace bolt {
auto VarTE = static_cast<VarTypeExpression*>(N);
auto Ty = lookupMono(VarTE->Name->getCanonicalText());
if (Ty == nullptr) {
if (Config.typeVarsRequireForall()) {
if (IsPoly && Config.typeVarsRequireForall()) {
DE.add<BindingNotFoundDiagnostic>(VarTE->Name->getCanonicalText(), VarTE->Name);
}
Ty = createRigidVar(VarTE->Name->getCanonicalText());
Ty = IsPoly ? createRigidVar(VarTE->Name->getCanonicalText()) : createTypeVar();
addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty));
}
ZEN_ASSERT(Ty->getKind() == TypeKind::Var);
@ -814,7 +831,7 @@ namespace bolt {
auto TupleTE = static_cast<TupleTypeExpression*>(N);
std::vector<Type*> ElementTypes;
for (auto [TE, Comma]: TupleTE->Elements) {
ElementTypes.push_back(inferTypeExpression(TE));
ElementTypes.push_back(inferTypeExpression(TE, IsPoly));
}
auto Ty = new TTuple(ElementTypes);
N->setType(Ty);
@ -824,7 +841,7 @@ namespace bolt {
case NodeKind::NestedTypeExpression:
{
auto NestedTE = static_cast<NestedTypeExpression*>(N);
auto Ty = inferTypeExpression(NestedTE->TE);
auto Ty = inferTypeExpression(NestedTE->TE, IsPoly);
N->setType(Ty);
return Ty;
}
@ -834,9 +851,9 @@ namespace bolt {
auto ArrowTE = static_cast<ArrowTypeExpression*>(N);
std::vector<Type*> ParamTypes;
for (auto ParamType: ArrowTE->ParamTypes) {
ParamTypes.push_back(inferTypeExpression(ParamType));
ParamTypes.push_back(inferTypeExpression(ParamType, IsPoly));
}
auto ReturnType = inferTypeExpression(ArrowTE->ReturnType);
auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, IsPoly);
auto Ty = new TArrow(ParamTypes, ReturnType);
N->setType(Ty);
return Ty;
@ -848,7 +865,7 @@ namespace bolt {
for (auto [C, Comma]: QTE->Constraints) {
addConstraint(convertToConstraint(C));
}
auto Ty = inferTypeExpression(QTE->TE);
auto Ty = inferTypeExpression(QTE->TE, IsPoly);
N->setType(Ty);
return Ty;
}
@ -889,12 +906,13 @@ namespace bolt {
}
Ty = createTypeVar();
for (auto Case: Match->Cases) {
auto OldCtx = &getContext();
setContext(Case->Ctx);
auto PattTy = inferPattern(Case->Pattern);
addConstraint(new CEqual(PattTy, ValTy, Case));
auto ExprTy = inferExpression(Case->Expression);
addConstraint(new CEqual(ExprTy, Ty, Case->Expression));
popContext();
setContext(OldCtx);
}
if (!Match->Value) {
Ty = new TArrow({ ValTy }, Ty);
@ -925,8 +943,8 @@ namespace bolt {
auto Ref = static_cast<ReferenceExpression*>(X);
ZEN_ASSERT(Ref->ModulePath.empty());
auto Target = Ref->getScope()->lookup(Ref->getSymbolPath());
if (Target && llvm::isa<LetDeclaration>(Target)) {
auto Let = static_cast<LetDeclaration*>(Target);
if (Target && llvm::isa<FunctionDeclaration>(Target)) {
auto Let = static_cast<FunctionDeclaration*>(Target);
if (Let->IsCycleActive) {
return Let->Ty;
}
@ -1100,7 +1118,7 @@ namespace bolt {
std::stack<Node*> Stack;
void visitLetDeclaration(LetDeclaration* N) {
void visitFunctionDeclaration(FunctionDeclaration* N) {
RefGraph.addVertex(N);
Stack.push(N);
visitEachChild(N);
@ -1121,7 +1139,7 @@ namespace bolt {
RefGraph.addEdge(Stack.top(), Def->Parent);
return;
}
ZEN_ASSERT(Def->getKind() == NodeKind::LetDeclaration);
ZEN_ASSERT(Def->getKind() == NodeKind::FunctionDeclaration || Def->getKind() == NodeKind::VariableDeclaration);
if (!Stack.empty()) {
RefGraph.addEdge(Def, Stack.top());
}
@ -1140,7 +1158,7 @@ namespace bolt {
Checker& C;
void visitLetDeclaration(LetDeclaration* Decl) {
void visitLetDeclaration(FunctionDeclaration* Decl) {
// Only inspect those let-declarations that look like a function
if (Decl->Params.empty()) {
@ -1289,26 +1307,26 @@ namespace bolt {
auto TVs = new TVSet;
auto Constraints = new ConstraintSet;
for (auto N: Nodes) {
auto Decl = static_cast<LetDeclaration*>(N);
forwardDeclareLetDeclaration(Decl, TVs, Constraints);
auto Decl = static_cast<FunctionDeclaration*>(N);
forwardDeclareFunctionDeclaration(Decl, TVs, Constraints);
}
}
for (auto Nodes: SCCs) {
for (auto N: Nodes) {
auto Decl = static_cast<LetDeclaration*>(N);
auto Decl = static_cast<FunctionDeclaration*>(N);
Decl->IsCycleActive = true;
}
for (auto N: Nodes) {
auto Decl = static_cast<LetDeclaration*>(N);
inferLetDeclaration(Decl);
auto Decl = static_cast<FunctionDeclaration*>(N);
inferFunctionDeclaration(Decl);
}
for (auto N: Nodes) {
auto Decl = static_cast<LetDeclaration*>(N);
auto Decl = static_cast<FunctionDeclaration*>(N);
Decl->IsCycleActive = false;
}
}
setContext(SF->Ctx);
infer(SF);
popContext();
solve(new CMany(*SF->Ctx->Constraints));
checkTypeclassSigs(SF);
}

View file

@ -100,6 +100,8 @@ namespace bolt {
return "'pub'";
case NodeKind::LetKeyword:
return "'let'";
case NodeKind::FnKeyword:
return "'fn'";
case NodeKind::MutKeyword:
return "'mut'";
case NodeKind::MatchKeyword:
@ -108,8 +110,10 @@ namespace bolt {
return "'return'";
case NodeKind::TypeKeyword:
return "'type'";
case NodeKind::LetDeclaration:
return "a let-declaration";
case NodeKind::FunctionDeclaration:
return "a function declaration";
case NodeKind::VariableDeclaration:
return "a variable declaration";
case NodeKind::CallExpression:
return "a call-expression";
case NodeKind::InfixExpression:

View file

@ -853,7 +853,7 @@ finish:
return new IfStatement(Parts);
}
LetDeclaration* Parser::parseLetDeclaration() {
VariableDeclaration* Parser::parseVariableDeclaration() {
PubKeyword* Pub = nullptr;
LetKeyword* Let;
@ -881,8 +881,8 @@ finish:
Tokens.get();
}
auto Patt = parsePattern();
if (!Patt) {
auto P = parsePattern();
if (!P) {
if (Pub) {
Pub->unref();
}
@ -894,27 +894,7 @@ finish:
return nullptr;
}
std::vector<Parameter*> Params;
Token* T2;
for (;;) {
T2 = Tokens.peek();
switch (T2->getKind()) {
case NodeKind::LineFoldEnd:
case NodeKind::BlockStart:
case NodeKind::Equals:
case NodeKind::Colon:
goto after_params;
default:
auto P = parsePattern();
if (P == nullptr) {
P = new BindPattern(new Identifier("_"));
}
Params.push_back(new Parameter(P, nullptr));
}
}
after_params:
auto T2 = Tokens.peek();
if (T2->getKind() == NodeKind::Colon) {
Tokens.get();
auto TE = parseTypeExpression();
@ -972,16 +952,137 @@ after_params:
DE.add<UnexpectedTokenDiagnostic>(File, T2, Expected);
}
after_body:
checkLineFoldEnd();
finish:
return new VariableDeclaration { Pub, Let, Mut, P, TA, Body };
}
FunctionDeclaration* Parser::parseFunctionDeclaration() {
PubKeyword* Pub = nullptr;
FnKeyword* Fn;
MutKeyword* Mut = nullptr;
TypeAssert* TA = nullptr;
LetBody* Body = nullptr;
auto T0 = Tokens.get();
if (T0->getKind() == NodeKind::PubKeyword) {
Pub = static_cast<PubKeyword*>(T0);
T0 = Tokens.get();
}
if (T0->getKind() != NodeKind::FnKeyword) {
DE.add<UnexpectedTokenDiagnostic>(File, T0, std::vector { NodeKind::FnKeyword });
if (Pub) {
Pub->unref();
}
skipToLineFoldEnd();
return nullptr;
}
Fn = static_cast<FnKeyword*>(T0);
auto T1 = Tokens.peek();
if (T1->getKind() == NodeKind::MutKeyword) {
Mut = static_cast<MutKeyword*>(T1);
Tokens.get();
}
auto Name = expectToken<Identifier>();
if (!Name) {
if (Pub) {
Pub->unref();
}
Fn->unref();
if (Mut) {
Mut->unref();
}
skipToLineFoldEnd();
return nullptr;
}
std::vector<Parameter*> Params;
Token* T2;
for (;;) {
T2 = Tokens.peek();
switch (T2->getKind()) {
case NodeKind::LineFoldEnd:
case NodeKind::BlockStart:
case NodeKind::Equals:
case NodeKind::Colon:
goto after_params;
default:
auto P = parsePattern();
if (!P) {
P = new BindPattern(new Identifier("_"));
}
Params.push_back(new Parameter(P, nullptr));
}
}
after_params:
if (T2->getKind() == NodeKind::Colon) {
Tokens.get();
auto TE = parseTypeExpression();
if (TE) {
TA = new TypeAssert(static_cast<Colon*>(T2), TE);
} else {
skipToLineFoldEnd();
goto finish;
}
T2 = Tokens.peek();
}
switch (T2->getKind()) {
case NodeKind::BlockStart:
{
Tokens.get();
std::vector<Node*> Elements;
for (;;) {
auto T3 = Tokens.peek();
if (T3->getKind() == NodeKind::BlockEnd) {
break;
}
auto Element = parseLetBodyElement();
if (Element) {
Elements.push_back(Element);
}
}
Tokens.get()->unref(); // Always a BlockEnd
Body = new LetBlockBody(static_cast<BlockStart*>(T2), Elements);
break;
}
case NodeKind::Equals:
{
Tokens.get();
auto E = parseExpression();
if (!E) {
skipToLineFoldEnd();
goto finish;
}
Body = new LetExprBody(static_cast<Equals*>(T2), E);
break;
}
case NodeKind::LineFoldEnd:
break;
default:
std::vector<NodeKind> Expected { NodeKind::BlockStart, NodeKind::LineFoldEnd, NodeKind::Equals };
if (TA == nullptr) {
// First tokens of TypeAssert
Expected.push_back(NodeKind::Colon);
// First tokens of Pattern
Expected.push_back(NodeKind::Identifier);
}
DE.add<UnexpectedTokenDiagnostic>(File, T2, Expected);
}
checkLineFoldEnd();
finish:
return new LetDeclaration(
return new FunctionDeclaration(
Pub,
Let,
Mut,
Patt,
Fn,
Name,
Params,
TA,
Body
@ -992,7 +1093,9 @@ finish:
auto T0 = peekFirstTokenAfterModifiers();
switch (T0->getKind()) {
case NodeKind::LetKeyword:
return parseLetDeclaration();
return parseVariableDeclaration();
case NodeKind::FnKeyword:
return parseFunctionDeclaration();
case NodeKind::ReturnKeyword:
return parseReturnStatement();
case NodeKind::IfKeyword:
@ -1396,12 +1499,12 @@ next_member:
Node* Parser::parseClassElement() {
auto T0 = Tokens.peek();
switch (T0->getKind()) {
case NodeKind::LetKeyword:
return parseLetDeclaration();
case NodeKind::FnKeyword:
return parseFunctionDeclaration();
case NodeKind::TypeKeyword:
// TODO
default:
DE.add<UnexpectedTokenDiagnostic>(File, T0, std::vector<NodeKind> { NodeKind::LetKeyword, NodeKind::TypeKeyword });
DE.add<UnexpectedTokenDiagnostic>(File, T0, std::vector<NodeKind> { NodeKind::FnKeyword, NodeKind::TypeKeyword });
skipToLineFoldEnd();
return nullptr;
}
@ -1411,7 +1514,9 @@ next_member:
auto T0 = peekFirstTokenAfterModifiers();
switch (T0->getKind()) {
case NodeKind::LetKeyword:
return parseLetDeclaration();
return parseVariableDeclaration();
case NodeKind::FnKeyword:
return parseFunctionDeclaration();
case NodeKind::IfKeyword:
return parseIfStatement();
case NodeKind::ClassKeyword:

View file

@ -62,6 +62,7 @@ namespace bolt {
std::unordered_map<ByteString, NodeKind> Keywords = {
{ "pub", NodeKind::PubKeyword },
{ "let", NodeKind::LetKeyword },
{ "fn", NodeKind::FnKeyword },
{ "mut", NodeKind::MutKeyword },
{ "return", NodeKind::ReturnKeyword },
{ "type", NodeKind::TypeKeyword },
@ -226,6 +227,8 @@ digit_finish:
return new PubKeyword(StartLoc);
case NodeKind::LetKeyword:
return new LetKeyword(StartLoc);
case NodeKind::FnKeyword:
return new FnKeyword(StartLoc);
case NodeKind::MutKeyword:
return new MutKeyword(StartLoc);
case NodeKind::TypeKeyword: