Split let-declaration up into function/variable declarations

This commit is contained in:
Sam Vervaeck 2023-05-30 21:34:40 +02:00
parent 053d45868e
commit 87af4686b7
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
11 changed files with 545 additions and 551 deletions

View file

@ -103,6 +103,7 @@ namespace bolt {
RArrow, RArrow,
RArrowAlt, RArrowAlt,
LetKeyword, LetKeyword,
FnKeyword,
MutKeyword, MutKeyword,
PubKeyword, PubKeyword,
TypeKeyword, TypeKeyword,
@ -160,7 +161,8 @@ namespace bolt {
Parameter, Parameter,
LetBlockBody, LetBlockBody,
LetExprBody, LetExprBody,
LetDeclaration, FunctionDeclaration,
VariableDeclaration,
RecordDeclarationField, RecordDeclarationField,
RecordDeclaration, RecordDeclaration,
VariantDeclaration, VariantDeclaration,
@ -549,6 +551,20 @@ namespace bolt {
}; };
class FnKeyword : public Token {
public:
inline FnKeyword(TextLoc StartLoc):
Token(NodeKind::FnKeyword, StartLoc) {}
std::string getText() const override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::FnKeyword;
}
};
class MutKeyword : public Token { class MutKeyword : public Token {
public: public:
@ -1661,7 +1677,7 @@ namespace bolt {
}; };
class LetDeclaration : public Node { class FunctionDeclaration : public Node {
Scope* TheScope = nullptr; Scope* TheScope = nullptr;
@ -1672,26 +1688,23 @@ namespace bolt {
class Type* Ty; class Type* Ty;
class PubKeyword* PubKeyword; class PubKeyword* PubKeyword;
class LetKeyword* LetKeyword; class FnKeyword* FnKeyword;
class MutKeyword* MutKeyword; class Identifier* Name;
class Pattern* Pattern;
std::vector<Parameter*> Params; std::vector<Parameter*> Params;
class TypeAssert* TypeAssert; class TypeAssert* TypeAssert;
LetBody* Body; LetBody* Body;
LetDeclaration( FunctionDeclaration(
class PubKeyword* PubKeyword, class PubKeyword* PubKeyword,
class LetKeyword* LetKeywod, class FnKeyword* FnKeyword,
class MutKeyword* MutKeyword, class Identifier* Name,
class Pattern* Pattern,
std::vector<Parameter*> Params, std::vector<Parameter*> Params,
class TypeAssert* TypeAssert, class TypeAssert* TypeAssert,
LetBody* Body LetBody* Body
): Node(NodeKind::LetDeclaration), ): Node(NodeKind::FunctionDeclaration),
PubKeyword(PubKeyword), PubKeyword(PubKeyword),
LetKeyword(LetKeywod), FnKeyword(FnKeyword),
MutKeyword(MutKeyword), Name(Name),
Pattern(Pattern),
Params(Params), Params(Params),
TypeAssert(TypeAssert), TypeAssert(TypeAssert),
Body(Body) {} Body(Body) {}
@ -1703,14 +1716,6 @@ namespace bolt {
return TheScope; return TheScope;
} }
bool isFunc() const noexcept {
return !Params.empty();
}
bool isVar() const noexcept {
return !isFunc();
}
bool isInstance() const noexcept { bool isInstance() const noexcept {
return Parent->getKind() == NodeKind::InstanceDeclaration; return Parent->getKind() == NodeKind::InstanceDeclaration;
} }
@ -1723,11 +1728,47 @@ namespace bolt {
Token* getLastToken() const override; Token* getLastToken() const override;
static bool classof(const Node* N) { static bool classof(const Node* N) {
return N->getKind() == NodeKind::LetDeclaration; return N->getKind() == NodeKind::FunctionDeclaration;
} }
}; };
class VariableDeclaration : public TypedNode {
Scope* TheScope = nullptr;
public:
bool IsCycleActive = false;
class PubKeyword* PubKeyword;
class LetKeyword* LetKeyword;
class MutKeyword* MutKeyword;
class Pattern* Pattern;
std::vector<Parameter*> Params;
class TypeAssert* TypeAssert;
LetBody* Body;
VariableDeclaration(
class PubKeyword* PubKeyword,
class LetKeyword* LetKeyword,
class MutKeyword* MutKeyword,
class Pattern* Pattern,
class TypeAssert* TypeAssert,
LetBody* Body
): TypedNode(NodeKind::VariableDeclaration),
PubKeyword(PubKeyword),
LetKeyword(LetKeyword),
MutKeyword(MutKeyword),
Pattern(Pattern),
TypeAssert(TypeAssert),
Body(Body) {}
Token* getFirstToken() const override;
Token* getLastToken() const override;
};
class InstanceDeclaration : public Node { class InstanceDeclaration : public Node {
public: public:
@ -1977,6 +2018,7 @@ namespace bolt {
template<> inline NodeKind getNodeType<RArrow>() { return NodeKind::RArrow; } template<> inline NodeKind getNodeType<RArrow>() { return NodeKind::RArrow; }
template<> inline NodeKind getNodeType<RArrowAlt>() { return NodeKind::RArrowAlt; } template<> inline NodeKind getNodeType<RArrowAlt>() { return NodeKind::RArrowAlt; }
template<> inline NodeKind getNodeType<LetKeyword>() { return NodeKind::LetKeyword; } template<> inline NodeKind getNodeType<LetKeyword>() { return NodeKind::LetKeyword; }
template<> inline NodeKind getNodeType<FnKeyword>() { return NodeKind::FnKeyword; }
template<> inline NodeKind getNodeType<MutKeyword>() { return NodeKind::MutKeyword; } template<> inline NodeKind getNodeType<MutKeyword>() { return NodeKind::MutKeyword; }
template<> inline NodeKind getNodeType<PubKeyword>() { return NodeKind::PubKeyword; } template<> inline NodeKind getNodeType<PubKeyword>() { return NodeKind::PubKeyword; }
template<> inline NodeKind getNodeType<TypeKeyword>() { return NodeKind::TypeKeyword; } template<> inline NodeKind getNodeType<TypeKeyword>() { return NodeKind::TypeKeyword; }
@ -2019,7 +2061,8 @@ namespace bolt {
template<> inline NodeKind getNodeType<Parameter>() { return NodeKind::Parameter; } template<> inline NodeKind getNodeType<Parameter>() { return NodeKind::Parameter; }
template<> inline NodeKind getNodeType<LetBlockBody>() { return NodeKind::LetBlockBody; } template<> inline NodeKind getNodeType<LetBlockBody>() { return NodeKind::LetBlockBody; }
template<> inline NodeKind getNodeType<LetExprBody>() { return NodeKind::LetExprBody; } template<> inline NodeKind getNodeType<LetExprBody>() { return NodeKind::LetExprBody; }
template<> inline NodeKind getNodeType<LetDeclaration>() { return NodeKind::LetDeclaration; } template<> inline NodeKind getNodeType<FunctionDeclaration>() { return NodeKind::FunctionDeclaration; }
template<> inline NodeKind getNodeType<VariableDeclaration>() { return NodeKind::VariableDeclaration; }
template<> inline NodeKind getNodeType<RecordDeclarationField>() { return NodeKind::RecordDeclarationField; } template<> inline NodeKind getNodeType<RecordDeclarationField>() { return NodeKind::RecordDeclarationField; }
template<> inline NodeKind getNodeType<RecordDeclaration>() { return NodeKind::RecordDeclaration; } template<> inline NodeKind getNodeType<RecordDeclaration>() { return NodeKind::RecordDeclaration; }
template<> inline NodeKind getNodeType<ClassDeclaration>() { return NodeKind::ClassDeclaration; } template<> inline NodeKind getNodeType<ClassDeclaration>() { return NodeKind::ClassDeclaration; }

View file

@ -12,170 +12,96 @@ namespace bolt {
public: public:
void visit(Node* N) { void visit(Node* N) {
#define BOLT_GEN_CASE(name) \
case NodeKind::name: \
return static_cast<D*>(this)->visit ## name(static_cast<name*>(N));
switch (N->getKind()) { switch (N->getKind()) {
case NodeKind::Equals: BOLT_GEN_CASE(Equals)
return static_cast<D*>(this)->visitEquals(static_cast<Equals*>(N)); BOLT_GEN_CASE(Colon)
case NodeKind::Colon: BOLT_GEN_CASE(Comma)
return static_cast<D*>(this)->visitColon(static_cast<Colon*>(N)); BOLT_GEN_CASE(Dot)
case NodeKind::Comma: BOLT_GEN_CASE(DotDot)
return static_cast<D*>(this)->visitComma(static_cast<Comma*>(N)); BOLT_GEN_CASE(Tilde)
case NodeKind::Dot: BOLT_GEN_CASE(LParen)
return static_cast<D*>(this)->visitDot(static_cast<Dot*>(N)); BOLT_GEN_CASE(RParen)
case NodeKind::DotDot: BOLT_GEN_CASE(LBracket)
return static_cast<D*>(this)->visitDotDot(static_cast<DotDot*>(N)); BOLT_GEN_CASE(RBracket)
case NodeKind::Tilde: BOLT_GEN_CASE(LBrace)
return static_cast<D*>(this)->visitTilde(static_cast<Tilde*>(N)); BOLT_GEN_CASE(RBrace)
case NodeKind::LParen: BOLT_GEN_CASE(RArrow)
return static_cast<D*>(this)->visitLParen(static_cast<LParen*>(N)); BOLT_GEN_CASE(RArrowAlt)
case NodeKind::RParen: BOLT_GEN_CASE(LetKeyword)
return static_cast<D*>(this)->visitRParen(static_cast<RParen*>(N)); BOLT_GEN_CASE(FnKeyword)
case NodeKind::LBracket: BOLT_GEN_CASE(MutKeyword)
return static_cast<D*>(this)->visitLBracket(static_cast<LBracket*>(N)); BOLT_GEN_CASE(PubKeyword)
case NodeKind::RBracket: BOLT_GEN_CASE(TypeKeyword)
return static_cast<D*>(this)->visitRBracket(static_cast<RBracket*>(N)); BOLT_GEN_CASE(ReturnKeyword)
case NodeKind::LBrace: BOLT_GEN_CASE(ModKeyword)
return static_cast<D*>(this)->visitLBrace(static_cast<LBrace*>(N)); BOLT_GEN_CASE(StructKeyword)
case NodeKind::RBrace: BOLT_GEN_CASE(EnumKeyword)
return static_cast<D*>(this)->visitRBrace(static_cast<RBrace*>(N)); BOLT_GEN_CASE(ClassKeyword)
case NodeKind::RArrow: BOLT_GEN_CASE(InstanceKeyword)
return static_cast<D*>(this)->visitRArrow(static_cast<RArrow*>(N)); BOLT_GEN_CASE(ElifKeyword)
case NodeKind::RArrowAlt: BOLT_GEN_CASE(IfKeyword)
return static_cast<D*>(this)->visitRArrowAlt(static_cast<RArrowAlt*>(N)); BOLT_GEN_CASE(ElseKeyword)
case NodeKind::LetKeyword: BOLT_GEN_CASE(MatchKeyword)
return static_cast<D*>(this)->visitLetKeyword(static_cast<LetKeyword*>(N)); BOLT_GEN_CASE(Invalid)
case NodeKind::MutKeyword: BOLT_GEN_CASE(EndOfFile)
return static_cast<D*>(this)->visitMutKeyword(static_cast<MutKeyword*>(N)); BOLT_GEN_CASE(BlockStart)
case NodeKind::PubKeyword: BOLT_GEN_CASE(BlockEnd)
return static_cast<D*>(this)->visitPubKeyword(static_cast<PubKeyword*>(N)); BOLT_GEN_CASE(LineFoldEnd)
case NodeKind::TypeKeyword: BOLT_GEN_CASE(CustomOperator)
return static_cast<D*>(this)->visitTypeKeyword(static_cast<TypeKeyword*>(N)); BOLT_GEN_CASE(Assignment)
case NodeKind::ReturnKeyword: BOLT_GEN_CASE(Identifier)
return static_cast<D*>(this)->visitReturnKeyword(static_cast<ReturnKeyword*>(N)); BOLT_GEN_CASE(IdentifierAlt)
case NodeKind::ModKeyword: BOLT_GEN_CASE(StringLiteral)
return static_cast<D*>(this)->visitModKeyword(static_cast<ModKeyword*>(N)); BOLT_GEN_CASE(IntegerLiteral)
case NodeKind::StructKeyword: BOLT_GEN_CASE(TypeclassConstraintExpression)
return static_cast<D*>(this)->visitStructKeyword(static_cast<StructKeyword*>(N)); BOLT_GEN_CASE(EqualityConstraintExpression)
case NodeKind::EnumKeyword: BOLT_GEN_CASE(QualifiedTypeExpression)
return static_cast<D*>(this)->visitEnumKeyword(static_cast<EnumKeyword*>(N)); BOLT_GEN_CASE(ReferenceTypeExpression)
case NodeKind::ClassKeyword: BOLT_GEN_CASE(ArrowTypeExpression)
return static_cast<D*>(this)->visitClassKeyword(static_cast<ClassKeyword*>(N)); BOLT_GEN_CASE(AppTypeExpression)
case NodeKind::InstanceKeyword: BOLT_GEN_CASE(VarTypeExpression)
return static_cast<D*>(this)->visitInstanceKeyword(static_cast<InstanceKeyword*>(N)); BOLT_GEN_CASE(NestedTypeExpression)
case NodeKind::ElifKeyword: BOLT_GEN_CASE(TupleTypeExpression)
return static_cast<D*>(this)->visitElifKeyword(static_cast<ElifKeyword*>(N)); BOLT_GEN_CASE(BindPattern)
case NodeKind::IfKeyword: BOLT_GEN_CASE(LiteralPattern)
return static_cast<D*>(this)->visitIfKeyword(static_cast<IfKeyword*>(N)); BOLT_GEN_CASE(NamedPattern)
case NodeKind::ElseKeyword: BOLT_GEN_CASE(NestedPattern)
return static_cast<D*>(this)->visitElseKeyword(static_cast<ElseKeyword*>(N)); BOLT_GEN_CASE(ReferenceExpression)
case NodeKind::MatchKeyword: BOLT_GEN_CASE(MatchCase)
return static_cast<D*>(this)->visitMatchKeyword(static_cast<MatchKeyword*>(N)); BOLT_GEN_CASE(MatchExpression)
case NodeKind::Invalid: BOLT_GEN_CASE(MemberExpression)
return static_cast<D*>(this)->visitInvalid(static_cast<Invalid*>(N)); BOLT_GEN_CASE(TupleExpression)
case NodeKind::EndOfFile: BOLT_GEN_CASE(NestedExpression)
return static_cast<D*>(this)->visitEndOfFile(static_cast<EndOfFile*>(N)); BOLT_GEN_CASE(ConstantExpression)
case NodeKind::BlockStart: BOLT_GEN_CASE(CallExpression)
return static_cast<D*>(this)->visitBlockStart(static_cast<BlockStart*>(N)); BOLT_GEN_CASE(InfixExpression)
case NodeKind::BlockEnd: BOLT_GEN_CASE(PrefixExpression)
return static_cast<D*>(this)->visitBlockEnd(static_cast<BlockEnd*>(N)); BOLT_GEN_CASE(RecordExpressionField)
case NodeKind::LineFoldEnd: BOLT_GEN_CASE(RecordExpression)
return static_cast<D*>(this)->visitLineFoldEnd(static_cast<LineFoldEnd*>(N)); BOLT_GEN_CASE(ExpressionStatement)
case NodeKind::CustomOperator: BOLT_GEN_CASE(ReturnStatement)
return static_cast<D*>(this)->visitCustomOperator(static_cast<CustomOperator*>(N)); BOLT_GEN_CASE(IfStatement)
case NodeKind::Assignment: BOLT_GEN_CASE(IfStatementPart)
return static_cast<D*>(this)->visitAssignment(static_cast<Assignment*>(N)); BOLT_GEN_CASE(TypeAssert)
case NodeKind::Identifier: BOLT_GEN_CASE(Parameter)
return static_cast<D*>(this)->visitIdentifier(static_cast<Identifier*>(N)); BOLT_GEN_CASE(LetBlockBody)
case NodeKind::IdentifierAlt: BOLT_GEN_CASE(LetExprBody)
return static_cast<D*>(this)->visitIdentifierAlt(static_cast<IdentifierAlt*>(N)); BOLT_GEN_CASE(FunctionDeclaration)
case NodeKind::StringLiteral: BOLT_GEN_CASE(VariableDeclaration)
return static_cast<D*>(this)->visitStringLiteral(static_cast<StringLiteral*>(N)); BOLT_GEN_CASE(RecordDeclaration)
case NodeKind::IntegerLiteral: BOLT_GEN_CASE(RecordDeclarationField)
return static_cast<D*>(this)->visitIntegerLiteral(static_cast<IntegerLiteral*>(N)); BOLT_GEN_CASE(VariantDeclaration)
case NodeKind::TypeclassConstraintExpression: BOLT_GEN_CASE(TupleVariantDeclarationMember)
return static_cast<D*>(this)->visitTypeclassConstraintExpression(static_cast<TypeclassConstraintExpression*>(N)); BOLT_GEN_CASE(RecordVariantDeclarationMember)
case NodeKind::EqualityConstraintExpression: BOLT_GEN_CASE(ClassDeclaration)
return static_cast<D*>(this)->visitEqualityConstraintExpression(static_cast<EqualityConstraintExpression*>(N)); BOLT_GEN_CASE(InstanceDeclaration)
case NodeKind::QualifiedTypeExpression: BOLT_GEN_CASE(SourceFile)
return static_cast<D*>(this)->visitQualifiedTypeExpression(static_cast<QualifiedTypeExpression*>(N)); }
case NodeKind::ReferenceTypeExpression:
return static_cast<D*>(this)->visitReferenceTypeExpression(static_cast<ReferenceTypeExpression*>(N));
case NodeKind::ArrowTypeExpression:
return static_cast<D*>(this)->visitArrowTypeExpression(static_cast<ArrowTypeExpression*>(N));
case NodeKind::AppTypeExpression:
return static_cast<D*>(this)->visitAppTypeExpression(static_cast<AppTypeExpression*>(N));
case NodeKind::VarTypeExpression:
return static_cast<D*>(this)->visitVarTypeExpression(static_cast<VarTypeExpression*>(N));
case NodeKind::NestedTypeExpression:
return static_cast<D*>(this)->visitNestedTypeExpression(static_cast<NestedTypeExpression*>(N));
case NodeKind::TupleTypeExpression:
return static_cast<D*>(this)->visitTupleTypeExpression(static_cast<TupleTypeExpression*>(N));
case NodeKind::BindPattern:
return static_cast<D*>(this)->visitBindPattern(static_cast<BindPattern*>(N));
case NodeKind::LiteralPattern:
return static_cast<D*>(this)->visitLiteralPattern(static_cast<LiteralPattern*>(N));
case NodeKind::NamedPattern:
return static_cast<D*>(this)->visitNamedPattern(static_cast<NamedPattern*>(N));
case NodeKind::NestedPattern:
return static_cast<D*>(this)->visitNestedPattern(static_cast<NestedPattern*>(N));
case NodeKind::ReferenceExpression:
return static_cast<D*>(this)->visitReferenceExpression(static_cast<ReferenceExpression*>(N));
case NodeKind::MatchCase:
return static_cast<D*>(this)->visitMatchCase(static_cast<MatchCase*>(N));
case NodeKind::MatchExpression:
return static_cast<D*>(this)->visitMatchExpression(static_cast<MatchExpression*>(N));
case NodeKind::MemberExpression:
return static_cast<D*>(this)->visitMemberExpression(static_cast<MemberExpression*>(N));
case NodeKind::TupleExpression:
return static_cast<D*>(this)->visitTupleExpression(static_cast<TupleExpression*>(N));
case NodeKind::NestedExpression:
return static_cast<D*>(this)->visitNestedExpression(static_cast<NestedExpression*>(N));
case NodeKind::ConstantExpression:
return static_cast<D*>(this)->visitConstantExpression(static_cast<ConstantExpression*>(N));
case NodeKind::CallExpression:
return static_cast<D*>(this)->visitCallExpression(static_cast<CallExpression*>(N));
case NodeKind::InfixExpression:
return static_cast<D*>(this)->visitInfixExpression(static_cast<InfixExpression*>(N));
case NodeKind::PrefixExpression:
return static_cast<D*>(this)->visitPrefixExpression(static_cast<PrefixExpression*>(N));
case NodeKind::RecordExpressionField:
return static_cast<D*>(this)->visitRecordExpressionField(static_cast<RecordExpressionField*>(N));
case NodeKind::RecordExpression:
return static_cast<D*>(this)->visitRecordExpression(static_cast<RecordExpression*>(N));
case NodeKind::ExpressionStatement:
return static_cast<D*>(this)->visitExpressionStatement(static_cast<ExpressionStatement*>(N));
case NodeKind::ReturnStatement:
return static_cast<D*>(this)->visitReturnStatement(static_cast<ReturnStatement*>(N));
case NodeKind::IfStatement:
return static_cast<D*>(this)->visitIfStatement(static_cast<IfStatement*>(N));
case NodeKind::IfStatementPart:
return static_cast<D*>(this)->visitIfStatementPart(static_cast<IfStatementPart*>(N));
case NodeKind::TypeAssert:
return static_cast<D*>(this)->visitTypeAssert(static_cast<TypeAssert*>(N));
case NodeKind::Parameter:
return static_cast<D*>(this)->visitParameter(static_cast<Parameter*>(N));
case NodeKind::LetBlockBody:
return static_cast<D*>(this)->visitLetBlockBody(static_cast<LetBlockBody*>(N));
case NodeKind::LetExprBody:
return static_cast<D*>(this)->visitLetExprBody(static_cast<LetExprBody*>(N));
case NodeKind::LetDeclaration:
return static_cast<D*>(this)->visitLetDeclaration(static_cast<LetDeclaration*>(N));
case NodeKind::RecordDeclarationField:
return static_cast<D*>(this)->visitRecordDeclarationField(static_cast<RecordDeclarationField*>(N));
case NodeKind::RecordDeclaration:
return static_cast<D*>(this)->visitRecordDeclaration(static_cast<RecordDeclaration*>(N));
case NodeKind::VariantDeclaration:
return static_cast<D*>(this)->visitVariantDeclaration(static_cast<VariantDeclaration*>(N));
case NodeKind::TupleVariantDeclarationMember:
return static_cast<D*>(this)->visitTupleVariantDeclarationMember(static_cast<TupleVariantDeclarationMember*>(N));
case NodeKind::RecordVariantDeclarationMember:
return static_cast<D*>(this)->visitRecordVariantDeclarationMember(static_cast<RecordVariantDeclarationMember*>(N));
case NodeKind::ClassDeclaration:
return static_cast<D*>(this)->visitClassDeclaration(static_cast<ClassDeclaration*>(N));
case NodeKind::InstanceDeclaration:
return static_cast<D*>(this)->visitInstanceDeclaration(static_cast<InstanceDeclaration*>(N));
case NodeKind::SourceFile:
return static_cast<D*>(this)->visitSourceFile(static_cast<SourceFile*>(N));
}
} }
protected: protected:
@ -248,6 +174,10 @@ namespace bolt {
visitToken(N); visitToken(N);
} }
void visitFnKeyword(FnKeyword* N) {
visitToken(N);
}
void visitMutKeyword(MutKeyword* N) { void visitMutKeyword(MutKeyword* N) {
visitToken(N); visitToken(N);
} }
@ -500,7 +430,11 @@ namespace bolt {
visitLetBody(N); visitLetBody(N);
} }
void visitLetDeclaration(LetDeclaration* N) { void visitFunctionDeclaration(FunctionDeclaration* N) {
visitNode(N);
}
void visitVariableDeclaration(VariableDeclaration* N) {
visitNode(N); visitNode(N);
} }
@ -543,252 +477,96 @@ namespace bolt {
public: public:
void visitEachChild(Node* N) { void visitEachChild(Node* N) {
#define BOLT_GEN_CHILD_CASE(name) \
case NodeKind::name: \
visitEachChild(static_cast<name*>(N)); \
break;
switch (N->getKind()) { switch (N->getKind()) {
case NodeKind::Equals: BOLT_GEN_CHILD_CASE(Equals)
visitEachChild(static_cast<Equals*>(N)); BOLT_GEN_CHILD_CASE(Colon)
break; BOLT_GEN_CHILD_CASE(Comma)
case NodeKind::Colon: BOLT_GEN_CHILD_CASE(Dot)
visitEachChild(static_cast<Colon*>(N)); BOLT_GEN_CHILD_CASE(DotDot)
break; BOLT_GEN_CHILD_CASE(Tilde)
case NodeKind::Comma: BOLT_GEN_CHILD_CASE(LParen)
visitEachChild(static_cast<Comma*>(N)); BOLT_GEN_CHILD_CASE(RParen)
break; BOLT_GEN_CHILD_CASE(LBracket)
case NodeKind::Dot: BOLT_GEN_CHILD_CASE(RBracket)
visitEachChild(static_cast<Dot*>(N)); BOLT_GEN_CHILD_CASE(LBrace)
break; BOLT_GEN_CHILD_CASE(RBrace)
case NodeKind::DotDot: BOLT_GEN_CHILD_CASE(RArrow)
visitEachChild(static_cast<DotDot*>(N)); BOLT_GEN_CHILD_CASE(RArrowAlt)
break; BOLT_GEN_CHILD_CASE(LetKeyword)
case NodeKind::Tilde: BOLT_GEN_CHILD_CASE(FnKeyword)
visitEachChild(static_cast<Tilde*>(N)); BOLT_GEN_CHILD_CASE(MutKeyword)
break; BOLT_GEN_CHILD_CASE(PubKeyword)
case NodeKind::LParen: BOLT_GEN_CHILD_CASE(TypeKeyword)
visitEachChild(static_cast<LParen*>(N)); BOLT_GEN_CHILD_CASE(ReturnKeyword)
break; BOLT_GEN_CHILD_CASE(ModKeyword)
case NodeKind::RParen: BOLT_GEN_CHILD_CASE(StructKeyword)
visitEachChild(static_cast<RParen*>(N)); BOLT_GEN_CHILD_CASE(EnumKeyword)
break; BOLT_GEN_CHILD_CASE(ClassKeyword)
case NodeKind::LBracket: BOLT_GEN_CHILD_CASE(InstanceKeyword)
visitEachChild(static_cast<LBracket*>(N)); BOLT_GEN_CHILD_CASE(ElifKeyword)
break; BOLT_GEN_CHILD_CASE(IfKeyword)
case NodeKind::RBracket: BOLT_GEN_CHILD_CASE(ElseKeyword)
visitEachChild(static_cast<RBracket*>(N)); BOLT_GEN_CHILD_CASE(MatchKeyword)
break; BOLT_GEN_CHILD_CASE(Invalid)
case NodeKind::LBrace: BOLT_GEN_CHILD_CASE(EndOfFile)
visitEachChild(static_cast<LBrace*>(N)); BOLT_GEN_CHILD_CASE(BlockStart)
break; BOLT_GEN_CHILD_CASE(BlockEnd)
case NodeKind::RBrace: BOLT_GEN_CHILD_CASE(LineFoldEnd)
visitEachChild(static_cast<RBrace*>(N)); BOLT_GEN_CHILD_CASE(CustomOperator)
break; BOLT_GEN_CHILD_CASE(Assignment)
case NodeKind::RArrow: BOLT_GEN_CHILD_CASE(Identifier)
visitEachChild(static_cast<RArrow*>(N)); BOLT_GEN_CHILD_CASE(IdentifierAlt)
break; BOLT_GEN_CHILD_CASE(StringLiteral)
case NodeKind::RArrowAlt: BOLT_GEN_CHILD_CASE(IntegerLiteral)
visitEachChild(static_cast<RArrowAlt*>(N)); BOLT_GEN_CHILD_CASE(TypeclassConstraintExpression)
break; BOLT_GEN_CHILD_CASE(EqualityConstraintExpression)
case NodeKind::LetKeyword: BOLT_GEN_CHILD_CASE(QualifiedTypeExpression)
visitEachChild(static_cast<LetKeyword*>(N)); BOLT_GEN_CHILD_CASE(ReferenceTypeExpression)
break; BOLT_GEN_CHILD_CASE(ArrowTypeExpression)
case NodeKind::MutKeyword: BOLT_GEN_CHILD_CASE(AppTypeExpression)
visitEachChild(static_cast<MutKeyword*>(N)); BOLT_GEN_CHILD_CASE(VarTypeExpression)
break; BOLT_GEN_CHILD_CASE(NestedTypeExpression)
case NodeKind::PubKeyword: BOLT_GEN_CHILD_CASE(TupleTypeExpression)
visitEachChild(static_cast<PubKeyword*>(N)); BOLT_GEN_CHILD_CASE(BindPattern)
break; BOLT_GEN_CHILD_CASE(LiteralPattern)
case NodeKind::TypeKeyword: BOLT_GEN_CHILD_CASE(NamedPattern)
visitEachChild(static_cast<TypeKeyword*>(N)); BOLT_GEN_CHILD_CASE(NestedPattern)
break; BOLT_GEN_CHILD_CASE(ReferenceExpression)
case NodeKind::ReturnKeyword: BOLT_GEN_CHILD_CASE(MatchCase)
visitEachChild(static_cast<ReturnKeyword*>(N)); BOLT_GEN_CHILD_CASE(MatchExpression)
break; BOLT_GEN_CHILD_CASE(MemberExpression)
case NodeKind::ModKeyword: BOLT_GEN_CHILD_CASE(TupleExpression)
visitEachChild(static_cast<ModKeyword*>(N)); BOLT_GEN_CHILD_CASE(NestedExpression)
break; BOLT_GEN_CHILD_CASE(ConstantExpression)
case NodeKind::StructKeyword: BOLT_GEN_CHILD_CASE(CallExpression)
visitEachChild(static_cast<StructKeyword*>(N)); BOLT_GEN_CHILD_CASE(InfixExpression)
break; BOLT_GEN_CHILD_CASE(PrefixExpression)
case NodeKind::EnumKeyword: BOLT_GEN_CHILD_CASE(RecordExpressionField)
visitEachChild(static_cast<EnumKeyword*>(N)); BOLT_GEN_CHILD_CASE(RecordExpression)
break; BOLT_GEN_CHILD_CASE(ExpressionStatement)
case NodeKind::ClassKeyword: BOLT_GEN_CHILD_CASE(ReturnStatement)
visitEachChild(static_cast<ClassKeyword*>(N)); BOLT_GEN_CHILD_CASE(IfStatement)
break; BOLT_GEN_CHILD_CASE(IfStatementPart)
case NodeKind::InstanceKeyword: BOLT_GEN_CHILD_CASE(TypeAssert)
visitEachChild(static_cast<InstanceKeyword*>(N)); BOLT_GEN_CHILD_CASE(Parameter)
break; BOLT_GEN_CHILD_CASE(LetBlockBody)
case NodeKind::ElifKeyword: BOLT_GEN_CHILD_CASE(LetExprBody)
visitEachChild(static_cast<ElifKeyword*>(N)); BOLT_GEN_CHILD_CASE(FunctionDeclaration)
break; BOLT_GEN_CHILD_CASE(VariableDeclaration)
case NodeKind::IfKeyword: BOLT_GEN_CHILD_CASE(RecordDeclaration)
visitEachChild(static_cast<IfKeyword*>(N)); BOLT_GEN_CHILD_CASE(RecordDeclarationField)
break; BOLT_GEN_CHILD_CASE(VariantDeclaration)
case NodeKind::ElseKeyword: BOLT_GEN_CHILD_CASE(TupleVariantDeclarationMember)
visitEachChild(static_cast<ElseKeyword*>(N)); BOLT_GEN_CHILD_CASE(RecordVariantDeclarationMember)
break; BOLT_GEN_CHILD_CASE(ClassDeclaration)
case NodeKind::MatchKeyword: BOLT_GEN_CHILD_CASE(InstanceDeclaration)
visitEachChild(static_cast<MatchKeyword*>(N)); BOLT_GEN_CHILD_CASE(SourceFile)
break;
case NodeKind::Invalid:
visitEachChild(static_cast<Invalid*>(N));
break;
case NodeKind::EndOfFile:
visitEachChild(static_cast<EndOfFile*>(N));
break;
case NodeKind::BlockStart:
visitEachChild(static_cast<BlockStart*>(N));
break;
case NodeKind::BlockEnd:
visitEachChild(static_cast<BlockEnd*>(N));
break;
case NodeKind::LineFoldEnd:
visitEachChild(static_cast<LineFoldEnd*>(N));
break;
case NodeKind::CustomOperator:
visitEachChild(static_cast<CustomOperator*>(N));
break;
case NodeKind::Assignment:
visitEachChild(static_cast<Assignment*>(N));
break;
case NodeKind::Identifier:
visitEachChild(static_cast<Identifier*>(N));
break;
case NodeKind::IdentifierAlt:
visitEachChild(static_cast<IdentifierAlt*>(N));
break;
case NodeKind::StringLiteral:
visitEachChild(static_cast<StringLiteral*>(N));
break;
case NodeKind::IntegerLiteral:
visitEachChild(static_cast<IntegerLiteral*>(N));
break;
case NodeKind::TypeclassConstraintExpression:
visitEachChild(static_cast<TypeclassConstraintExpression*>(N));
break;
case NodeKind::EqualityConstraintExpression:
visitEachChild(static_cast<EqualityConstraintExpression*>(N));
break;
case NodeKind::QualifiedTypeExpression:
visitEachChild(static_cast<QualifiedTypeExpression*>(N));
break;
case NodeKind::ReferenceTypeExpression:
visitEachChild(static_cast<ReferenceTypeExpression*>(N));
break;
case NodeKind::ArrowTypeExpression:
visitEachChild(static_cast<ArrowTypeExpression*>(N));
break;
case NodeKind::AppTypeExpression:
visitEachChild(static_cast<AppTypeExpression*>(N));
break;
case NodeKind::VarTypeExpression:
visitEachChild(static_cast<VarTypeExpression*>(N));
break;
case NodeKind::NestedTypeExpression:
visitEachChild(static_cast<NestedTypeExpression*>(N));
break;
case NodeKind::TupleTypeExpression:
visitEachChild(static_cast<TupleTypeExpression*>(N));
break;
case NodeKind::BindPattern:
visitEachChild(static_cast<BindPattern*>(N));
break;
case NodeKind::LiteralPattern:
visitEachChild(static_cast<LiteralPattern*>(N));
break;
case NodeKind::NamedPattern:
visitEachChild(static_cast<NamedPattern*>(N));
break;
case NodeKind::NestedPattern:
visitEachChild(static_cast<NestedPattern*>(N));
break;
case NodeKind::ReferenceExpression:
visitEachChild(static_cast<ReferenceExpression*>(N));
break;
case NodeKind::MatchCase:
visitEachChild(static_cast<MatchCase*>(N));
break;
case NodeKind::MatchExpression:
visitEachChild(static_cast<MatchExpression*>(N));
break;
case NodeKind::MemberExpression:
visitEachChild(static_cast<MemberExpression*>(N));
break;
case NodeKind::TupleExpression:
visitEachChild(static_cast<TupleExpression*>(N));
break;
case NodeKind::NestedExpression:
visitEachChild(static_cast<NestedExpression*>(N));
break;
case NodeKind::ConstantExpression:
visitEachChild(static_cast<ConstantExpression*>(N));
break;
case NodeKind::CallExpression:
visitEachChild(static_cast<CallExpression*>(N));
break;
case NodeKind::InfixExpression:
visitEachChild(static_cast<InfixExpression*>(N));
break;
case NodeKind::PrefixExpression:
visitEachChild(static_cast<PrefixExpression*>(N));
break;
case NodeKind::RecordExpressionField:
visitEachChild(static_cast<RecordExpressionField*>(N));
break;
case NodeKind::RecordExpression:
visitEachChild(static_cast<RecordExpression*>(N));
break;
case NodeKind::ExpressionStatement:
visitEachChild(static_cast<ExpressionStatement*>(N));
break;
case NodeKind::ReturnStatement:
visitEachChild(static_cast<ReturnStatement*>(N));
break;
case NodeKind::IfStatement:
visitEachChild(static_cast<IfStatement*>(N));
break;
case NodeKind::IfStatementPart:
visitEachChild(static_cast<IfStatementPart*>(N));
break;
case NodeKind::TypeAssert:
visitEachChild(static_cast<TypeAssert*>(N));
break;
case NodeKind::Parameter:
visitEachChild(static_cast<Parameter*>(N));
break;
case NodeKind::LetBlockBody:
visitEachChild(static_cast<LetBlockBody*>(N));
break;
case NodeKind::LetExprBody:
visitEachChild(static_cast<LetExprBody*>(N));
break;
case NodeKind::LetDeclaration:
visitEachChild(static_cast<LetDeclaration*>(N));
break;
case NodeKind::RecordDeclaration:
visitEachChild(static_cast<RecordDeclaration*>(N));
break;
case NodeKind::RecordDeclarationField:
visitEachChild(static_cast<RecordDeclarationField*>(N));
break;
case NodeKind::VariantDeclaration:
visitEachChild(static_cast<VariantDeclaration*>(N));
break;
case NodeKind::TupleVariantDeclarationMember:
visitEachChild(static_cast<TupleVariantDeclarationMember*>(N));
break;
case NodeKind::RecordVariantDeclarationMember:
visitEachChild(static_cast<RecordVariantDeclarationMember*>(N));
break;
case NodeKind::ClassDeclaration:
visitEachChild(static_cast<ClassDeclaration*>(N));
break;
case NodeKind::InstanceDeclaration:
visitEachChild(static_cast<InstanceDeclaration*>(N));
break;
case NodeKind::SourceFile:
visitEachChild(static_cast<SourceFile*>(N));
break;
default:
ZEN_UNREACHABLE
} }
} }
@ -839,6 +617,9 @@ namespace bolt {
void visitEachChild(LetKeyword* N) { void visitEachChild(LetKeyword* N) {
} }
void visitEachChild(FnKeyword* N) {
}
void visitEachChild(MutKeyword* N) { void visitEachChild(MutKeyword* N) {
} }
@ -1136,18 +917,29 @@ namespace bolt {
BOLT_VISIT(N->Expression); BOLT_VISIT(N->Expression);
} }
void visitEachChild(LetDeclaration* N) { void visitEachChild(FunctionDeclaration* N) {
if (N->PubKeyword) {
BOLT_VISIT(N->PubKeyword);
}
BOLT_VISIT(N->FnKeyword);
BOLT_VISIT(N->Name);
for (auto Param: N->Params) {
BOLT_VISIT(Param);
}
if (N->TypeAssert) {
BOLT_VISIT(N->TypeAssert);
}
if (N->Body) {
BOLT_VISIT(N->Body);
}
}
void visitEachChild(VariableDeclaration* N) {
if (N->PubKeyword) { if (N->PubKeyword) {
BOLT_VISIT(N->PubKeyword); BOLT_VISIT(N->PubKeyword);
} }
BOLT_VISIT(N->LetKeyword); BOLT_VISIT(N->LetKeyword);
if (N->MutKeyword) {
BOLT_VISIT(N->MutKeyword);
}
BOLT_VISIT(N->Pattern); BOLT_VISIT(N->Pattern);
for (auto Param: N->Params) {
BOLT_VISIT(Param);
}
if (N->TypeAssert) { if (N->TypeAssert) {
BOLT_VISIT(N->TypeAssert); BOLT_VISIT(N->TypeAssert);
} }

View file

@ -197,16 +197,16 @@ namespace bolt {
void addConstraint(Constraint* Constraint); void addConstraint(Constraint* Constraint);
void forwardDeclare(Node* Node); void forwardDeclare(Node* Node);
void forwardDeclareLetDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints); void forwardDeclareFunctionDeclaration(FunctionDeclaration* N, TVSet* TVs, ConstraintSet* Constraints);
Type* inferExpression(Expression* Expression); Type* inferExpression(Expression* Expression);
Type* inferTypeExpression(TypeExpression* TE); Type* inferTypeExpression(TypeExpression* TE, bool IsPoly = true);
Type* inferLiteral(Literal* Lit); Type* inferLiteral(Literal* Lit);
Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet); Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet);
void infer(Node* node); void infer(Node* node);
void inferLetDeclaration(LetDeclaration* N); void inferFunctionDeclaration(FunctionDeclaration* N);
Constraint* convertToConstraint(ConstraintExpression* C); Constraint* convertToConstraint(ConstraintExpression* C);

View file

@ -9,7 +9,7 @@ namespace bolt {
ConfigFlags_TypeVarsRequireForall = 1 << 0, ConfigFlags_TypeVarsRequireForall = 1 << 0,
}; };
unsigned Flags; unsigned Flags = 0;
public: public:

View file

@ -109,9 +109,9 @@ namespace bolt {
public: public:
TypeclassSignature Sig; TypeclassSignature Sig;
LetDeclaration* Decl; FunctionDeclaration* Decl;
inline TypeclassMissingDiagnostic(TypeclassSignature Sig, LetDeclaration* Decl): inline TypeclassMissingDiagnostic(TypeclassSignature Sig, FunctionDeclaration* Decl):
Diagnostic(DiagnosticKind::TypeclassMissing), Sig(Sig), Decl(Decl) {} Diagnostic(DiagnosticKind::TypeclassMissing), Sig(Sig), Decl(Decl) {}
inline Node* getNode() const override { inline Node* getNode() const override {

View file

@ -124,7 +124,9 @@ namespace bolt {
Node* parseLetBodyElement(); Node* parseLetBodyElement();
LetDeclaration* parseLetDeclaration(); FunctionDeclaration* parseFunctionDeclaration();
VariableDeclaration* parseVariableDeclaration();
Node* parseClassElement(); Node* parseClassElement();

View file

@ -69,9 +69,9 @@ namespace bolt {
} }
break; break;
} }
case NodeKind::LetDeclaration: case NodeKind::FunctionDeclaration:
{ {
auto Decl = static_cast<LetDeclaration*>(X); auto Decl = static_cast<FunctionDeclaration*>(X);
for (auto Param: Decl->Params) { for (auto Param: Decl->Params) {
visitPattern(Param->Pattern, Param); visitPattern(Param->Pattern, Param);
} }
@ -112,12 +112,18 @@ namespace bolt {
} }
break; break;
} }
case NodeKind::LetDeclaration: case NodeKind::VariableDeclaration:
{ {
auto Decl = static_cast<LetDeclaration*>(X); auto Decl = static_cast<VariableDeclaration*>(X);
visitPattern(Decl->Pattern, Decl); visitPattern(Decl->Pattern, Decl);
break; break;
} }
case NodeKind::FunctionDeclaration:
{
auto Decl = static_cast<FunctionDeclaration*>(X);
addSymbol(Decl->Name->getCanonicalText(), Decl, SymbolKind::Var);
break;
}
case NodeKind::RecordDeclaration: case NodeKind::RecordDeclaration:
{ {
auto Decl = static_cast<RecordDeclaration*>(X); auto Decl = static_cast<RecordDeclaration*>(X);
@ -597,14 +603,14 @@ namespace bolt {
return Expression->getLastToken(); return Expression->getLastToken();
} }
Token* LetDeclaration::getFirstToken() const { Token* FunctionDeclaration::getFirstToken() const {
if (PubKeyword) { if (PubKeyword) {
return PubKeyword; return PubKeyword;
} }
return LetKeyword; return FnKeyword;
} }
Token* LetDeclaration::getLastToken() const { Token* FunctionDeclaration::getLastToken() const {
if (Body) { if (Body) {
return Body->getLastToken(); return Body->getLastToken();
} }
@ -614,6 +620,23 @@ namespace bolt {
if (Params.size()) { if (Params.size()) {
return Params.back()->getLastToken(); return Params.back()->getLastToken();
} }
return Name;
}
Token* VariableDeclaration::getFirstToken() const {
if (PubKeyword) {
return PubKeyword;
}
return LetKeyword;
}
Token* VariableDeclaration::getLastToken() const {
if (Body) {
return Body->getLastToken();
}
if (TypeAssert) {
return TypeAssert->getLastToken();
}
return Pattern->getLastToken(); return Pattern->getLastToken();
} }
@ -766,6 +789,10 @@ namespace bolt {
return "let"; return "let";
} }
std::string FnKeyword::getText() const {
return "fn";
}
std::string MutKeyword::getText() const { std::string MutKeyword::getText() const {
return "mut"; return "mut";
} }

View file

@ -12,6 +12,12 @@
// TODO see if we can merge UnificationError diagnostics so that we get a list of **all** types that were wrong on a given node // TODO see if we can merge UnificationError diagnostics so that we get a list of **all** types that were wrong on a given node
// TODO When a forall variable is missing, do not just insert a blank one into the env. It will result in too few diagnostics being emitted.
// Same goes for reference expressions.
// If running the compiler as a language server, this matters.
// TODO Add a pattern that only performs a type assert
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
#include <stack> #include <stack>
@ -296,10 +302,14 @@ namespace bolt {
break; break;
} }
case NodeKind::LetDeclaration: case NodeKind::FunctionDeclaration:
// These declarations will be handled separately in check() // These declarations will be handled separately in check()
break; break;
case NodeKind::VariableDeclaration:
// All of this node's semantics will be handled in infer()
break;
case NodeKind::VariantDeclaration: case NodeKind::VariantDeclaration:
{ {
auto Decl = static_cast<VariantDeclaration*>(X); auto Decl = static_cast<VariantDeclaration*>(X);
@ -376,8 +386,8 @@ namespace bolt {
for (auto TV: Vars) { for (auto TV: Vars) {
RetTy = new TApp(RetTy, TV); RetTy = new TApp(RetTy, TV);
} }
Decl->Ctx->Parent->Env.emplace(Name, new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow({ FieldsTy }, RetTy)));
popContext(); popContext();
addBinding(Name, new Forall(Decl->Ctx->TVs, Decl->Ctx->Constraints, new TArrow({ FieldsTy }, RetTy)));
break; break;
} }
@ -423,18 +433,18 @@ namespace bolt {
Contexts.pop(); Contexts.pop();
} }
void visitLetDeclaration(LetDeclaration* Let) { void visitFunctionDeclaration(FunctionDeclaration* Let) {
if (Let->isFunc()) { Let->Ctx = createDerivedContext();
Let->Ctx = createDerivedContext(); Contexts.push(Let->Ctx);
Contexts.push(Let->Ctx); visitEachChild(Let);
visitEachChild(Let); Contexts.pop();
Contexts.pop();
} else {
Let->Ctx = Contexts.top();
visitEachChild(Let);
}
} }
// void visitVariableDeclaration(VariableDeclaration* Var) {
// Var->Ctx = Contexts.top();
// visitEachChild(Var);
// }
}; };
Init I { {}, *this }; Init I { {}, *this };
@ -442,9 +452,7 @@ namespace bolt {
} }
void Checker::forwardDeclareLetDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints) { void Checker::forwardDeclareFunctionDeclaration(FunctionDeclaration* Let, TVSet* TVs, ConstraintSet* Constraints) {
auto Let = static_cast<LetDeclaration*>(N);
setContext(Let->Ctx); setContext(Let->Ctx);
@ -495,7 +503,7 @@ namespace bolt {
Params.push_back(TV); Params.push_back(TV);
} }
auto SigLet = llvm::cast<LetDeclaration>(Class->getScope()->lookupDirect({ {}, llvm::cast<BindPattern>(Let->Pattern)->Name->getCanonicalText() }, SymbolKind::Var)); auto SigLet = llvm::cast<FunctionDeclaration>(Class->getScope()->lookupDirect({ {}, Let->Name->getCanonicalText() }, SymbolKind::Var));
// It would be very strange if there was no type assert in the type // It would be very strange if there was no type assert in the type
// class let-declaration but we rather not let the compiler crash if that happens. // class let-declaration but we rather not let the compiler crash if that happens.
@ -520,9 +528,7 @@ namespace bolt {
case NodeKind::LetBlockBody: case NodeKind::LetBlockBody:
{ {
auto Block = static_cast<LetBlockBody*>(Let->Body); auto Block = static_cast<LetBlockBody*>(Let->Body);
if (Let->isFunc()) { Let->Ctx->ReturnType = createTypeVar();
Let->Ctx->ReturnType = createTypeVar();
}
for (auto Element: Block->Elements) { for (auto Element: Block->Elements) {
forwardDeclare(Element); forwardDeclare(Element);
} }
@ -533,19 +539,11 @@ namespace bolt {
} }
} }
Let->Ctx->Parent->Env.emplace(Let->Name->getCanonicalText(), new Forall(Let->Ctx->TVs, Let->Ctx->Constraints, Ty));
Type* BindTy;
if (Let->isFunc()) {
popContext();
BindTy = inferPattern(Let->Pattern, Let->Ctx->Constraints, Let->Ctx->TVs);
} else {
BindTy = inferPattern(Let->Pattern);
}
addConstraint(new CEqual(BindTy, Ty, Let));
} }
void Checker::inferLetDeclaration(LetDeclaration* Decl) { void Checker::inferFunctionDeclaration(FunctionDeclaration* Decl) {
setContext(Decl->Ctx); setContext(Decl->Ctx);
@ -553,7 +551,6 @@ namespace bolt {
Type* RetType; Type* RetType;
for (auto Param: Decl->Params) { for (auto Param: Decl->Params) {
// TODO incorporate Param->TypeAssert or make it a kind of pattern
ParamTypes.push_back(inferPattern(Param->Pattern)); ParamTypes.push_back(inferPattern(Param->Pattern));
} }
@ -568,7 +565,6 @@ namespace bolt {
case NodeKind::LetBlockBody: case NodeKind::LetBlockBody:
{ {
auto Block = static_cast<LetBlockBody*>(Decl->Body); auto Block = static_cast<LetBlockBody*>(Decl->Body);
ZEN_ASSERT(Decl->isFunc());
RetType = Decl->Ctx->ReturnType; RetType = Decl->Ctx->ReturnType;
for (auto Element: Block->Elements) { for (auto Element: Block->Elements) {
infer(Element); infer(Element);
@ -582,13 +578,7 @@ namespace bolt {
RetType = createTypeVar(); RetType = createTypeVar();
} }
if (Decl->isFunc()) { addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), Decl });
popContext();
addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), Decl });
} else {
// Declaration is a plain (typed) variable
addConstraint(new CEqual { Decl->Ty, RetType, Decl });
}
} }
@ -642,7 +632,7 @@ namespace bolt {
break; break;
} }
case NodeKind::LetDeclaration: case NodeKind::FunctionDeclaration:
break; break;
case NodeKind::ReturnStatement: case NodeKind::ReturnStatement:
@ -658,6 +648,33 @@ namespace bolt {
break; break;
} }
case NodeKind::VariableDeclaration:
{
auto Decl = static_cast<VariableDeclaration*>(N);
Type* Ty = nullptr;
if (Decl->TypeAssert) {
Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression, false);
}
if (Decl->Body) {
ZEN_ASSERT(Decl->Body->getKind() == NodeKind::LetExprBody);
auto E = static_cast<LetExprBody*>(Decl->Body);
auto Ty2 = inferExpression(E->Expression);
if (Ty) {
addConstraint(new CEqual(Ty, Ty2, Decl));
} else {
Ty = Ty2;
}
}
auto Ty3 = inferPattern(Decl->Pattern);
if (Ty) {
addConstraint(new CEqual(Ty, Ty3, Decl));
} else {
Ty = Ty3;
}
Decl->setType(Ty);
break;
}
case NodeKind::ExpressionStatement: case NodeKind::ExpressionStatement:
{ {
auto ExprStmt = static_cast<ExpressionStatement*>(N); auto ExprStmt = static_cast<ExpressionStatement*>(N);
@ -764,7 +781,7 @@ namespace bolt {
} }
} }
Type* Checker::inferTypeExpression(TypeExpression* N) { Type* Checker::inferTypeExpression(TypeExpression* N, bool IsPoly) {
switch (N->getKind()) { switch (N->getKind()) {
@ -786,9 +803,9 @@ namespace bolt {
case NodeKind::AppTypeExpression: case NodeKind::AppTypeExpression:
{ {
auto AppTE = static_cast<AppTypeExpression*>(N); auto AppTE = static_cast<AppTypeExpression*>(N);
Type* Ty = inferTypeExpression(AppTE->Op); Type* Ty = inferTypeExpression(AppTE->Op, IsPoly);
for (auto Arg: AppTE->Args) { for (auto Arg: AppTE->Args) {
Ty = new TApp(Ty, inferTypeExpression(Arg)); Ty = new TApp(Ty, inferTypeExpression(Arg, IsPoly));
} }
return Ty; return Ty;
} }
@ -798,10 +815,10 @@ namespace bolt {
auto VarTE = static_cast<VarTypeExpression*>(N); auto VarTE = static_cast<VarTypeExpression*>(N);
auto Ty = lookupMono(VarTE->Name->getCanonicalText()); auto Ty = lookupMono(VarTE->Name->getCanonicalText());
if (Ty == nullptr) { if (Ty == nullptr) {
if (Config.typeVarsRequireForall()) { if (IsPoly && Config.typeVarsRequireForall()) {
DE.add<BindingNotFoundDiagnostic>(VarTE->Name->getCanonicalText(), VarTE->Name); DE.add<BindingNotFoundDiagnostic>(VarTE->Name->getCanonicalText(), VarTE->Name);
} }
Ty = createRigidVar(VarTE->Name->getCanonicalText()); Ty = IsPoly ? createRigidVar(VarTE->Name->getCanonicalText()) : createTypeVar();
addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty)); addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty));
} }
ZEN_ASSERT(Ty->getKind() == TypeKind::Var); ZEN_ASSERT(Ty->getKind() == TypeKind::Var);
@ -814,7 +831,7 @@ namespace bolt {
auto TupleTE = static_cast<TupleTypeExpression*>(N); auto TupleTE = static_cast<TupleTypeExpression*>(N);
std::vector<Type*> ElementTypes; std::vector<Type*> ElementTypes;
for (auto [TE, Comma]: TupleTE->Elements) { for (auto [TE, Comma]: TupleTE->Elements) {
ElementTypes.push_back(inferTypeExpression(TE)); ElementTypes.push_back(inferTypeExpression(TE, IsPoly));
} }
auto Ty = new TTuple(ElementTypes); auto Ty = new TTuple(ElementTypes);
N->setType(Ty); N->setType(Ty);
@ -824,7 +841,7 @@ namespace bolt {
case NodeKind::NestedTypeExpression: case NodeKind::NestedTypeExpression:
{ {
auto NestedTE = static_cast<NestedTypeExpression*>(N); auto NestedTE = static_cast<NestedTypeExpression*>(N);
auto Ty = inferTypeExpression(NestedTE->TE); auto Ty = inferTypeExpression(NestedTE->TE, IsPoly);
N->setType(Ty); N->setType(Ty);
return Ty; return Ty;
} }
@ -834,9 +851,9 @@ namespace bolt {
auto ArrowTE = static_cast<ArrowTypeExpression*>(N); auto ArrowTE = static_cast<ArrowTypeExpression*>(N);
std::vector<Type*> ParamTypes; std::vector<Type*> ParamTypes;
for (auto ParamType: ArrowTE->ParamTypes) { for (auto ParamType: ArrowTE->ParamTypes) {
ParamTypes.push_back(inferTypeExpression(ParamType)); ParamTypes.push_back(inferTypeExpression(ParamType, IsPoly));
} }
auto ReturnType = inferTypeExpression(ArrowTE->ReturnType); auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, IsPoly);
auto Ty = new TArrow(ParamTypes, ReturnType); auto Ty = new TArrow(ParamTypes, ReturnType);
N->setType(Ty); N->setType(Ty);
return Ty; return Ty;
@ -848,7 +865,7 @@ namespace bolt {
for (auto [C, Comma]: QTE->Constraints) { for (auto [C, Comma]: QTE->Constraints) {
addConstraint(convertToConstraint(C)); addConstraint(convertToConstraint(C));
} }
auto Ty = inferTypeExpression(QTE->TE); auto Ty = inferTypeExpression(QTE->TE, IsPoly);
N->setType(Ty); N->setType(Ty);
return Ty; return Ty;
} }
@ -889,12 +906,13 @@ namespace bolt {
} }
Ty = createTypeVar(); Ty = createTypeVar();
for (auto Case: Match->Cases) { for (auto Case: Match->Cases) {
auto OldCtx = &getContext();
setContext(Case->Ctx); setContext(Case->Ctx);
auto PattTy = inferPattern(Case->Pattern); auto PattTy = inferPattern(Case->Pattern);
addConstraint(new CEqual(PattTy, ValTy, Case)); addConstraint(new CEqual(PattTy, ValTy, Case));
auto ExprTy = inferExpression(Case->Expression); auto ExprTy = inferExpression(Case->Expression);
addConstraint(new CEqual(ExprTy, Ty, Case->Expression)); addConstraint(new CEqual(ExprTy, Ty, Case->Expression));
popContext(); setContext(OldCtx);
} }
if (!Match->Value) { if (!Match->Value) {
Ty = new TArrow({ ValTy }, Ty); Ty = new TArrow({ ValTy }, Ty);
@ -925,8 +943,8 @@ namespace bolt {
auto Ref = static_cast<ReferenceExpression*>(X); auto Ref = static_cast<ReferenceExpression*>(X);
ZEN_ASSERT(Ref->ModulePath.empty()); ZEN_ASSERT(Ref->ModulePath.empty());
auto Target = Ref->getScope()->lookup(Ref->getSymbolPath()); auto Target = Ref->getScope()->lookup(Ref->getSymbolPath());
if (Target && llvm::isa<LetDeclaration>(Target)) { if (Target && llvm::isa<FunctionDeclaration>(Target)) {
auto Let = static_cast<LetDeclaration*>(Target); auto Let = static_cast<FunctionDeclaration*>(Target);
if (Let->IsCycleActive) { if (Let->IsCycleActive) {
return Let->Ty; return Let->Ty;
} }
@ -1100,7 +1118,7 @@ namespace bolt {
std::stack<Node*> Stack; std::stack<Node*> Stack;
void visitLetDeclaration(LetDeclaration* N) { void visitFunctionDeclaration(FunctionDeclaration* N) {
RefGraph.addVertex(N); RefGraph.addVertex(N);
Stack.push(N); Stack.push(N);
visitEachChild(N); visitEachChild(N);
@ -1121,7 +1139,7 @@ namespace bolt {
RefGraph.addEdge(Stack.top(), Def->Parent); RefGraph.addEdge(Stack.top(), Def->Parent);
return; return;
} }
ZEN_ASSERT(Def->getKind() == NodeKind::LetDeclaration); ZEN_ASSERT(Def->getKind() == NodeKind::FunctionDeclaration || Def->getKind() == NodeKind::VariableDeclaration);
if (!Stack.empty()) { if (!Stack.empty()) {
RefGraph.addEdge(Def, Stack.top()); RefGraph.addEdge(Def, Stack.top());
} }
@ -1140,7 +1158,7 @@ namespace bolt {
Checker& C; Checker& C;
void visitLetDeclaration(LetDeclaration* Decl) { void visitLetDeclaration(FunctionDeclaration* Decl) {
// Only inspect those let-declarations that look like a function // Only inspect those let-declarations that look like a function
if (Decl->Params.empty()) { if (Decl->Params.empty()) {
@ -1289,26 +1307,26 @@ namespace bolt {
auto TVs = new TVSet; auto TVs = new TVSet;
auto Constraints = new ConstraintSet; auto Constraints = new ConstraintSet;
for (auto N: Nodes) { for (auto N: Nodes) {
auto Decl = static_cast<LetDeclaration*>(N); auto Decl = static_cast<FunctionDeclaration*>(N);
forwardDeclareLetDeclaration(Decl, TVs, Constraints); forwardDeclareFunctionDeclaration(Decl, TVs, Constraints);
} }
} }
for (auto Nodes: SCCs) { for (auto Nodes: SCCs) {
for (auto N: Nodes) { for (auto N: Nodes) {
auto Decl = static_cast<LetDeclaration*>(N); auto Decl = static_cast<FunctionDeclaration*>(N);
Decl->IsCycleActive = true; Decl->IsCycleActive = true;
} }
for (auto N: Nodes) { for (auto N: Nodes) {
auto Decl = static_cast<LetDeclaration*>(N); auto Decl = static_cast<FunctionDeclaration*>(N);
inferLetDeclaration(Decl); inferFunctionDeclaration(Decl);
} }
for (auto N: Nodes) { for (auto N: Nodes) {
auto Decl = static_cast<LetDeclaration*>(N); auto Decl = static_cast<FunctionDeclaration*>(N);
Decl->IsCycleActive = false; Decl->IsCycleActive = false;
} }
} }
setContext(SF->Ctx);
infer(SF); infer(SF);
popContext();
solve(new CMany(*SF->Ctx->Constraints)); solve(new CMany(*SF->Ctx->Constraints));
checkTypeclassSigs(SF); checkTypeclassSigs(SF);
} }

View file

@ -100,6 +100,8 @@ namespace bolt {
return "'pub'"; return "'pub'";
case NodeKind::LetKeyword: case NodeKind::LetKeyword:
return "'let'"; return "'let'";
case NodeKind::FnKeyword:
return "'fn'";
case NodeKind::MutKeyword: case NodeKind::MutKeyword:
return "'mut'"; return "'mut'";
case NodeKind::MatchKeyword: case NodeKind::MatchKeyword:
@ -108,8 +110,10 @@ namespace bolt {
return "'return'"; return "'return'";
case NodeKind::TypeKeyword: case NodeKind::TypeKeyword:
return "'type'"; return "'type'";
case NodeKind::LetDeclaration: case NodeKind::FunctionDeclaration:
return "a let-declaration"; return "a function declaration";
case NodeKind::VariableDeclaration:
return "a variable declaration";
case NodeKind::CallExpression: case NodeKind::CallExpression:
return "a call-expression"; return "a call-expression";
case NodeKind::InfixExpression: case NodeKind::InfixExpression:

View file

@ -853,7 +853,7 @@ finish:
return new IfStatement(Parts); return new IfStatement(Parts);
} }
LetDeclaration* Parser::parseLetDeclaration() { VariableDeclaration* Parser::parseVariableDeclaration() {
PubKeyword* Pub = nullptr; PubKeyword* Pub = nullptr;
LetKeyword* Let; LetKeyword* Let;
@ -881,8 +881,8 @@ finish:
Tokens.get(); Tokens.get();
} }
auto Patt = parsePattern(); auto P = parsePattern();
if (!Patt) { if (!P) {
if (Pub) { if (Pub) {
Pub->unref(); Pub->unref();
} }
@ -894,27 +894,7 @@ finish:
return nullptr; return nullptr;
} }
std::vector<Parameter*> Params; auto T2 = Tokens.peek();
Token* T2;
for (;;) {
T2 = Tokens.peek();
switch (T2->getKind()) {
case NodeKind::LineFoldEnd:
case NodeKind::BlockStart:
case NodeKind::Equals:
case NodeKind::Colon:
goto after_params;
default:
auto P = parsePattern();
if (P == nullptr) {
P = new BindPattern(new Identifier("_"));
}
Params.push_back(new Parameter(P, nullptr));
}
}
after_params:
if (T2->getKind() == NodeKind::Colon) { if (T2->getKind() == NodeKind::Colon) {
Tokens.get(); Tokens.get();
auto TE = parseTypeExpression(); auto TE = parseTypeExpression();
@ -972,16 +952,137 @@ after_params:
DE.add<UnexpectedTokenDiagnostic>(File, T2, Expected); DE.add<UnexpectedTokenDiagnostic>(File, T2, Expected);
} }
after_body: checkLineFoldEnd();
finish:
return new VariableDeclaration { Pub, Let, Mut, P, TA, Body };
}
FunctionDeclaration* Parser::parseFunctionDeclaration() {
PubKeyword* Pub = nullptr;
FnKeyword* Fn;
MutKeyword* Mut = nullptr;
TypeAssert* TA = nullptr;
LetBody* Body = nullptr;
auto T0 = Tokens.get();
if (T0->getKind() == NodeKind::PubKeyword) {
Pub = static_cast<PubKeyword*>(T0);
T0 = Tokens.get();
}
if (T0->getKind() != NodeKind::FnKeyword) {
DE.add<UnexpectedTokenDiagnostic>(File, T0, std::vector { NodeKind::FnKeyword });
if (Pub) {
Pub->unref();
}
skipToLineFoldEnd();
return nullptr;
}
Fn = static_cast<FnKeyword*>(T0);
auto T1 = Tokens.peek();
if (T1->getKind() == NodeKind::MutKeyword) {
Mut = static_cast<MutKeyword*>(T1);
Tokens.get();
}
auto Name = expectToken<Identifier>();
if (!Name) {
if (Pub) {
Pub->unref();
}
Fn->unref();
if (Mut) {
Mut->unref();
}
skipToLineFoldEnd();
return nullptr;
}
std::vector<Parameter*> Params;
Token* T2;
for (;;) {
T2 = Tokens.peek();
switch (T2->getKind()) {
case NodeKind::LineFoldEnd:
case NodeKind::BlockStart:
case NodeKind::Equals:
case NodeKind::Colon:
goto after_params;
default:
auto P = parsePattern();
if (!P) {
P = new BindPattern(new Identifier("_"));
}
Params.push_back(new Parameter(P, nullptr));
}
}
after_params:
if (T2->getKind() == NodeKind::Colon) {
Tokens.get();
auto TE = parseTypeExpression();
if (TE) {
TA = new TypeAssert(static_cast<Colon*>(T2), TE);
} else {
skipToLineFoldEnd();
goto finish;
}
T2 = Tokens.peek();
}
switch (T2->getKind()) {
case NodeKind::BlockStart:
{
Tokens.get();
std::vector<Node*> Elements;
for (;;) {
auto T3 = Tokens.peek();
if (T3->getKind() == NodeKind::BlockEnd) {
break;
}
auto Element = parseLetBodyElement();
if (Element) {
Elements.push_back(Element);
}
}
Tokens.get()->unref(); // Always a BlockEnd
Body = new LetBlockBody(static_cast<BlockStart*>(T2), Elements);
break;
}
case NodeKind::Equals:
{
Tokens.get();
auto E = parseExpression();
if (!E) {
skipToLineFoldEnd();
goto finish;
}
Body = new LetExprBody(static_cast<Equals*>(T2), E);
break;
}
case NodeKind::LineFoldEnd:
break;
default:
std::vector<NodeKind> Expected { NodeKind::BlockStart, NodeKind::LineFoldEnd, NodeKind::Equals };
if (TA == nullptr) {
// First tokens of TypeAssert
Expected.push_back(NodeKind::Colon);
// First tokens of Pattern
Expected.push_back(NodeKind::Identifier);
}
DE.add<UnexpectedTokenDiagnostic>(File, T2, Expected);
}
checkLineFoldEnd(); checkLineFoldEnd();
finish: finish:
return new LetDeclaration( return new FunctionDeclaration(
Pub, Pub,
Let, Fn,
Mut, Name,
Patt,
Params, Params,
TA, TA,
Body Body
@ -992,7 +1093,9 @@ finish:
auto T0 = peekFirstTokenAfterModifiers(); auto T0 = peekFirstTokenAfterModifiers();
switch (T0->getKind()) { switch (T0->getKind()) {
case NodeKind::LetKeyword: case NodeKind::LetKeyword:
return parseLetDeclaration(); return parseVariableDeclaration();
case NodeKind::FnKeyword:
return parseFunctionDeclaration();
case NodeKind::ReturnKeyword: case NodeKind::ReturnKeyword:
return parseReturnStatement(); return parseReturnStatement();
case NodeKind::IfKeyword: case NodeKind::IfKeyword:
@ -1396,12 +1499,12 @@ next_member:
Node* Parser::parseClassElement() { Node* Parser::parseClassElement() {
auto T0 = Tokens.peek(); auto T0 = Tokens.peek();
switch (T0->getKind()) { switch (T0->getKind()) {
case NodeKind::LetKeyword: case NodeKind::FnKeyword:
return parseLetDeclaration(); return parseFunctionDeclaration();
case NodeKind::TypeKeyword: case NodeKind::TypeKeyword:
// TODO // TODO
default: default:
DE.add<UnexpectedTokenDiagnostic>(File, T0, std::vector<NodeKind> { NodeKind::LetKeyword, NodeKind::TypeKeyword }); DE.add<UnexpectedTokenDiagnostic>(File, T0, std::vector<NodeKind> { NodeKind::FnKeyword, NodeKind::TypeKeyword });
skipToLineFoldEnd(); skipToLineFoldEnd();
return nullptr; return nullptr;
} }
@ -1411,7 +1514,9 @@ next_member:
auto T0 = peekFirstTokenAfterModifiers(); auto T0 = peekFirstTokenAfterModifiers();
switch (T0->getKind()) { switch (T0->getKind()) {
case NodeKind::LetKeyword: case NodeKind::LetKeyword:
return parseLetDeclaration(); return parseVariableDeclaration();
case NodeKind::FnKeyword:
return parseFunctionDeclaration();
case NodeKind::IfKeyword: case NodeKind::IfKeyword:
return parseIfStatement(); return parseIfStatement();
case NodeKind::ClassKeyword: case NodeKind::ClassKeyword:

View file

@ -62,6 +62,7 @@ namespace bolt {
std::unordered_map<ByteString, NodeKind> Keywords = { std::unordered_map<ByteString, NodeKind> Keywords = {
{ "pub", NodeKind::PubKeyword }, { "pub", NodeKind::PubKeyword },
{ "let", NodeKind::LetKeyword }, { "let", NodeKind::LetKeyword },
{ "fn", NodeKind::FnKeyword },
{ "mut", NodeKind::MutKeyword }, { "mut", NodeKind::MutKeyword },
{ "return", NodeKind::ReturnKeyword }, { "return", NodeKind::ReturnKeyword },
{ "type", NodeKind::TypeKeyword }, { "type", NodeKind::TypeKeyword },
@ -226,6 +227,8 @@ digit_finish:
return new PubKeyword(StartLoc); return new PubKeyword(StartLoc);
case NodeKind::LetKeyword: case NodeKind::LetKeyword:
return new LetKeyword(StartLoc); return new LetKeyword(StartLoc);
case NodeKind::FnKeyword:
return new FnKeyword(StartLoc);
case NodeKind::MutKeyword: case NodeKind::MutKeyword:
return new MutKeyword(StartLoc); return new MutKeyword(StartLoc);
case NodeKind::TypeKeyword: case NodeKind::TypeKeyword: